Knowledge-based-BERT(一)

该文介绍了一个利用BERT模型处理化学领域SMILES表示的预训练方法,通过多头注意力机制学习分子特征。数据经过多轮增强,生成多个SMILES版本,用于对比学习。模型包括编码器层和位置前馈网络,损失函数考虑了全局和原子级别的分类任务。训练过程中采用了早停策略优化性能。
摘要由CSDN通过智能技术生成

多种预训练任务解决NLP处理SMILES的多种弊端,代码:Knowledge-based-BERT,原文:Knowledge-based BERT: a method to extract molecular features like computational chemists,解析:Knowledge-based BERT: 像计算化学家一样提取分子特征的方法,代码解析从K_BERT_pretrain开始。模型框架如下:
在这里插入图片描述

args['pretrain_data_path'] = '../data/pretrain_data/CHEMBL_maccs'
args['batch_size'] = 32
pretrain_set = build_data.load_data_for_contrastive_aug_pretrain(
                                        pretrain_data_path=args['pretrain_data_path'])
print("Pretrain data generation is complete !")

pretrain_loader = DataLoader(dataset=pretrain_set,
                             batch_size=args['batch_size'],
                             shuffle=True,
                             collate_fn=collate_pretrain_data)

1.load_data_for_contrastive_aug_pretrain

def load_data_for_contrastive_aug_pretrain(pretrain_data_path='./data/CHEMBL_wash_500_pretrain'):
    tokens_idx_list = []
    global_labels_list = []
    atom_labels_list = []
    atom_mask_list = []
    for i in range(80):
        pretrain_data = np.load(pretrain_data_path+'_contrastive_{}.npy'.format(i+1), allow_pickle=True)
        tokens_idx_list = tokens_idx_list + [x for x in pretrain_data[0]]
        global_labels_list = global_labels_list + [x for x in pretrain_data[1]]
        atom_labels_list = atom_labels_list + [x for x in pretrain_data[2]]
        atom_mask_list = atom_mask_list + [x for x in pretrain_data[3]]
        print(pretrain_data_path+'_contrastive_{}.npy'.format(i+1) + ' is loaded')
    pretrain_data_final = []
    for i in range(len(tokens_idx_list)):
        a_pretrain_data = [tokens_idx_list[i], global_labels_list[i], atom_labels_list[i], atom_mask_list[i]]
        pretrain_data_final.append(a_pretrain_data)
    return pretrain_data_final
  • CHEMBL_maccs_contrastive_{}.npy 是在 build_contrastive_pretrain_selected_tasks 文件中构造的
  • 通过下面的分析,最终 .npy 存储的内容应该是 tokens_idx_all_list, global_label_list, atom_labels_list, atom_mask_list,其中 tokens_idx_all_list 是某个分子的5个SMILES编码转化为token后的下标列表,shape应该是(n_smiles,5,201),其他几个的shape在下面有示例,应该只是多了 n_smiles 这个维度

1.1.build_contrastive_pretrain_selected_tasks

from experiment.build_data import build_maccs_pretrain_contrastive_data_and_save
import multiprocessing
import pandas as pd

task_name = 'CHEMBL'
if __name__ == "__main__":
    n_thread = 8
    data = pd.read_csv('../data/pretrain_data/'+task_name+'_5_contrastive_aug.csv')
    smiles_name_list = ['smiles', 'aug_smiles_0', 'aug_smiles_1', 'aug_smiles_2', 'aug_smiles_3']
    smiles_list = data[smiles_name_list].values.tolist()

    # 避免内存不足,将数据集分为10份来计算
    for i in range(10):
        n_split = int(len(smiles_list)/10)
        smiles_split = smiles_list[i*n_split:(i+1)*n_split]

        n_mol = int(len(smiles_split)/8)

        # creating processes
        p1 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[:n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+1)+'.npy'))
        p2 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[n_mol:2*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+2)+'.npy'))
        p3 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[2*n_mol:3*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+3)+'.npy'))
        p4 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[3*n_mol:4*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+4)+'.npy'))
        p5 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[4*n_mol:5*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+5)+'.npy'))
        p6 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[5*n_mol:6*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+6)+'.npy'))
        p7 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[6*n_mol:7*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+7)+'.npy'))
        p8 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[7*n_mol:],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+8)+'.npy'))

        # starting my_scaffold_split 1&2
        p1.start()
        p2.start()
        p3.start()
        p4.start()
        p5.start()
        p6.start()
        p7.start()
        p8.start()

        # wait until my_scaffold_split 1&2 is finished
        p1.join()
        p2.join()
        p3.join()
        p4.join()
        p5.join()
        p6.join()
        p7.join()
        p8.join()


        # both processes finished
        print("Done!")
  • 在 CHEMBAL 收集分子后,经过数据增强存成SMILES,这里读入生成 .npy
  • 输入 smiles_list 的格式如下,每一行是一个分子的五个SMILES:
