Knowledge-based-BERT(三)

多种预训练任务解决NLP处理SMILES的多种弊端,代码:Knowledge-based-BERT,原文:Knowledge-based BERT: a method to extract molecular features like computational chemists,代码解析继续downstream_task。模型框架如下:
在这里插入图片描述

for task in args['task_name_list']:
    args['task_name'] = task
    args['data_path'] = '../data/task_data/' + args['task_name'] + '.npy'

    all_times_train_result = []
    all_times_val_result = []
    all_times_test_result = []
    result_pd = pd.DataFrame()
    result_pd['index'] = ['roc_auc', 'accuracy', 'sensitivity', 'specificity', 'f1-score', 'precision', 'recall',
                          'error rate', 'mcc']

    for time_id in range(args['times']):
        set_random_seed(2020+time_id)
        train_set, val_set, test_set, task_number = build_data.load_data_for_random_splited(
            data_path=args['data_path'], shuffle=True
        )
        print("Molecule graph is loaded!")

1.load_data_for_random_splited

def load_data_for_random_splited(data_path='example.npy', shuffle=True):
    data = np.load(data_path, allow_pickle=True)
    smiles_list = data[0]
    tokens_idx_list = data[1]
    labels_list = data[2]
    mask_list = data[3]
    group_list = data[4]
    if shuffle:
        random.shuffle(group_list)
    print(group_list)
    train_set = []
    val_set = []
    test_set = []
    task_number = len(labels_list[1])
    for i, group in enumerate(group_list):
        molecule = [smiles_list[i], tokens_idx_list[i], labels_list[i], mask_list[i]]
        if group == 'training':
            train_set.append(molecule)
        elif group == 'val':
            val_set.append(molecule)
        else:
            test_set.append(molecule)
    print('Training set: {}, Validation set: {}, Test set: {}, task number: {}'.format(
            len(train_set), len(val_set), len(test_set), task_number))
    return train_set, val_set, test_set, task_number

2.model

train_loader = DataLoader(dataset=train_set,
                                  batch_size=args['batch_size'],
                                  shuffle=True,
                                  collate_fn=collate_data)

        val_loader = DataLoader(dataset=val_set,
                                batch_size=args['batch_size'],
                                collate_fn=collate_data)

        test_loader = DataLoader(dataset=test_set,
                                 batch_size=args['batch_size'],
                                 collate_fn=collate_data)
        pos_weight_task = pos_weight(train_set)
        one_time_train_result = []
        one_time_val_result = []
        one_time_test_result = []
        print('***************************************************************************************************')
        print('{}, {}/{} time'.format(args['task_name'], time_id+1, args['times']))
        print('***************************************************************************************************')

        loss_criterion = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight_task.to(args['device']))
        model = K_BERT_WCL(d_model=args['d_model'], n_layers=args['n_layers'], vocab_size=args['vocab_size'],
                            maxlen=args['maxlen'], d_k=args['d_k'], d_v=args['d_v'], n_heads=args['n_heads'], d_ff=args['d_ff'],
                            global_label_dim=args['global_labels_dim'], atom_label_dim=args['atom_labels_dim'])
        stopper = EarlyStopping(patience=args['patience'], pretrained_model=args['pretrain_model'],
                                pretrain_layer=args['pretrain_layer'],
                                task_name=args['task_name']+'_downstream_k_bert_wcl', mode=args['mode'])
        model.to(args['device'])
        stopper.load_pretrained_model(model)
        optimizer = Adam(model.parameters(), lr=args['lr'])

2.1.pos_weight

def pos_weight(train_set):
    smiles, tokens_idx, labels, mask = map(list, zip(*train_set))
    task_pos_weight_list = []
    for j in range(len(labels[1])):
        num_pos = 0
        num_impos = 0
        for i in labels:
            if i[j] == 1:
                num_pos = num_pos + 1
            if i[j] == 0:
                num_impos = num_impos + 1
        task_pos_weight = num_impos / (num_pos+0.00000001)
        task_pos_weight_list.append(task_pos_weight)
    return torch.tensor(task_pos_weight_list)
  • 这里不理解为什么这么设置 task_pos_weight_list

