Transformer和SGCN生成分子,代码:deepHops,原文:Deep Scaffold Hopping with Multi-modal Transformer Neural Networks,代码解析顺着README.md,模型框架如下:
1.Dataset split
python split_data.py -out_dir data40_tue_3d/0.60 -protein_group data40 -target_uniq_rate 0.6 -hopping_pairs_dir hopping_pairs_with_scaffold
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='split_data.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-hopping_pairs_dir', type=str,
default='/home/aht/paper_code/shaungjia/hopping_pairs',
help='the hooping pairs directory generated by make_pairs.py')
parser.add_argument('-out_dir', type=str,
default='/home/aht/paper_code/shaungjia/code/MolecularGET-master/data_hopping',
help='the output directory after data splitting')
parser.add_argument('-protein_group', type=str, choices=['data40', 'test6', 'test69'], help='data40: forty proteins in training set, test6: six unseen proteins, test69: 69 unseen proteins, the R2 of scorer is greater than 0.66')
parser.add_argument('-proteins', nargs='+', type=str, help='the pairs of proteins to split')
parser.add_argument('-ratio_list', nargs='+', type=float, default=[0.8, 0.1, 0.1], help='the ratio of train val test')
parser.add_argument('-ref_max_occured', type=int, default=0, help='the max number of times as source per molecule')
parser.add_argument('-target_uniq_rate', type=float, default=0.0, help='target_uniq_rate is the uniq rate of target molecules per protein')
opt = parser.parse_args()
d = opt.hopping_pairs_dir
print(f"total task: {len(TASKS)}")
df_train_list = []
df_val_list = []
df_test_list = []
ratio_list = opt.ratio_list
filename_list = ['train', 'test', 'val']
assert len(ratio_list) == len(filename_list)
result = [[] for _ in range(len(filename_list))]
if opt.protein_group is None:
select_list = opt.proteins
else:
select_dict = {'data40': include_list, 'test6': external_test_set, 'test69': big_external_test_set}
select_list = select_dict[opt.protein_group]
for x in select_list:
df = data_loader.load_data_frame(f"{d}/{x}.csv")
df = shuffle(df)
if opt.ref_max_occured > 0 or opt.target_uniq_rate > 0.0:
df = df[df.score_scaffold < 0.6]
uniq_refs = list(set(df.ref_smiles))
uniq_refs = shuffle(uniq_refs)
test_size = int(len(uniq_refs) * 0.1)
print(f"uniq_refs: {len(uniq_refs)}, test size: {test_size}")
smi_of_test = uniq_refs[:test_size]
smi_of_train = uniq_refs[test_size:]
df_train = df[df.ref_smiles.isin(smi_of_train) & df.target_smiles.isin(smi_of_train)]
df_test = df[df.ref_smiles.isin(smi_of_test) & df.target_smiles.isin(smi_of_test)]
if opt.ref_max_occured > 0:
# 从训练集中分割出验证集合
df = select_limit_ref_num(df_train, opt.ref_max_occured)
else:
df = select_by_target_uniq_rate(df_train, opt.target_uniq_rate)
df1 = df.sample(frac=0.9)
df2 = df[~df.index.isin(df1.index)]
result[0].append(df1.reset_index(drop=True))
result[1].append(df_test)
result[2].append(df2.reset_index(drop=True))
else:
pos = 0
for i, r in enumerate(ratio_list):
cur_size = int(len(df) * r)
if cur_size > 0:
tmp_size = min(2000000000, cur_size)
cur_df = df[pos:pos + tmp_size]
result[i].append(cur_df)
pos += cur_size
save_dir = opt.out_dir
os.makedirs(save_dir, exist_ok=True)
def write_line(out, line, add_space=True):
if add_space:
tmp = [' '] * (len(line) * 2 - 1)
tmp[0::2] = [s for s in line]
line = ''.join(tmp)
out.write(line)
out.write(os.linesep)
for i, f in enumerate(filename_list):
df = pd.concat(result[i], axis=0, ignore_index=True)
df = shuffle(df)
with open(f"{save_dir}/src-{f}.txt", 'w') as out:
_ = [write_line(out, s) for s in df['ref_smiles']]
with open(f"{save_dir}/tgt-{f}.txt", 'w') as out:
_ = [write_line(out, s) for s in df['target_smiles']]
with open(f"{save_dir}/cond-{f}.txt", 'w') as out:
_ = [write_line(out, f"{TASKS.index(x)}", False) for x in df['target_chembl_id']]
- 【Python】Parser 用法-通俗易懂!
- df 可能包含三列,分别是ref_smiles,target_smiles,score_scaffold,分别是源参考分子,骨架跃迁分子,二者的2D相似度
- 拿到40个训练集蛋白质和骨架跃迁分子对,取出2D相似性小于0.6的那些分子对
- result列表分别是训练集,测试集和验证集,最后写入文件
def select_limit_ref_num(df, ref_max_occured=5):
"""
max occured
Args:
ref_max_occured:
Returns:
"""
uniq_refs = set(df.ref_smiles)
df_per_ref_list = []
for ref in uniq_refs:
cur = df[df.ref_smiles == ref]
# 随机抽取 ref_max_occured 个pair
cur = shuffle(cur)
# cur.sort_values(by=["score_2d", "score_3d"], ascending=[True, False])
df_per_ref_list.append(cur.head(ref_max_occured))
return pd.concat(df_per_ref_list)
def select_by_target_uniq_rate(df, target_uniq_rate):
"""
Args:
df_train:
target_uniq_rate:
Returns:
"""
uniq_targets = list(set(df.target_smiles))
df_per_target_list = []
base_num = int(np.floor(1 / target_uniq_rate))
prob_threshold = 1/target_uniq_rate - base_num
print(f"base_num: {base_num}, prob_threshold: {prob_threshold}")
prob_of_targets = np.random.uniform(low=0.0, high=1.0, size=len(uniq_targets))
for target, prob in zip(uniq_targets, prob_of_targets.tolist()):
cur = df[df.target_smiles == target]
# random select ref_max_occured pairs
cur = shuffle(cur)
select_num = base_num
if prob_threshold > 0 and prob <= prob_threshold:
select_num += 1
df_per_target_list.append(cur.head(select_num))
return pd.concat(df_per_target_list)
2.Data preprocessing
python preprocess.py -train_src data40_tue_3d/0.60/src-train.txt -train_tgt data40_tue_3d/0.60/tgt-train.txt -train_cond data40_tue_3d/0.60/cond-train.txt -valid_src data40_tue_3d/0.60/src-val.txt -valid_tgt data40_tue_3d/0.60/tgt-val.txt -valid_cond data40_tue_3d/0.60/cond-val.txt -save_data data40_tue_3d/0.60/seqdata -share_vocab -src_seq_length 1000 -tgt_seq_length 1000 -src_vocab_size 1000 -tgt_vocab_size 1000 -with_3d_confomer
def main():
opt = parse_args()
if (opt.max_shard_size > 0):
raise AssertionError("-max_shard_size is deprecated, please use \
-shard_size (number of examples) instead.")
init_logger(opt.log_file)
logger.info("Extracting features...")
# 下面的代码是尝试解决多进程prepare失败的问题,但是没有效果
torch.multiprocessing.set_sharing_strategy('file_system')
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (65535, rlimit[1]))
# END
src_nfeats = inputters.get_num_features(
opt.data_type, opt.train_src, 'src')
tgt_nfeats = inputters.get_num_features(
opt.data_type, opt.train_tgt, 'tgt')
logger.info(" * number of source features: %d." % src_nfeats)
logger.info(" * number of target features: %d." % tgt_nfeats)
logger.info("Building `Fields` object...")
fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)
myutils.add_more_field(fields)
logger.info("Building & saving training data...")
train_dataset_files = build_save_dataset('train', fields, opt)
logger.info("Building & saving validation data...")
build_save_dataset('valid', fields, opt)
logger.info("Building & saving vocabulary...")
build_save_vocab(train_dataset_files, fields, opt)
if __name__ == "__main__":
main()
2.1.parse_args
def parse_args():
""" Parsing arguments """
parser = argparse.ArgumentParser(
description='preprocess.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
opts.add_md_help_argument(parser)
opts.preprocess_opts(parser)
parser.add_argument('-parrel_run', action='store_true', default=False,
help='生成靶标')
parser.add_argument('-with_3d_confomer', action='store_true', default=False,
help='原子特征是否在最后3个维度加上坐标')
opt = parser.parse_args()
torch.manual_seed(opt.seed)
check_existing_pt_files(opt)
return opt
def add_md_help_argument(parser):
""" md help parser """
parser.add_argument('-md', action=MarkdownHelpAction,
help='print Markdown-formatted help text and exit.')
def preprocess_opts(parser):
""" Pre-procesing options """
# Data options
group = parser.add_argument_group('Data')
group.add_argument('-data_type', default="text",
help="""Type of the source input.
Options are [text|img].""")
group.add_argument('-train_src', required=True,
help="Path to the training source data")
group.add_argument('-train_tgt', required=True,
help="Path to the training target data")
group.add_argument('-train_cond', required=False,
help="Path to the training 生成条件")
group.add_argument('-valid_src', required=True,
help="Path to the validation source data")
group.add_argument('-valid_tgt', required=True,
help="Path to the validation target data")
group.add_argument('-valid_cond', required=False,
help="Path to the validation 生成条件")
group.add_argument('-src_dir', default="",
help="Source directory for image or audio files.")
group.add_argument('-save_data', required=True,
help="Output file for the prepared data")
group.add_argument('-max_shard_size', type=int, default=0,
help="""Deprecated use shard_size instead""")
group.add_argument('-shard_size', type=int, default=1000000,
help="""Divide src_corpus and tgt_corpus into
smaller multiple src_copus and tgt corpus files, then
build shards, each shard will have
opt.shard_size samples except last shard.
shard_size=0 means no segmentation
shard_size>0 means segment dataset into multiple shards,
each shard has shard_size samples""")
...
def check_existing_pt_files(opt):
""" Checking if there are existing .pt files to avoid tampering """
# We will use glob.glob() to find sharded {train|valid}.[0-9]*.pt
# when training, so check to avoid tampering with existing pt files
# or mixing them up.
for t in ['train', 'valid', 'vocab']:
pattern = opt.save_data + '.' + t + '*.pt'
if glob.glob(pattern):
sys.stderr.write("Please backup existing pt file: %s, "
"to avoid tampering, but now support continue run!\n" % pattern)
# sys.exit(1)
2.2.init_logger
import logging
logger = logging.getLogger()
def init_logger(log_file=None):
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_format)
logger.handlers = [console_handler]
if log_file and log_file != '':
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
return logger
2.3.inputters
def get_num_features(data_type, corpus_file, side):
"""
Args:
data_type (str): type of the source input.
Options are [text|img|audio].
corpus_file (str): file path to get the features.
side (str): for source or for target.
Returns:
number of features on `side`.
"""
assert side in ["src", "tgt"]
if data_type == 'text':
return TextDataset.get_num_features(corpus_file, side)
elif data_type == 'img':
return ImageDataset.get_num_features(corpus_file, side)
elif data_type == 'audio':
return AudioDataset.get_num_features(corpus_file, side)
else:
raise ValueError("Data type not implemented")
def get_num_features(corpus_file, side):
"""
Peek one line and get number of features of it.
(All lines must have same number of features).
For text corpus, both sides are in text form, thus
it works the same.
Args:
corpus_file (str): file path to get the features.
side (str): 'src' or 'tgt'.
Returns:
number of features on `side`.
"""
with codecs.open(corpus_file, "r", "utf-8") as cf:
f_line = cf.readline().strip().split()
_, _, num_feats = TextDataset.extract_text_features(f_line)
return num_feats