import pandas as pd
import numpy as np
smiles_name_list = ['smiles', 'aug_smiles_0', 'aug_smiles_1', 'aug_smiles_2', 'aug_smiles_3']
data=pd.DataFrame(np.arange(15).reshape(3,5),columns=smiles_name_list)
smiles_list = data[smiles_name_list].values.tolist()
smiles_list
#[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]

1.2.build_maccs_pretrain_contrastive_data_and_save

def build_maccs_pretrain_contrastive_data_and_save(smiles_list, output_smiles_path, global_feature='MACCS'):
    # all smiles list
    smiles_list = smiles_list
    tokens_idx_all_list = []
    global_label_list = []
    atom_labels_list = []
    atom_mask_list = []
    for i, smiles_one_mol in enumerate(smiles_list):
        tokens_idx_list = [construct_input_from_smiles(smiles, global_feature=global_feature)[0] for
                           smiles in smiles_one_mol]
        if 0 not in tokens_idx_list:
            _ , global_labels, atom_labels, atom_mask = construct_input_from_smiles(smiles_one_mol[0],
                                                                global_feature=global_feature)
            tokens_idx_all_list.append(tokens_idx_list)
            global_label_list.append(global_labels)
            atom_labels_list.append(atom_labels)
            atom_mask_list.append(atom_mask)
            print('{}/{} is transformed!'.format(i+1, len(smiles_list)))
        else:
            print('{} is transformed failed!'.format(smiles_one_mol[0]))
    pretrain_data_list = [tokens_idx_all_list, global_label_list, atom_labels_list, atom_mask_list]
    pretrain_data_np = np.array(pretrain_data_list, dtype=object)
    np.save(output_smiles_path, pretrain_data_np)

tokens_idx_list 取 construct_input_from_smiles 返回的第一个元素

1.3.construct_input_from_smiles

def construct_input_from_smiles(smiles, max_len=200, global_feature='MACCS'):
    try:
        # built a pretrain data from smiles
        atom_list = []
        atom_token_list = ['c', 'C', 'O', 'N', 'n', '[C@H]', 'F', '[C@@H]', 'S', 'Cl', '[nH]', 's', 'o', '[C@]',
                           '[C@@]', '[O-]', '[N+]', 'Br', 'P', '[n+]', 'I', '[S+]',  '[N-]', '[Si]', 'B', '[Se]', '[other_atom]']
        all_token_list = ['[PAD]', '[GLO]', 'c', 'C', '(', ')', 'O', '1', '2', '=', 'N', '3', 'n', '4', '[C@H]', 'F', '[C@@H]', '-', 'S', '/', 'Cl', '[nH]', 's', 'o', '5', '#', '[C@]', '[C@@]', '\\', '[O-]', '[N+]', 'Br', '6', 'P', '[n+]', '7', 'I', '[S+]', '8', '[N-]', '[Si]', 'B', '9', '[2H]', '[Se]', '[other_atom]', '[other_token]']

        # 构建token转化成idx的字典
        word2idx = {}
        for i, w in enumerate(all_token_list):
            word2idx[w] = i
        # 构建token_list 并加上padding和global
        token_list = smi_tokenizer(smiles)
        padding_list = ['[PAD]' for x in range(max_len-len(token_list))]
        tokens = ['[GLO]'] + token_list + padding_list
        mol = MolFromSmiles(smiles)
        atom_example = mol.GetAtomWithIdx(0)
        atom_labels_example = atom_labels(atom_example)
        atom_mask_labels = [2 for x in range(len(atom_labels_example))]
        atom_labels_list = []
        atom_mask_list = []

        index = 0
        tokens_idx = []
        for i, token in enumerate(tokens):
            if token in atom_token_list:
                atom = mol.GetAtomWithIdx(index)
                an_atom_labels = atom_labels(atom)
                atom_labels_list.append(an_atom_labels)
                atom_mask_list.append(1)
                index = index + 1
                tokens_idx.append(word2idx[token])
            else:
                if token in all_token_list:
                    atom_labels_list.append(atom_mask_labels)
                    tokens_idx.append(word2idx[token])
                    atom_mask_list.append(0)
                elif '[' in list(token):
                    atom = mol.GetAtomWithIdx(index)
                    tokens[i] = '[other_atom]'
                    an_atom_labels = atom_labels(atom)
                    atom_labels_list.append(an_atom_labels)
                    atom_mask_list.append(1)
                    index = index + 1
                    tokens_idx.append(word2idx['[other_atom]'])
                else:
                    tokens[i] = '[other_token]'
                    atom_labels_list.append(atom_mask_labels)
                    tokens_idx.append(word2idx['[other_token]'])
                    atom_mask_list.append(0)
        if global_feature == 'MACCS':
            global_label_list = global_maccs_data(smiles)
        elif global_feature == 'ECFP4':
            global_label_list = global_ecfp4_data(smiles)
        elif global_feature == 'RDKIT_des':
            global_label_list = global_rdkit_des_data(smiles)

        tokens_idx = [word2idx[x] for x in tokens]
        if len(tokens_idx) == max_len + 1:
            return tokens_idx, global_label_list, atom_labels_list, atom_mask_list
        else:
            return 0, 0, 0, 0
    except:
        return 0, 0, 0, 0
