本文仅为记录代码
def main():
if __name__ == '__main__':
# setup random seed
setup(seed=42)
# Avoid the pylint warning.
a = MolVocab
# supress rdkit logger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
# Initialize MolVocab
mol_vocab = MolVocab
args = parse_args()
if args.parser_name == 'finetune':
logger = create_logger(name='train', save_dir=args.save_dir, quiet=False)
cross_validate(args, logger)
elif args.parser_name == 'pretrain':
logger = create_logger(name='pretrain', save_dir=args.save_dir)
pretrain_model(args, logger)
elif args.parser_name == "eval":
logger = create_logger(name='eval', save_dir=args.save_dir, quiet=False)
cross_validate(args, logger)
elif args.parser_name == 'fingerprint':
train_args = get_newest_train_args()
logger = create_logger(name='fingerprint', save_dir=None, quiet=False)
feas = generate_fingerprints(args, logger)
np.savez_compressed(args.output_path, fps=feas)
elif args.parser_name == 'predict':
train_args = get_newest_train_args()
avg_preds, test_smiles = make_predictions(args, train_args)
write_prediction(avg_preds, test_smiles, args)
def generate_fingerprints:
def generate_fingerprints(args: Namespace, logger: Logger = None) -> List[List[float]]:
"""
Generate the fingerprints.
:param logger:
:param args: Arguments.
:return: A list of lists of target fingerprints.
"""
checkpoint_path = args.checkpoint_paths[0]
if logger is None:
logger = create_logger('fingerprints', quiet=False)
print('Loading data')
test_data = get_data(path=args.data_path,
args=args,
use_compound_names=False,
max_data_size=float("inf"),
skip_invalid_smiles=False)
test_data = MoleculeDataset(test_data)
logger.info(f'Total size = {len(test_data):,}')
logger.info(f'Generating...')
# Load model
model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
model_preds = do_generate(
model=model,
data=test_data,
args=args
)
return model_preds
do_generate:
def do_generate(model: nn.Module,
data: MoleculeDataset,
args: Namespace,
) -> List[List[float]]:
"""
Do the fingerprint generation on a dataset using the pre-trained models.
:param model: A model.
:param data: A MoleculeDataset.
:param args: A StandardScaler object fit on the training targets.
:return: A list of fingerprints.
"""
model.eval()
args.bond_drop_rate = 0
preds = []
mol_collator = MolCollator(args=args, shared_dict={})
num_workers = 4
mol_loader = DataLoader(data,
batch_size=32,
shuffle=False,
num_workers=num_workers,
collate_fn=mol_collator)
for item in mol_loader:
_, batch, features_batch, _, _ = item
with torch.no_grad():
batch_preds = model(batch, features_batch)
preds.extend(batch_preds.data.cpu().numpy())
return preds