1.2.load_pretrained_model

    def load_pretrained_model(self, model):
        if self.pretrain_layer == 1:
            pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight', 'embedding.norm.weight', 'embedding.norm.bias', 'layers.0.enc_self_attn.linear.weight', 'layers.0.enc_self_attn.linear.bias', 'layers.0.enc_self_attn.layernorm.weight', 'layers.0.enc_self_attn.layernorm.bias', 'layers.0.enc_self_attn.W_Q.weight', 'layers.0.enc_self_attn.W_Q.bias', 'layers.0.enc_self_attn.W_K.weight', 'layers.0.enc_self_attn.W_K.bias', 'layers.0.enc_self_attn.W_V.weight', 'layers.0.enc_self_attn.W_V.bias', 'layers.0.pos_ffn.fc.0.weight', 'layers.0.pos_ffn.fc.2.weight', 'layers.0.pos_ffn.layernorm.weight', 'layers.0.pos_ffn.layernorm.bias']

        elif self.pretrain_layer == 2:
            pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight', 'embedding.norm.weight', 'embedding.norm.bias', 'layers.0.enc_self_attn.linear.weight', 'layers.0.enc_self_attn.linear.bias', 'layers.0.enc_self_attn.layernorm.weight', 'layers.0.enc_self_attn.layernorm.bias', 'layers.0.enc_self_attn.W_Q.weight', 'layers.0.enc_self_attn.W_Q.bias', 'layers.0.enc_self_attn.W_K.weight', 'layers.0.enc_self_attn.W_K.bias', 'layers.0.enc_self_attn.W_V.weight', 'layers.0.enc_self_attn.W_V.bias', 'layers.0.pos_ffn.fc.0.weight', 'layers.0.pos_ffn.fc.2.weight', 'layers.0.pos_ffn.layernorm.weight', 'layers.0.pos_ffn.layernorm.bias', 'layers.1.enc_self_attn.linear.weight', 'layers.1.enc_self_attn.linear.bias', 'layers.1.enc_self_attn.layernorm.weight', 'layers.1.enc_self_attn.layernorm.bias', 'layers.1.enc_self_attn.W_Q.weight', 'layers.1.enc_self_attn.W_Q.bias', 'layers.1.enc_self_attn.W_K.weight', 'layers.1.enc_self_attn.W_K.bias', 'layers.1.enc_self_attn.W_V.weight', 'layers.1.enc_self_attn.W_V.bias', 'layers.1.pos_ffn.fc.0.weight', 'layers.1.pos_ffn.fc.2.weight', 'layers.1.pos_ffn.layernorm.weight', 'layers.1.pos_ffn.layernorm.bias']

        elif self.pretrain_layer == 3:
		...
		elif self.pretrain_layer == 'all_12layer':
            pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight',
                                     'embedding.norm.weight', 'embedding.norm.bias',
                                     'layers.0.enc_self_attn.linear.weight', 'layers.0.enc_self_attn.linear.bias',
                                     'layers.0.enc_self_attn.layernorm.weight', 'layers.0.enc_self_attn.layernorm.bias',
                                     'layers.0.enc_self_attn.W_Q.weight', 'layers.0.enc_self_attn.W_Q.bias',
                                     'layers.0.enc_self_attn.W_K.weight', 'layers.0.enc_self_attn.W_K.bias',
                                     'layers.0.enc_self_attn.W_V.weight', 'layers.0.enc_self_attn.W_V.bias',
                                     'layers.0.pos_ffn.fc.0.weight', 'layers.0.pos_ffn.fc.2.weight',
                                     'layers.0.pos_ffn.layernorm.weight', 'layers.0.pos_ffn.layernorm.bias',
                                     'layers.1.enc_self_attn.linear.weight', 'layers.1.enc_self_attn.linear.bias',
                                     'layers.1.enc_self_attn.layernorm.weight', 'layers.1.enc_self_attn.layernorm.bias',
                                     'layers.1.enc_self_attn.W_Q.weight', 'layers.1.enc_self_attn.W_Q.bias',
                                     'layers.1.enc_self_attn.W_K.weight', 'layers.1.enc_self_attn.W_K.bias',
                                     'layers.1.enc_self_attn.W_V.weight', 'layers.1.enc_self_attn.W_V.bias',
                                     'layers.1.pos_ffn.fc.0.weight', 'layers.1.pos_ffn.fc.2.weight',
                                     'layers.1.pos_ffn.layernorm.weight', 'layers.1.pos_ffn.layernorm.bias',
                                     'layers.2.enc_self_attn.linear.weight', 'layers.2.enc_self_attn.linear.bias',
                                     'layers.2.enc_self_attn.layernorm.weight', 'layers.2.enc_self_attn.layernorm.bias',
                                     'layers.2.enc_self_attn.W_Q.weight', 'layers.2.enc_self_attn.W_Q.bias',
                                     'layers.2.enc_self_attn.W_K.weight', 'layers.2.enc_self_attn.W_K.bias',
                                     'layers.2.enc_self_attn.W_V.weight', 'layers.2.enc_self_attn.W_V.bias',
                                     'layers.2.pos_ffn.fc.0.weight', 'layers.2.pos_ffn.fc.2.weight',
                                     'layers.2.pos_ffn.layernorm.weight', 'layers.2.pos_ffn.layernorm.bias',
                                     'layers.3.enc_self_attn.linear.weight', 'layers.3.enc_self_attn.linear.bias',
                                     'layers.3.enc_self_attn.layernorm.weight', 'layers.3.enc_self_attn.layernorm.bias',
                                     'layers.3.enc_self_attn.W_Q.weight', 'layers.3.enc_self_attn.W_Q.bias',
                                     'layers.3.enc_self_attn.W_K.weight', 'layers.3.enc_self_attn.W_K.bias',
                                     'layers.3.enc_self_attn.W_V.weight', 'layers.3.enc_self_attn.W_V.bias',
                                     'layers.3.pos_ffn.fc.0.weight', 'layers.3.pos_ffn.fc.2.weight',
                                     'layers.3.pos_ffn.layernorm.weight', 'layers.3.pos_ffn.layernorm.bias',
                                     'layers.4.enc_self_attn.linear.weight', 'layers.4.enc_self_attn.linear.bias',
                                     'layers.4.enc_self_attn.layernorm.weight', 'layers.4.enc_self_attn.layernorm.bias',
                                     'layers.4.enc_self_attn.W_Q.weight', 'layers.4.enc_self_attn.W_Q.bias',
                                     'layers.4.enc_self_attn.W_K.weight', 'layers.4.enc_self_attn.W_K.bias',
                                     'layers.4.enc_self_attn.W_V.weight', 'layers.4.enc_self_attn.W_V.bias',
                                     'layers.4.pos_ffn.fc.0.weight', 'layers.4.pos_ffn.fc.2.weight',
                                     'layers.4.pos_ffn.layernorm.weight', 'layers.4.pos_ffn.layernorm.bias',
                                     'layers.5.enc_self_attn.linear.weight', 'layers.5.enc_self_attn.linear.bias',
                                     'layers.5.enc_self_attn.layernorm.weight', 'layers.5.enc_self_attn.layernorm.bias',
                                     'layers.5.enc_self_attn.W_Q.weight', 'layers.5.enc_self_attn.W_Q.bias',
                                     'layers.5.enc_self_attn.W_K.weight', 'layers.5.enc_self_attn.W_K.bias',
                                     'layers.5.enc_self_attn.W_V.weight', 'layers.5.enc_self_attn.W_V.bias',
                                     'layers.5.pos_ffn.fc.0.weight', 'layers.5.pos_ffn.fc.2.weight',
                                     'layers.5.pos_ffn.layernorm.weight', 'layers.5.pos_ffn.layernorm.bias',

                                     'layers.6.enc_self_attn.linear.weight', 'layers.6.enc_self_attn.linear.bias',
                                     'layers.6.enc_self_attn.layernorm.weight', 'layers.6.enc_self_attn.layernorm.bias',
                                     'layers.6.enc_self_attn.W_Q.weight', 'layers.6.enc_self_attn.W_Q.bias',
                                     'layers.6.enc_self_attn.W_K.weight', 'layers.6.enc_self_attn.W_K.bias',
                                     'layers.6.enc_self_attn.W_V.weight', 'layers.6.enc_self_attn.W_V.bias',
                                     'layers.6.pos_ffn.fc.0.weight', 'layers.6.pos_ffn.fc.2.weight',
                                     'layers.6.pos_ffn.layernorm.weight', 'layers.6.pos_ffn.layernorm.bias',
                                     'layers.7.enc_self_attn.linear.weight', 'layers.7.enc_self_attn.linear.bias',
                                     'layers.7.enc_self_attn.layernorm.weight', 'layers.7.enc_self_attn.layernorm.bias',
                                     'layers.7.enc_self_attn.W_Q.weight', 'layers.7.enc_self_attn.W_Q.bias',
                                     'layers.7.enc_self_attn.W_K.weight', 'layers.7.enc_self_attn.W_K.bias',
                                     'layers.7.enc_self_attn.W_V.weight', 'layers.7.enc_self_attn.W_V.bias',
                                     'layers.7.pos_ffn.fc.0.weight', 'layers.7.pos_ffn.fc.2.weight',
                                     'layers.7.pos_ffn.layernorm.weight', 'layers.7.pos_ffn.layernorm.bias',
                                     'layers.8.enc_self_attn.linear.weight', 'layers.8.enc_self_attn.linear.bias',
                                     'layers.8.enc_self_attn.layernorm.weight', 'layers.8.enc_self_attn.layernorm.bias',
                                     'layers.8.enc_self_attn.W_Q.weight', 'layers.8.enc_self_attn.W_Q.bias',
                                     'layers.8.enc_self_attn.W_K.weight', 'layers.8.enc_self_attn.W_K.bias',
                                     'layers.8.enc_self_attn.W_V.weight', 'layers.8.enc_self_attn.W_V.bias',
                                     'layers.8.pos_ffn.fc.0.weight', 'layers.8.pos_ffn.fc.2.weight',
                                     'layers.8.pos_ffn.layernorm.weight', 'layers.8.pos_ffn.layernorm.bias',
                                     'layers.9.enc_self_attn.linear.weight', 'layers.9.enc_self_attn.linear.bias',
                                     'layers.9.enc_self_attn.layernorm.weight', 'layers.9.enc_self_attn.layernorm.bias',
                                     'layers.9.enc_self_attn.W_Q.weight', 'layers.9.enc_self_attn.W_Q.bias',
                                     'layers.9.enc_self_attn.W_K.weight', 'layers.9.enc_self_attn.W_K.bias',
                                     'layers.9.enc_self_attn.W_V.weight', 'layers.9.enc_self_attn.W_V.bias',
                                     'layers.9.pos_ffn.fc.0.weight', 'layers.9.pos_ffn.fc.2.weight',
                                     'layers.9.pos_ffn.layernorm.weight', 'layers.9.pos_ffn.layernorm.bias',
                                     'layers.10.enc_self_attn.linear.weight', 'layers.10.enc_self_attn.linear.bias',
                                     'layers.10.enc_self_attn.layernorm.weight',
                                     'layers.10.enc_self_attn.layernorm.bias', 'layers.10.enc_self_attn.W_Q.weight',
                                     'layers.10.enc_self_attn.W_Q.bias', 'layers.10.enc_self_attn.W_K.weight',
                                     'layers.10.enc_self_attn.W_K.bias', 'layers.10.enc_self_attn.W_V.weight',
                                     'layers.10.enc_self_attn.W_V.bias', 'layers.10.pos_ffn.fc.0.weight',
                                     'layers.10.pos_ffn.fc.2.weight', 'layers.10.pos_ffn.layernorm.weight',
                                     'layers.10.pos_ffn.layernorm.bias'
                                     'fc.1.weight', 'fc.1.bias', 'fc.3.weight', 'fc.3.bias', 'classifier_global.weight',
                                     'classifier_global.bias', 'classifier_atom.weight', 'classifier_atom.bias']
		pretrained_model = torch.load(self.pretrained_model, map_location=torch.device('cpu'))
        # pretrained_model = torch.load(self.pretrained_model)
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_model['model_state_dict'].items() if k in pretrained_parameters}
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict, strict=False)