def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    import re
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    # assert smi == ''.join(tokens)
    # return ' '.join(tokens)
    return tokens
    """
    smi='C=CCC=CCO'
	smi_tokenizer(smi)
	#['C', '=', 'C', 'C', 'C', '=', 'C', 'C', 'O']
    """
def atom_labels(atom, use_chirality=True):
    results = one_of_k_encoding(atom.GetDegree(),
                                [0, 1, 2, 3, 4, 5, 6]) + \
              one_of_k_encoding_unk(atom.GetHybridization(), [
                  Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                  Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
                  Chem.rdchem.HybridizationType.SP3D2, 'other']) + [atom.GetIsAromatic()] \
              + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                  [0, 1, 2, 3, 4])
    if use_chirality:
        try:
            results = results + one_of_k_encoding_unk(
                atom.GetProp('_CIPCode'),
                ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
        except:
            results = results + [False, False
                                 ] + [atom.HasProp('_ChiralityPossible')]
    atom_labels_list = np.array(results).tolist()
    atom_selected_index = [1, 2, 3, 4, 7, 8, 9, 13, 14, 15, 16, 17, 19, 20, 21]
    atom_labels_selected = [atom_labels_list[x] for x in atom_selected_index]
    return atom_labels_selected
    """
	from rdkit.Chem import *
	from build_data import atom_labels
	mol = MolFromSmiles(smi)
	atom_example = mol.GetAtomWithIdx(0)
	atom_labels_example = atom_labels(atom_example)
	atom_labels_example
	#[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0]
    """
  • tokens_idx 是 SMILES 转换为 tokens 后对应的下标列表,global_label_list 是根据 SMILES 算出的各种描述符,这里是 global_maccs_data,atom_labels_list 是分子中每个原子编码,如果 token 不是原子就设为全2,atom_mask_list 是 token 是否是原子的标记,构建失败返回全0,正确的话 tokens_idx 是一个列表,构建失败就是数值0
def global_maccs_data(smiles):
    mol = Chem.MolFromSmiles(smiles)
    maccs = MACCSkeys.GenMACCSKeys(mol)
    global_maccs_list = np.array(maccs).tolist()
    # 选择负/正样本比例小于1000且大于0.001的数据
    selected_index = [3, 8, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165]
    selected_global_list = [global_maccs_list[x] for x in selected_index]
    return selected_global_list
  • 使用示例如下,具体实现的if-else细节处理不再深入
from build_data import *
import numpy as np
smi1='C=CCC=CCO'
smi2='OCC=CCC=C'
res=construct_input_from_smiles(smi1)
#res=construct_input_from_smiles(smi2)
len(res),np.array(res[0]).shape,np.array(res[1]).shape,np.array(res[2]).shape,np.array(res[3]).shape
#(4, (201,), (154,), (201, 15), (201,)) smi1
#(4, (201,), (154,), (201, 15), (201,)) smi2
  • 201是pad到200再加glo,154 是 selected_index 的长度,每个 token 编码为长度为15的向量

2.collate_pretrain_data

DataLoder参数collate_fn=collate_pretrain_data

def collate_pretrain_data(data):
    tokens_idx, global_label_list, atom_labels_list, atom_mask_list = map(list, zip(*data))
    tokens_idx = torch.tensor(tokens_idx)
    global_label = torch.tensor(global_label_list)
    atom_labels = torch.tensor(atom_labels_list)
    atom_mask = torch.tensor(atom_mask_list)
    return tokens_idx, global_label, atom_labels, atom_mask
  • 先把数据都转换成列表,再转换成tensor

3.loss

global_pos_weight = torch.tensor([884.17, 70.71, 43.32, 118.73, 428.67, 829.0, 192.84, 67.89, 533.86, 18.46, 707.55, 160.14, 23.19, 26.33, 13.38, 12.45, 44.91, 173.58, 40.14, 67.25, 171.12, 8.84, 8.36, 43.63, 5.87, 10.2, 3.06, 161.72, 101.75, 20.01, 4.35, 12.62, 331.79, 31.17, 23.19, 5.91, 53.58, 15.73, 10.75, 6.84, 3.92, 6.52, 6.33, 6.74, 24.7, 2.67, 6.64, 5.4, 6.71, 6.51, 1.35, 24.07, 5.2, 0.74, 4.78, 6.1, 62.43, 6.1, 12.57, 9.44, 3.33, 5.71, 4.67, 0.98, 8.2, 1.28, 9.13, 1.1, 1.03, 2.46, 2.95, 0.74, 6.24, 0.96, 1.72, 2.25, 2.16, 2.87, 1.8, 1.62, 0.76, 1.78, 1.74, 1.08, 0.65, 0.97, 0.71, 5.08, 0.75, 0.85, 3.3, 4.79, 1.72, 0.78, 1.46, 1.8, 2.97, 2.18, 0.61, 0.61, 1.83, 1.19, 4.68, 3.08, 2.83, 0.51, 0.77, 6.31, 0.47, 0.29, 0.58, 2.76, 1.48, 0.25, 1.33, 0.69, 1.03, 0.97, 3.27, 1.31, 1.22, 0.85, 1.75, 1.02, 1.13, 0.16, 1.02, 2.2, 1.72, 2.9, 0.26, 0.69, 0.6, 0.23, 0.76, 0.73, 0.47, 1.13, 0.48, 0.53, 0.72, 0.38, 0.35, 0.48, 0.12, 0.52, 0.15, 0.28, 0.36, 0.08, 0.06, 0.03, 0.07, 0.01])
global_pos_weight = torch.cat((global_pos_weight, global_pos_weight, global_pos_weight, global_pos_weight, global_pos_weight), 0)
atom_pos_weight = torch.tensor([4.81, 1.0, 2.23, 53.49, 211.94, 0.49, 2.1, 1.13, 1.22, 1.93, 5.74, 15.42, 70.09, 61.47, 23.2])
loss_criterion_global = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=global_pos_weight.to('cuda'))
loss_criterion_atom = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=atom_pos_weight.to('cuda'))
  • 根据权重定义损失函数,BCEWithLogitsLoss Binary Cross Entropy,这里 atom_pos_weight 的长度是15,即对应atom_labels编码每个原子的向量长度,global_pos_weight 的长度是5*154,154即对应了每个分子的 maccs_data

