2021-12-17

 2021SC@SDUSC

在main.py中,训练、开发和测试分了两个分支处理,训练执行train模块,

其他执行evaluate模块。

if args.mode == 'train':

        train_model(args, 0)

    else:

        evaluate_model(args, 0)

下面来分析evaluate_model

def evaluate_model(args, local_rank):

    eval_file = args.eval_file

    gpus = args.gpus

    args = PARAMS_MAP[args.encoder_type]

    if 'google' in args['lm_model']:

        lm_args = args['lm_model'][args['lm_model'].index('/')+1:]

    else:

        lm_args = args['lm_model']

    args['prefix'] = str(args['encoder_type']) + '_' + args['task'] + '_' + args['feature'] + '_lr' + str(

        args['lr']) + '_' + str(args['batch_multiplier']) + '_' + lm_args

    assert len(args['cnn_filters']) % 2 == 0

    model = GraphC_QAModel(args, local_rank)

model.evaluate_model(eval_file, gpus)

主要还是两个模块 GraphC_QAModel和evaluate_model,这个调用的evaluate_model是在train.py程序的GraphC_QAModel类中定义的。下面分析这个evaluate_model模块。

这个模块读取开发和测试数据集的数据,构建模型并进行评估,返回模型性能的评估值。该模块输出是14个元素的向量值。

def evaluate_model(self, eval_file, gpus):

        self.device = torch.device("cuda:" + str(gpus) if torch.cuda.is_available() else "cpu")

        print('device', self.device)

        test_models = []

        if os.path.isdir(eval_file):

            for file in os.listdir(eval_file):

                fname = os.path.join(eval_file, file)

                if os.path.isfile(fname):

                    test_models.append(fname)

            model_args = torch.load(fname, map_location=self.device)['args']

        else:

            test_models.append(eval_file)

            model_args = torch.load(eval_file, map_location=self.device)['args']

        from data import Vocab, DataLoader, STR, END, CLS, SEL, TL, rCLS

        model_args = collections.namedtuple("HParams", sorted(model_args.keys()))(**model_args)

        vocabs = dict()

        vocabs['concept'] = Vocab(model_args.concept_vocab, 5, [CLS])

        vocabs['token'] = Vocab(model_args.token_vocab, 5, [STR, END])

        vocabs['token_char'] = Vocab(model_args.token_char_vocab, 100, [STR, END])

        vocabs['concept_char'] = Vocab(model_args.concept_char_vocab, 100, [STR, END])

        vocabs['relation'] = Vocab(model_args.relation_vocab, 5, [CLS, rCLS, SEL, TL])

        lexical_mapping = LexicalMap()

        if self.args.encoder_:

            vocabs, lexical_mapping = self._prepare_data()

            config_class, model_class, tokenizer_class = MODEL_CLASSES[self.args.encoder_type]

            bert_config = config_class.from_pretrained(

                self.args.lm_model,

            )

            bert_tokenizer = tokenizer_class.from_pretrained(

                self.args.lm_model

            )

            bert_model = model_class.from_pretrained(

                self.args.lm_model,

                from_tf=bool(".ckpt" in self.args.lm_model),

                config=self.args.lm_model,

            ).to(self.device)

            eval_model = Reasoning_AMR_CN_DUAL(vocabs,

                                               model_args.concept_char_dim, model_args.concept_dim,

                                               model_args.cnn_filters, model_args.char2concept_dim,

                                               model_args.rel_dim, model_args.rnn_hidden_size, model_args.rnn_num_layers,

                                               model_args.embed_dim, model_args.bert_embed_dim, model_args.ff_embed_dim,

                                               model_args.num_heads,

                                               model_args.dropout,

                                               model_args.snt_layer,

                                               model_args.graph_layers,

                                               model_args.pretrained_file, self.device, model_args.batch_size,

                                               model_args.lm_model, bert_config, bert_model, bert_tokenizer, model_args.bert_max_length,

                                               model_args.n_answers,

                                               model_args.encoder_type,

                                               model_args.gcn_concept_dim, model_args.gcn_hidden_dim, model_args.gcn_output_dim, model_args.max_conceptnet_length,

                                               model_args.conceptnet_path,

            )

        else:

            eval_model = ''

        test_data = DataLoader(self.args, vocabs, lexical_mapping, self.args.test_data, model_args.batch_size,

                               for_train='Eval')

        answer_tempelate = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E'}

        # Evaluate!

        logger.info("***** Running Evaluating *****")

        logger.info("  Task: %s", self.args.task)

        logger.info("  Num examples = %d", len(test_data))

        logger.info("  Running Language Model = %s", model_args.lm_model)

        logger.info("  Running Model = %s", model_args.encoder_type)

        logger.info("  Running File = %s", eval_file)

        logger.info("  Test data = %s", self.args.test_data)

        for test_model in test_models:

            eval_model.load_state_dict(torch.load(test_model, map_location=self.device)['model'])

            eval_model = eval_model.cuda(self.device)

            eval_model.eval()

            running_corrects = 0

            eval_loss_sum, batch_acm = 0, 0

            with open(test_model + model_args.prefix + '.csv', 'w', newline='') as csvfile:

                csvwriter = csv.writer(csvfile, delimiter=',',

                                       quoting=csv.QUOTE_MINIMAL)

                for batch in test_data:

                    batch = move_to_cuda(batch, self.device)

                    eval_logits, eval_labels, ans_ids, = eval_model(batch, train=False)

                    eval_logits_forpred = eval_logits.clone().detach()

                    pred_values, pred_indices = torch.max(eval_logits_forpred, 1)

                    eval_labels = eval_labels.tolist()

                    eval_pred = pred_indices.tolist()

                    corrects = [i for i, j in zip(eval_labels, eval_pred) if i == j]

                    batch_acm += 1

                    # Statistics

                    running_corrects += len(corrects)

                    for i, pred in enumerate(eval_pred):

                        csvwriter.writerow([ans_ids[i], answer_tempelate[int(pred_indices[i])]])

                print('Overall accuracy: ', (running_corrects / len(test_data)))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值