3.run

for epoch in range(args['num_epochs']):
        train_score = run_a_train_global_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
        # Validation and early stop
        _ = run_an_eval_global_epoch(args, model, train_loader)[0]
        val_score = run_an_eval_global_epoch(args, model, val_loader)[0]
        test_score = run_an_eval_global_epoch(args, model, test_loader)[0]
        if epoch < 5:
            early_stop = stopper.step(0, model)
        else:
            early_stop = stopper.step(val_score, model)
        print('epoch {:d}/{:d}, {}, lr: {:.6f},  train: {:.4f}, valid: {:.4f}, best valid {:.4f}, '
              'test: {:.4f}'.format(
              epoch + 1, args['num_epochs'], args['metric_name'], optimizer.param_groups[0]['lr'], train_score, val_score,
              stopper.best_score, test_score))
        if early_stop:
            break
stopper.load_checkpoint(model)

3.1.run_an_eval_global_epoch

def run_an_eval_global_epoch(args, model, data_loader):
    model.eval()
    eval_meter = Meter()
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            smiles, token_idx, global_labels, mask = batch_data
            token_idx = token_idx.long().to(args['device'])
            mask = mask.float().to(args['device'])
            global_labels = global_labels.float().to(args['device'])
            logits_global = model(token_idx)
            eval_meter.update(logits_global, global_labels, mask=mask)
            del token_idx, global_labels
            torch.cuda.empty_cache()
    y_pred, y_true = eval_meter.compute_metric('return_pred_true')
    y_true_list = y_true.squeeze(dim=1).tolist()
    y_pred_list = torch.sigmoid(y_pred).squeeze(dim=1).tolist()
    # save prediction
    y_pred_label = [1 if x >= 0.5 else 0 for x in y_pred_list]
    auc = metrics.roc_auc_score(y_true_list, y_pred_list)
    accuracy = metrics.accuracy_score(y_true_list, y_pred_label)
    se, sp = sesp_score(y_true_list, y_pred_label)
    pre, rec, f1, sup = metrics.precision_recall_fscore_support(y_true_list, y_pred_label)
    mcc = metrics.matthews_corrcoef(y_true_list, y_pred_label)
    f1 = f1[1]
    rec = rec[1]
    pre = pre[1]
    err = 1 - accuracy
    result = [auc, accuracy, se, sp, f1, pre, rec, err, mcc]
    return result

3.2.step

def step(self, score, model):
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif self._check(score, self.best_score):
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0
        else:
            self.counter += 1
            print(
                'EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_森罗万象

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值