4.K_BERT

model = K_BERT(d_model=args['d_model'], n_layers=args['n_layers'], vocab_size=args['vocab_size'],
               maxlen=args['maxlen'], d_k=args['d_k'], d_v=args['d_v'], n_heads=args['n_heads'], d_ff=args['d_ff'],
               global_label_dim=args['global_labels_dim'], atom_label_dim=args['atom_labels_dim'])

class K_BERT(nn.Module):
    def __init__(self, d_model, n_layers, vocab_size, maxlen, d_k, d_v, n_heads, d_ff, global_label_dim, atom_label_dim):
        super(K_BERT, self).__init__()
        self.maxlen = maxlen
        self.d_model = d_model
        self.embedding = Embedding(vocab_size, self.d_model, maxlen)
        self.layers = nn.ModuleList([EncoderLayer(self.d_model, d_k, d_v, n_heads, d_ff) for _ in range(n_layers)])
        self.fc_global = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Dropout(0.5),
            nn.Tanh(),
        )
        self.fc_atom = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Dropout(0.5),
            nn.Tanh(),
        )
        self.classifier_global = nn.Linear(self.d_model, global_label_dim)
        self.classifier_atom = nn.Linear(self.d_model, atom_label_dim)

    def forward(self, canonical_input_ids, aug_input_ids_1, aug_input_ids_2, aug_input_ids_3, aug_input_ids_4):
        canonical_output = self.embedding(canonical_input_ids)
        aug_output_1 = self.embedding(aug_input_ids_1)
        aug_output_2 = self.embedding(aug_input_ids_2)
        aug_output_3 = self.embedding(aug_input_ids_3)
        aug_output_4 = self.embedding(aug_input_ids_4)

        canonical_enc_self_attn_mask = get_attn_pad_mask(canonical_input_ids)
        aug_enc_self_attn_mask_1 = get_attn_pad_mask(aug_input_ids_1)
        aug_enc_self_attn_mask_2 = get_attn_pad_mask(aug_input_ids_2)
        aug_enc_self_attn_mask_3 = get_attn_pad_mask(aug_input_ids_3)
        aug_enc_self_attn_mask_4 = get_attn_pad_mask(aug_input_ids_4)

        for layer in self.layers:
            canonical_output = layer(canonical_output, canonical_enc_self_attn_mask)
            aug_output_1 = layer(aug_output_1, aug_enc_self_attn_mask_1)
            aug_output_2 = layer(aug_output_2, aug_enc_self_attn_mask_2)
            aug_output_3 = layer(aug_output_3, aug_enc_self_attn_mask_3)
            aug_output_4 = layer(aug_output_4, aug_enc_self_attn_mask_4)

        h_canonical_global = self.fc_global(canonical_output[:, 0])
        h_aug_global_1 = self.fc_global(aug_output_1[:, 0])
        h_aug_global_2 = self.fc_global(aug_output_2[:, 0])
        h_aug_global_3 = self.fc_global(aug_output_3[:, 0])
        h_aug_global_4 = self.fc_global(aug_output_4[:, 0])
        """
        a=torch.randn((3,4,5))
		a.shape,a[:,0].shape  (torch.Size([3, 4, 5]), torch.Size([3, 5]))
		"""
        h_cos_1 = torch.cosine_similarity(canonical_output[:, 0], aug_output_1[:, 0], dim=1)
        h_cos_2 = torch.cosine_similarity(canonical_output[:, 0], aug_output_2[:, 0], dim=1)
        h_cos_3 = torch.cosine_similarity(canonical_output[:, 0], aug_output_3[:, 0], dim=1)
        h_cos_4 = torch.cosine_similarity(canonical_output[:, 0], aug_output_4[:, 0], dim=1)
        
        consensus_score = (torch.ones_like(h_cos_1)*4-h_cos_1 - h_cos_2 - h_cos_3 - h_cos_4)/8
        
        logits_canonical_global = self.classifier_global(h_canonical_global)
        logits_global_aug_1 = self.classifier_global(h_aug_global_1)
        logits_global_aug_2 = self.classifier_global(h_aug_global_2)
        logits_global_aug_3 = self.classifier_global(h_aug_global_3)
        logits_global_aug_4 = self.classifier_global(h_aug_global_4)
        canonical_cos_score_matric = torch.abs(cos_similar(canonical_output[:, 0], canonical_output[:, 0]))
        
        diagonal_cos_score_matric = torch.eye(canonical_cos_score_matric.size(0)).float().cuda()
        different_score = canonical_cos_score_matric - diagonal_cos_score_matric
        
        logits_global = torch.cat((logits_canonical_global, logits_global_aug_1, logits_global_aug_2,
                                   logits_global_aug_3, logits_global_aug_4), 1)

        h_atom = self.fc_atom(canonical_output[:, 1:])
        h_atom_emb = h_atom.reshape([len(canonical_output)*(self.maxlen - 1), self.d_model])
        logits_atom = self.classifier_atom(h_atom_emb)
        return logits_global, logits_atom, consensus_score, different_score

模型输入是一个分子的标准化的SMILES和四个数据增强的SMILES,embedding将vocab_size=47即all_token_list的长度转换为d_model,经过n_layers层EncoderLayer(这里用的是ModuleList)之后拿到结果做分类任务和对比任务,logits_global 是五个SMILES分类输出的堆叠,只取glo得到,logits_atom 只用标准化的SMILES的除glo的其他输出得到,根据标准化SMILES的glo和其他数据增强的SMILES的glo得到一致性得分和不一致性得分

4.1.get_attn_pad_mask

def get_attn_pad_mask(seq_q):
    batch_size, seq_len = seq_q.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)
    return pad_attn_mask.expand(batch_size, seq_len, seq_len)

4.2.EncoderLayer

class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads, d_ff):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs
  • bert的encoder模块,经过多头自注意力和前馈网络做输出,输出的 enc_inputs 的shape一定与之前的 shape 一致,与molecular-graph-bert(一)中的BertModel类似,这里不再详细分析

4.2.1.MultiHeadAttention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads):
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.n_heads = n_heads
        super(MultiHeadAttention, self).__init__()
        self.linear = nn.Linear(self.n_heads * self.d_v, self.d_model)
        self.layernorm = nn.LayerNorm(self.d_model)
        self.W_Q = nn.Linear(self.d_model, self.d_k * self.n_heads)
        self.W_K = nn.Linear(self.d_model, self.d_k * self.n_heads)
        self.W_V = nn.Linear(self.d_model, self.d_v * self.n_heads)
    def forward(self, Q, K, V, attn_mask):
        residual, batch_size = Q, Q.size(0)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)

        context = ScaledDotProductAttention(self.d_k)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
        output = self.linear(context)
        return self.layernorm(output + residual)

4.2.3.PoswiseFeedForwardNet

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
        self.layernorm = nn.LayerNorm(d_model)
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return self.layernorm.cuda()(output + residual)

5.train

optimizer = Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(task_name=args['task_name'])
model.to(args['device'])

for epoch in range(args['num_epochs']):
    start = time.time()
    # Train
    run_a_contrastive_pretrain_epoch(args, epoch, model, pretrain_loader, loss_criterion_global=loss_criterion_global,
                                   loss_criterion_atom=loss_criterion_atom, optimizer=optimizer)
    # Validation and early stop
    stopper.pretrain_step(epoch, model)
    elapsed = (time.time() - start)
    m, s = divmod(elapsed, 60)
    h, m = divmod(m, 60)
    print("An epoch time used:", "{:d}:{:d}:{:d}".format(int(h), int(m), int(s)))

5.1.EarlyStopping

class EarlyStopping(object):
    def __init__(self, pretrained_model='Null_early_stop.pth',
                 pretrain_layer=6, mode='higher', patience=10, task_name="None"):
        assert mode in ['higher', 'lower']
        self.pretrain_layer = pretrain_layer
        self.mode = mode
        if self.mode == 'higher':
            self._check = self._check_higher
        else:
            self._check = self._check_lower

        self.patience = patience
        self.counter = 0
        self.filename = '../model/{}_early_stop.pth'.format(task_name)
        self.pretrain_save_filename = '../model/pretrain_{}_epoch_'.format(task_name)
        self.best_score = None
        self.early_stop = False
        self.pretrained_model = '../model/{}'.format(pretrained_model)

    def _check_higher(self, score, prev_best_score):
        return (score > prev_best_score)

    def _check_lower(self, score, prev_best_score):
        return (score < prev_best_score)

    def step(self, score, model):
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif self._check(score, self.best_score):
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0
        else:
            self.counter += 1
            print(
                'EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop

    def pretrain_step(self, epoch, model):
        print('Pretrain epoch {} is finished and the model is saved'.format(epoch))
        self.pretrain_save_checkpoint(epoch, model)

    def pretrain_save_checkpoint(self, epoch, model):
        '''Saves model when the metric on the validation set gets improved.'''
        torch.save({'model_state_dict': model.state_dict()}, self.pretrain_save_filename + str(epoch) + '.pth')
        # print(self.filename)

    def save_checkpoint(self, model):
        '''Saves model when the metric on the validation set gets improved.'''
        torch.save({'model_state_dict': model.state_dict()}, self.filename)
        # print(self.filename)

    def load_checkpoint(self, model):
        '''Load model saved with early stopping.'''
        # model.load_state_dict(torch.load(self.filename)['model_state_dict'])
        model.load_state_dict(torch.load(self.filename, map_location=torch.device('cpu'))['model_state_dict'])

    def load_pretrained_model(self, model):
        if self.pretrain_layer == 1:
            pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight', 'embedding.norm.weight', 'embedding.norm.bias', 'layers.0.enc_self_attn.linear.weight', 'layers.0.enc_self_attn.linear.bias', 'layers.0.enc_self_attn.layernorm.weight', 'layers.0.enc_self_attn.layernorm.bias', 'layers.0.enc_self_attn.W_Q.weight', 'layers.0.enc_self_attn.W_Q.bias', 'layers.0.enc_self_attn.W_K.weight', 'layers.0.enc_self_attn.W_K.bias', 'layers.0.enc_self_attn.W_V.weight', 'layers.0.enc_self_attn.W_V.bias', 'layers.0.pos_ffn.fc.0.weight', 'layers.0.pos_ffn.fc.2.weight', 'layers.0.pos_ffn.layernorm.weight', 'layers.0.pos_ffn.layernorm.bias']

        elif self.pretrain_layer == 2:
            pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight', 'embedding.norm.weight', 'embedding.norm.bias', 'layers.0.enc_self_attn.linear.weight', 'layers.0.enc_self_attn.linear.bias', 'layers.0.enc_self_attn.layernorm.weight', 'layers.0.enc_self_attn.layernorm.bias', 'layers.0.enc_self_attn.W_Q.weight', 'layers.0.enc_self_attn.W_Q.bias', 'layers.0.enc_self_attn.W_K.weight', 'layers.0.enc_self_attn.W_K.bias', 'layers.0.enc_self_attn.W_V.weight', 'layers.0.enc_self_attn.W_V.bias', 'layers.0.pos_ffn.fc.0.weight', 'layers.0.pos_ffn.fc.2.weight', 'layers.0.pos_ffn.layernorm.weight', 'layers.0.pos_ffn.layernorm.bias', 'layers.1.enc_self_attn.linear.weight', 'layers.1.enc_self_attn.linear.bias', 'layers.1.enc_self_attn.layernorm.weight', 'layers.1.enc_self_attn.layernorm.bias', 'layers.1.enc_self_attn.W_Q.weight', 'layers.1.enc_self_attn.W_Q.bias', 'layers.1.enc_self_attn.W_K.weight', 'layers.1.enc_self_attn.W_K.bias', 'layers.1.enc_self_attn.W_V.weight', 'layers.1.enc_self_attn.W_V.bias', 'layers.1.pos_ffn.fc.0.weight', 'layers.1.pos_ffn.fc.2.weight', 'layers.1.pos_ffn.layernorm.weight', 'layers.1.pos_ffn.layernorm.bias']

        elif self.pretrain_layer == 3:
        ...
        pretrained_model = torch.load(self.pretrained_model, map_location=torch.device('cpu'))
        # pretrained_model = torch.load(self.pretrained_model)
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_model['model_state_dict'].items() if k in pretrained_parameters}
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict, strict=False)
	def load_pretrained_model_continue(self, model):
        pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight', 'embedding.norm.weight',
                                 'embedding.norm.bias', 'layers.0.enc_self_attn.linear.weight',
                                 'layers.0.enc_self_attn.linear.bias', 'layers.0.enc_self_attn.layernorm.weight',
                                 'layers.0.enc_self_attn.layernorm.bias', 'layers.0.enc_self_attn.W_Q.weight',
                                 'layers.0.enc_self_attn.W_Q.bias', 'layers.0.enc_self_attn.W_K.weight',
                                 'layers.0.enc_self_attn.W_K.bias', 'layers.0.enc_self_attn.W_V.weight',
                                 'layers.0.enc_self_attn.W_V.bias', 'layers.0.pos_ffn.fc1.weight',
                                 'layers.0.pos_ffn.fc1.bias', 'layers.0.pos_ffn.fc2.weight',
                                 'layers.0.pos_ffn.fc2.bias', 'layers.1.enc_self_attn.linear.weight',
                                 'layers.1.enc_self_attn.linear.bias', 'layers.1.enc_self_attn.layernorm.weight',
                                 'layers.1.enc_self_attn.layernorm.bias', 'layers.1.enc_self_attn.W_Q.weight',
                                 'layers.1.enc_self_attn.W_Q.bias', 'layers.1.enc_self_attn.W_K.weight',
                                 'layers.1.enc_self_attn.W_K.bias', 'layers.1.enc_self_attn.W_V.weight',
                                 'layers.1.enc_self_attn.W_V.bias', 'layers.1.pos_ffn.fc1.weight',
                                 'layers.1.pos_ffn.fc1.bias', 'layers.1.pos_ffn.fc2.weight',
                                 'layers.1.pos_ffn.fc2.bias', 'layers.2.enc_self_attn.linear.weight',
                                 'layers.2.enc_self_attn.linear.bias', 'layers.2.enc_self_attn.layernorm.weight',
                                 'layers.2.enc_self_attn.layernorm.bias', 'layers.2.enc_self_attn.W_Q.weight',
                                 'layers.2.enc_self_attn.W_Q.bias', 'layers.2.enc_self_attn.W_K.weight',
                                 'layers.2.enc_self_attn.W_K.bias', 'layers.2.enc_self_attn.W_V.weight',
                                 'layers.2.enc_self_attn.W_V.bias', 'layers.2.pos_ffn.fc1.weight',
                                 'layers.2.pos_ffn.fc1.bias', 'layers.2.pos_ffn.fc2.weight',
                                 'layers.2.pos_ffn.fc2.bias', 'layers.3.enc_self_attn.linear.weight',
                                 'layers.3.enc_self_attn.linear.bias', 'layers.3.enc_self_attn.layernorm.weight',
                                 'layers.3.enc_self_attn.layernorm.bias', 'layers.3.enc_self_attn.W_Q.weight',
                                 'layers.3.enc_self_attn.W_Q.bias', 'layers.3.enc_self_attn.W_K.weight',
                                 'layers.3.enc_self_attn.W_K.bias', 'layers.3.enc_self_attn.W_V.weight',
                                 'layers.3.enc_self_attn.W_V.bias', 'layers.3.pos_ffn.fc1.weight',
                                 'layers.3.pos_ffn.fc1.bias', 'layers.3.pos_ffn.fc2.weight',
                                 'layers.3.pos_ffn.fc2.bias', 'layers.4.enc_self_attn.linear.weight',
                                 'layers.4.enc_self_attn.linear.bias', 'layers.4.enc_self_attn.layernorm.weight',
                                 'layers.4.enc_self_attn.layernorm.bias', 'layers.4.enc_self_attn.W_Q.weight',
                                 'layers.4.enc_self_attn.W_Q.bias', 'layers.4.enc_self_attn.W_K.weight',
                                 'layers.4.enc_self_attn.W_K.bias', 'layers.4.enc_self_attn.W_V.weight',
                                 'layers.4.enc_self_attn.W_V.bias', 'layers.4.pos_ffn.fc1.weight',
                                 'layers.4.pos_ffn.fc1.bias', 'layers.4.pos_ffn.fc2.weight',
                                 'layers.4.pos_ffn.fc2.bias', 'layers.5.enc_self_attn.linear.weight',
                                 'layers.5.enc_self_attn.linear.bias', 'layers.5.enc_self_attn.layernorm.weight',
                                 'layers.5.enc_self_attn.layernorm.bias', 'layers.5.enc_self_attn.W_Q.weight',
                                 'layers.5.enc_self_attn.W_Q.bias', 'layers.5.enc_self_attn.W_K.weight',
                                 'layers.5.enc_self_attn.W_K.bias', 'layers.5.enc_self_attn.W_V.weight',
                                 'layers.5.enc_self_attn.W_V.bias', 'layers.5.pos_ffn.fc1.weight',
                                 'layers.5.pos_ffn.fc1.bias', 'layers.5.pos_ffn.fc2.weight',
                                 'layers.5.pos_ffn.fc2.bias', 'fc.1.weight', 'fc.1.bias', 'fc.3.weight', 'fc.3.bias',
                                 'fc.5.weight', 'fc.5.bias', 'fc.7.weight', 'fc.7.bias', 'classifier_global.weight',
                                 'classifier_global.bias', 'classifier_atom.weight', 'classifier_atom.bias']
        pretrained_model = torch.load(self.pretrained_model, map_location=torch.device('cpu'))
        # pretrained_model = torch.load(self.pretrained_model)
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_model['model_state_dict'].items() if k in pretrained_parameters}
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict, strict=False)
  • 这里最后就是调用了torch.save,其他的没用

5.1.run_a_contrastive_pretrain_epoch

def run_a_contrastive_pretrain_epoch(args, epoch, model, data_loader,
                                   loss_criterion_global, loss_criterion_atom, optimizer):
    model.train()
    total_loss = 0
    for batch_id, batch_data in enumerate(data_loader):
        token_idx, global_labels, atom_labels, atom_mask = batch_data
        canonicaL_token_idx = token_idx[:, 0].long().to(args['device'])
        aug_token_idx_1 = token_idx[:, 1].long().to(args['device'])
        aug_token_idx_2 = token_idx[:, 2].long().to(args['device'])
        aug_token_idx_3 = token_idx[:, 3].long().to(args['device'])
        aug_token_idx_4 = token_idx[:, 4].long().to(args['device'])

        global_labels = global_labels.float().to(args['device'])
        global_labels = torch.cat((global_labels, global_labels, global_labels, global_labels, global_labels), 1)

        atom_labels = atom_labels[:, 1:].float().to(args['device'])
        atom_mask = atom_mask[:, 1:].float().to(args['device'])

        atom_labels = atom_labels.reshape([len(token_idx)*(args['maxlen']-1), args['atom_labels_dim']])
        atom_mask = atom_mask.reshape(len(token_idx)*(args['maxlen']-1), 1)

        logits_global, logits_atom, consensus_score, different_score = model(canonicaL_token_idx, aug_token_idx_1, aug_token_idx_2,
                                                            aug_token_idx_3, aug_token_idx_4)
        loss = (loss_criterion_global(logits_global, global_labels).float()).mean() \
                + (loss_criterion_atom(logits_atom, atom_labels)*(atom_mask != 0).float()).mean()\
                + consensus_score.mean() + different_score.mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss*len(token_idx)
        print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}, consensus_loss {:.4f}, different_loss {:.4f}, global_loss {:.4f}, atom_loss {:.4f}'.format(
            epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss, consensus_score.mean(), different_score.mean(),
            (loss_criterion_global(logits_global, global_labels).float()).mean(),
            (loss_criterion_atom(logits_atom, atom_labels)*(atom_mask != 0).float()).mean()))
        del token_idx, global_labels, atom_labels, atom_mask, loss, logits_global, logits_atom
        torch.cuda.empty_cache()
    print('epoch {:d}/{:d}, pre-train loss {:.4f}'.format(
        epoch + 1, args['num_epochs'], total_loss))
    return total_loss
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_森罗万象

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值