TagTree 项目总结

利用Roberta模型,训练QA系统,我们QA的回答主要分为文段中词语和文段中数字及数字推导出的结果
数据集格式如下:

{
  "table": {
    "uid": "3ffd9053-a45d-491c-957a-1b2fa0af0570",
    "table": [
      [
        "",
        "2019",
        "2018",
        "2017"
      ],
      [
        "Fixed Price",
        "$  1,452.4",
        "$  1,146.2",
        "$  1,036.9"
      ],
      [
        "..."
      ]
    ]
  },
  "paragraphs": [
    {
      "uid": "f4ac7069-10a2-47e9-995c-3903293b3d47",
      "order": 1,
      "text": "Sales by Contract Type: Substantially all of our contracts are fixed-price type contracts. Sales included in Other contract types represent cost plus and time and material type contracts."
    },
    {
      "uid": "79e37805-6558-4a8c-b033-32be6bffef48",
      "order": 2,
      "text": "On a fixed-price type contract, ... The table below presents total net sales disaggregated by contract type (in millions)"
    }
  ],
  "questions": [
    {
      "uid": "f4142349-eb72-49eb-9a76-f3ccb1010cbc",
      "order": 1,
      "question": "In which year is the amount of total sales the largest?",
      "answer": [
        "2019"
      ],
      "derivation": "",
      "answer_type": "span",
      "answer_from": "table-text",
      "rel_paragraphs": [
        "2"
      ],
      "req_comparison": true,
      "scale": ""
    },
    {
      "uid": "eb787966-fa02-401f-bfaf-ccabf3828b23",
      "order": 2,
      "question": "What is the change in Other in 2019 from 2018?",
      "answer": -12.6,
      "derivation": "44.1-56.7",
      "answer_type": "arithmetic",
      "answer_from": "table-text",
      "rel_paragraphs": [
        "2"
      ],
      "req_comparison": false,
      "scale": "million"
    }
  ]
}

读取数据集

读取数据的主文件

使用Reader类对文件进行读取,然后将读取数据的pickle文件储存起来

import os
import pickle
import argparse
from data_builder.tatqa_roberta_tagtree_dataset import TagTaTQATestReader, TagTaTQAReader
from transformers.tokenization_roberta import RobertaTokenizer

from transformers import BertTokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--model_path", type=str, default='./')
parser.add_argument("--passage_length_limit", type=int, default=463)
parser.add_argument("--question_length_limit", type=int, default=46)
parser.add_argument("--encoder", type=str, default="bert")
parser.add_argument("--mode", type=str, default='train')

args = parser.parse_args()

if args.encoder == 'roberta':
    tokenizer = RobertaTokenizer.from_pretrained(args.model_path + "/roberta.large")
    sep = '<s>'
elif args.encoder == 'bert':
    tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
    sep = '[SEP]'
elif args.encoder == 'finbert':
    tokenizer = BertTokenizer.from_pretrained(args.model_path + "/finbert")
    sep = '[SEP]'
if args.mode == 'test':
    data_reader = TagTaTQATestReader(tokenizer, args.passage_length_limit, args.question_length_limit, sep=sep)
    data_mode = ["test"]
else:
    data_reader = TagTaTQAReader(tokenizer, args.passage_length_limit, args.question_length_limit, sep=sep)
    data_mode = ["train", "dev"]

data_format = "tatqa_dataset_{}.json"
print(f'==== NOTE ====: encoder:{args.encoder}, mode:{args.mode}')

for dm in data_mode:
    dpath = os.path.join(args.input_path, data_format.format(dm))
    data = data_reader._read(dpath)
    print(data_reader.skip_count)
    data_reader.skip_count = 0
    print("Save data to {}.".format(os.path.join(args.output_dir, f"tagtree_{args.encoder}_cached_{dm}.pkl")))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, f"tagtree_{args.encoder}_cached_{dm}.pkl"), "wb") as f:
        pickle.dump(data, f)
读取数据的Reader类

接受tokenizer初始化,read函数接受文件路径,json读取数据文件,再分别处理文本得到token和id,处理答案得到对应标注和label,最后每条数据由字典保存,整个数据集返回成一个列表

class TagTaTQAReader(object):
    def __init__(self, tokenizer,
                 passage_length_limit: int = None, question_length_limit: int = None, sep="<s>"):
        self.max_pieces = 512
        self.tokenizer = tokenizer
        self.passage_length_limit = passage_length_limit
        self.question_length_limit = question_length_limit
        self.sep_start = self.tokenizer._convert_token_to_id(sep)
        self.sep_end = self.tokenizer._convert_token_to_id(sep)
        tokens = self.tokenizer._tokenize("Feb 2 Nov")
        self.skip_count = 0

        self.HEAD_CLASSES = HEAD_CLASSES_

    def _make_instance(self, input_ids, attention_mask, token_type_ids, paragraph_mask, table_mask,
                       paragraph_number_value, table_cell_number_value, paragraph_index, table_cell_index,
                       tags_ground_truth, scale_ground_truth,
                       paragraph_tokens, table_cell_tokens, answer_dict, question_id, span_pos_labels, head_class):
        return {
            "input_ids": np.array(input_ids),
            "attention_mask": np.array(attention_mask),
            "token_type_ids": np.array(token_type_ids),
            "paragraph_mask": np.array(paragraph_mask),
            "table_mask": np.array(table_mask),
            "paragraph_number_value": np.array(paragraph_number_value),
            "table_cell_number_value": np.array(table_cell_number_value),
            "paragraph_index": np.array(paragraph_index),
            "table_cell_index": np.array(table_cell_index),
            "tag_labels": np.array(tags_ground_truth),
            "scale_label": int(scale_ground_truth),
            "paragraph_tokens": paragraph_tokens,
            "table_cell_tokens": table_cell_tokens,
            "answer_dict": answer_dict,
            "question_id": question_id,
            "span_pos_labels": np.array(span_pos_labels),
            "head_label": int(head_class)
        }

    def _to_instance(self, question: str, table: List[List[str]], paragraphs: List[Dict], answer_from: str,
                     answer_type: str, answer: str, answer_mapping: Dict, scale: str,
                     question_id: str):
        question_text = question.strip()

        head_class = get_head_class(answer_type, answer_mapping, self.HEAD_CLASSES)
        scale_class = SCALE.index(scale)

        if head_class is None:
            self.skip_count += 1
            return None

        table_cell_tokens, table_ids, table_tags, table_cell_number_value, table_cell_index = \
            table_tokenize(table, self.tokenizer, answer_mapping, answer_type)

        for i in range(len(table)):
            for j in range(len(table[i])):
                if table[i][j] == '' or table[i][j] == 'N/A' or table[i][j] == 'n/a':
                    table[i][j] = "NONE"
        table = pd.DataFrame(table, dtype=np.str)
        column_relation = {}
        for column_name in table.columns.values.tolist():
            column_relation[column_name] = str(column_name)
        table.rename(columns=column_relation, inplace=True)

        paragraph_tokens, paragraph_ids, paragraph_tags, paragraph_word_piece_mask, paragraph_number_mask, \
        paragraph_number_value, paragraph_index = \
            paragraph_tokenize(question, paragraphs, self.tokenizer, answer_mapping, answer_type)
        question_ids = question_tokenizer(question_text, self.tokenizer)

        input_ids, attention_mask, paragraph_mask, paragraph_number_value, paragraph_index, \
        table_mask, table_number_value, table_index, tags, token_type_ids, span_pos_label = \
            _concat(question_ids, table_ids, table_tags, table_cell_index, table_cell_number_value,
                    paragraph_ids, paragraph_tags, paragraph_index, paragraph_number_value, answer_type,
                    self.sep_start, self.sep_end, self.question_length_limit,
                    self.passage_length_limit, self.max_pieces)
        answer_dict = {"answer_type": answer_type, "answer": answer, "scale": scale, "answer_from": answer_from}
        return self._make_instance(input_ids, attention_mask, token_type_ids, paragraph_mask, table_mask,
                                   paragraph_number_value, table_number_value, paragraph_index, table_index,
                                   tags, scale_class, paragraph_tokens, table_cell_tokens, answer_dict, question_id,
                                   span_pos_label, head_class)

    def _read(self, file_path: str):
        print("Reading file at %s", file_path)
        with open(file_path) as dataset_file:
            dataset = json.load(dataset_file)
        print("Reading the tatqa dataset")
        instances = []
        key_error_count = 0
        index_error_count = 0
        assert_error_count = 0
        for one in tqdm(dataset):
            table = one['table']['table']
            paragraphs = one['paragraphs']
            questions = one['questions']

            for question_answer in questions:
                try:
                    question = question_answer["question"].strip()
                    answer_type = question_answer["answer_type"]
                    answer = question_answer["answer"]
                    answer_mapping = question_answer["mapping"]
                    answer_from = question_answer["answer_from"]
                    scale = question_answer["scale"]
                    instance = self._to_instance(question, table, paragraphs, answer_from,
                                                 answer_type, answer, answer_mapping, scale,
                                                 question_answer["uid"])
                    if instance is not None:
                        instances.append(instance)
                except RuntimeError as e:
                    print(f"run time error:{e}")
                    print(question_answer["uid"])
                except KeyError:
                    key_error_count += 1
                    print(question_answer["uid"])
                    print("KeyError. Total Error Count: {}".format(key_error_count))
                except IndexError:
                    index_error_count += 1
                    print(question_answer["uid"])

                    print("IndexError. Total Error Count: {}".format(index_error_count))
                except AssertionError:
                    assert_error_count += 1
                    print(question_answer["uid"])
                    print("AssertError. Total Error Count: {}".format(assert_error_count))
        return instances

训练

train & eval 主文件

分别对数据生成,模型,结果度量进行封装

import os
import json
import argparse
from datetime import datetime
from tools.model import TagtreeFineTuningModel
import options
from pprint import pprint
from data_builder.data_util import get_op_1, get_arithmetic_op_index_1, get_op_2, get_arithmetic_op_index_2
from data_builder.data_util import get_op_3, get_arithmetic_op_index_3
from data_builder.data_util import OPERATOR_CLASSES_
from tools.utils import create_logger, set_environment
from data_builder.tatqa_roberta_tagtree_batch_gen import TaTQABatchGen, TaTQATestBatchGen
from transformers import RobertaModel, BertModel
from tagtree.modeling_roberta_tagtree import MutiHeadModel
import torch.nn as nn
import numpy as np

parser = argparse.ArgumentParser("Tagop training task.")
options.add_data_args(parser)
options.add_train_args(parser)
options.add_bert_args(parser)
parser.add_argument("--encoder", type=str, default='roberta')
parser.add_argument("--op_mode", type=int, default=0)
parser.add_argument("--finbert_model", type=str, default='dataset_tagtree/finbert')
parser.add_argument("--ablation_mode", type=int, default=0)
parser.add_argument("--test_data_dir", type=str, default="tag_tree/data/roberta")
parser.add_argument("--model_path", type=str, default='./')

args = parser.parse_args()
if args.ablation_mode != 0:
    args.save_dir = args.save_dir + "_{}_{}".format(args.op_mode, args.ablation_mode)
    args.data_dir = args.data_dir + "_{}_{}".format(args.op_mode, args.ablation_mode)

if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

args.cuda = args.gpu_num > 0
args_path = os.path.join(args.save_dir, "args.json")
with open(args_path, "w") as f:
    json.dump((vars(args)), f)

args.batch_size = args.batch_size // args.gradient_accumulation_steps

logger = create_logger("Roberta Training", log_file=os.path.join(args.save_dir, args.log_file))

pprint(args)
set_environment(args.seed, args.cuda)


def main():
    best_result = float("-inf")
    logger.info("Loading data...")

    train_itr = TaTQABatchGen(args, data_mode="train", encoder=args.encoder)
    dev_itr = TaTQABatchGen(args, data_mode="dev", encoder=args.encoder)
    test_itr = TaTQATestBatchGen(args, data_mode="test", encoder=args.encoder)
    num_train_steps = int(args.max_epoch * len(train_itr) / args.gradient_accumulation_steps)
    logger.info("Num update steps {}!".format(num_train_steps))

    logger.info(f"Build {args.encoder} model.")
    if args.encoder == 'bert':
        bert_model = BertModel.from_pretrained('bert-large-uncased')
    elif args.encoder == 'roberta':
        bert_model = RobertaModel.from_pretrained(args.roberta_model)
    elif args.encoder == 'finbert':
        bert_model = BertModel.from_pretrained(args.finbert_model)

    if args.ablation_mode == 0:
        arithmetic_op_index = [3, 4, 6, 7, 8, 9]
    elif args.ablation_mode == 1:
        arithmetic_op_index = get_arithmetic_op_index_1(args.op_mode)
    elif args.ablation_mode == 2:
        arithmetic_op_index = get_arithmetic_op_index_2(args.op_mode)
    else:
        arithmetic_op_index = get_arithmetic_op_index_3(args.op_mode)

    network = MutiHeadModel(bert=bert_model,
                            config=bert_model.config,
                            scale_criterion=nn.CrossEntropyLoss(reduction='sum'),
                            scale_classes=5,
                            arithmetic_op_index=arithmetic_op_index,
                            op_mode=args.op_mode,
                            ablation_mode=args.ablation_mode, )

    model = TagtreeFineTuningModel(args, network, num_train_steps=num_train_steps)
    train_start = datetime.now()
    first = True
    for epoch in range(1, args.max_epoch + 1):
        model.reset()
        if not first:
            train_itr.reset()
        first = False
        logger.info('At epoch {}'.format(epoch))
        for step, batch in enumerate(train_itr):
            model.update(batch)
            if model.step % (args.log_per_updates * args.gradient_accumulation_steps) == 0 or model.step == 1:
                logger.info("Updates[{0:6}] train loss[{1:.5f}] head acc[{2:.5f}]remaining[{3}].\r\n".format(
                    model.updates, model.train_loss.avg, model.head_acc.avg,
                    str((datetime.now() - train_start) / (step + 1) * (num_train_steps - step - 1)).split('.')[0]))
                model.avg_reset()
        model.get_metrics(logger)
        model.reset()
        model.avg_reset()
        model.evaluate(dev_itr)
        logger.info("Evaluate epoch:[{0:6}] eval loss[{1:.5f}] head acc[{2:.5f}].\r\n".format(epoch, model.dev_loss.avg,
                                                                                              model.head_acc.avg))
        model.get_metrics(logger)
        model.avg_reset()
        model.predict(test_itr)
        metrics = model.get_metrics(logger)

        if metrics["f1"] > best_result:
            save_prefix = os.path.join(args.save_dir, "checkpoint_best")
            model.save(save_prefix, epoch)
            best_result = metrics["f1"]
            logger.info("Best eval F1 {} at epoch {}.\r\n".format(best_result, epoch))


if __name__ == "__main__":
    main()

数据生成

初始化接受文件路径,读取pickle文件,将所有数据打乱然后生成多个batch,然后self.offset记录当前batch位置,使用zip(*batch)将每个batch里的数据按不同类型数据分开(原来是按不同instance分开的),然后对不同类别数据转化成tensor或list形式,yield每个batch

class TaTQABatchGen(object):
    def __init__(self, args, data_mode, encoder='roberta'):
        dpath = f"tagtree_{encoder}_cached_{data_mode}.pkl"
        self.is_train = data_mode == "train"
        self.args = args
        with open(os.path.join(args.data_dir, dpath), 'rb') as f:
            print("Load data from {}.".format(dpath))
            data = pickle.load(f)

        all_data = []
        for item in data:
            input_ids = torch.from_numpy(item["input_ids"])
            attention_mask = torch.from_numpy(item["attention_mask"])
            token_type_ids = torch.from_numpy(item["token_type_ids"])
            paragraph_mask = torch.from_numpy(item["paragraph_mask"])
            table_mask = torch.from_numpy(item["table_mask"])
            paragraph_numbers = item["paragraph_number_value"]
            table_cell_numbers = item["table_cell_number_value"]
            paragraph_index = torch.from_numpy(item["paragraph_index"])
            table_cell_index = torch.from_numpy(item["table_cell_index"])
            tag_labels = torch.from_numpy(item["tag_labels"])
            scale_labels = torch.tensor(item["scale_label"])
            gold_answers = item["answer_dict"]
            paragraph_tokens = item["paragraph_tokens"]
            table_cell_tokens = item["table_cell_tokens"]
            question_id = item["question_id"]
            span_pos_labels = torch.from_numpy(item["span_pos_labels"])
            head_labels = torch.tensor(item["head_label"])
            all_data.append((input_ids, attention_mask, token_type_ids, paragraph_mask, table_mask, paragraph_index,
                             table_cell_index, tag_labels, scale_labels, gold_answers,
                             paragraph_tokens, table_cell_tokens, paragraph_numbers, table_cell_numbers, question_id,
                             span_pos_labels, head_labels))
        print("Load data size {}.".format(len(all_data)))
        self.data = TaTQABatchGen.make_batches(all_data, args.batch_size if self.is_train else args.eval_batch_size,
                                               self.is_train)
        self.offset = 0

    @staticmethod
    def make_batches(data, batch_size=32, is_train=True):
        if is_train:
            random.shuffle(data)
        if is_train:
            return [
                data[i: i + batch_size] if i + batch_size < len(data) else data[i:] + data[
                                                                                      :i + batch_size - len(data)]
                for i in range(0, len(data), batch_size)]
        return [data[i:i + batch_size] for i in range(0, len(data), batch_size)]

    def reset(self):
        if self.is_train:
            indices = list(range(len(self.data)))
            random.shuffle(indices)
            self.data = [self.data[i] for i in indices]
            for i in range(len(self.data)):
                random.shuffle(self.data[i])
        self.offset = 0

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        while self.offset < len(self):
            batch = self.data[self.offset]
            self.offset += 1
            input_ids_batch, attention_mask_batch, token_type_ids_batch, paragraph_mask_batch, table_mask_batch, \
            paragraph_index_batch, table_cell_index_batch, tag_labels_batch, scale_labels_batch, \
            gold_answers_batch, paragraph_tokens_batch, \
            table_cell_tokens_batch, paragraph_numbers_batch, table_cell_numbers_batch, question_ids_batch, span_pos_labels_batch, head_labels_batch = zip(
                *batch)
            bsz = len(batch)
            input_ids = torch.LongTensor(bsz, 512)
            attention_mask = torch.LongTensor(bsz, 512)
            token_type_ids = torch.LongTensor(bsz, 512).fill_(0)
            paragraph_mask = torch.LongTensor(bsz, 512)
            table_mask = torch.LongTensor(bsz, 512)
            paragraph_index = torch.LongTensor(bsz, 512)
            table_cell_index = torch.LongTensor(bsz, 512)
            tag_labels = torch.LongTensor(bsz, 512)
            scale_labels = torch.LongTensor(bsz)
            span_pos_labels = torch.LongTensor(bsz, 2)
            head_labels = torch.LongTensor(bsz)
            paragraph_tokens = []
            table_cell_tokens = []
            gold_answers = []
            question_ids = []
            paragraph_numbers = []
            table_cell_numbers = []
            for i in range(bsz):
                input_ids[i] = input_ids_batch[i]
                attention_mask[i] = attention_mask_batch[i]
                token_type_ids[i] = token_type_ids_batch[i]
                paragraph_mask[i] = paragraph_mask_batch[i]
                table_mask[i] = table_mask_batch[i]
                paragraph_index[i] = paragraph_index_batch[i]
                table_cell_index[i] = table_cell_index_batch[i]
                tag_labels[i] = tag_labels_batch[i]
                scale_labels[i] = scale_labels_batch[i]
                paragraph_tokens.append(paragraph_tokens_batch[i])
                table_cell_tokens.append(table_cell_tokens_batch[i])
                paragraph_numbers.append(paragraph_numbers_batch[i])
                table_cell_numbers.append(table_cell_numbers_batch[i])
                gold_answers.append(gold_answers_batch[i])
                question_ids.append(question_ids_batch[i])
                span_pos_labels[i] = span_pos_labels_batch[i]
                head_labels[i] = head_labels_batch[i]
            out_batch = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids,
                         "paragraph_mask": paragraph_mask, "paragraph_index": paragraph_index, "tag_labels": tag_labels,
                         "scale_labels": scale_labels,
                         "paragraph_tokens": paragraph_tokens, "table_cell_tokens": table_cell_tokens,
                         "paragraph_numbers": paragraph_numbers,
                         "table_cell_numbers": table_cell_numbers, "gold_answers": gold_answers,
                         "question_ids": question_ids,
                         "span_pos_labels": span_pos_labels, "head_labels": head_labels, "table_mask": table_mask,
                         "table_cell_index": table_cell_index,
                         }
            if self.args.cuda:
                for k in out_batch.keys():
                    if isinstance(out_batch[k], torch.Tensor):
                        out_batch[k] = out_batch[k].cuda()
            yield out_batch
模型
模型封装

将模型训练验证测试的步骤封装在里面,loss的backward,optimizer的step和zero_grad封装在里面

class TagtreeFineTuningModel():
    def __init__(self, args, network, state_dict=None, num_train_steps=1):
        self.args = args
        self.train_loss = AverageMeter()
        self.dev_loss = AverageMeter()
        self.head_acc = AverageMeter()
        self.step = 0
        self.updates = 0
        self.network = network
        if state_dict is not None:
            print("Load Model!")
            self.network.load_state_dict(state_dict["state"])
        self.mnetwork = nn.DataParallel(self.network) if args.gpu_num > 1 else self.network

        self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.network.tapas.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': args.bert_weight_decay, 'lr': args.bert_learning_rate},
            {'params': [p for n, p in self.network.tapas.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, 'lr': args.bert_learning_rate},
            {'params': [p for n, p in self.network.named_parameters() if not n.startswith("tapas.")],
             "weight_decay": args.weight_decay, "lr": args.learning_rate}
        ]
        self.optimizer = Adam(optimizer_parameters,
                              lr=args.learning_rate,
                              warmup=args.warmup,
                              t_total=num_train_steps,
                              max_grad_norm=args.grad_clipping,
                              schedule=args.warmup_schedule)
        if self.args.gpu_num > 0:
            self.network.cuda()

    def avg_reset(self):
        self.train_loss.reset()
        self.dev_loss.reset()
        self.head_acc.reset()

    def update(self, tasks):
        self.network.train()
        output_dict = self.mnetwork(**tasks)
        loss = output_dict["loss"]
        self.train_loss.update(loss.item(), 1)
        acc = output_dict["head_acc"]
        self.head_acc.update(acc.item(), 1)
        if self.args.gradient_accumulation_steps > 1:
            loss /= self.args.gradient_accumulation_steps
        loss.backward()
        if (self.step + 1) % self.args.gradient_accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.updates += 1
        self.step += 1

    @torch.no_grad()
    def evaluate(self, dev_data_list):
        dev_data_list.reset()
        self.network.eval()
        with torch.no_grad():
            for batch in dev_data_list:
                output_dict = self.network(**batch)
                loss = output_dict["loss"]
                self.dev_loss.update(loss.item(), 1)
                acc = output_dict["head_acc"]
                self.head_acc.update(acc.item(), 1)
        # self.network.train()

    @torch.no_grad()
    def predict(self, test_data_list):
        test_data_list.reset()
        self.network.eval()
        # pred_json = {}
        for batch in tqdm(test_data_list):
            self.network.predict(**batch, mode="eval")

    def reset(self):
        self.mnetwork.reset()

    def get_df(self):
        return self.mnetwork.get_df()

    def get_metrics(self, logger=None):
        return self.mnetwork.get_metrics(logger, True)

    def save(self, prefix, epoch):
        network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()])
        other_params = {
            'optimizer': self.optimizer.state_dict(),
            'config': self.args,
            'epoch': epoch
        }
        state_path = prefix + ".pt"
        other_path = prefix + ".ot"
        torch.save(other_params, other_path)
        torch.save(network_state, state_path)
        print('model saved to {}'.format(prefix))
模型

在这里插入图片描述

init

传入embedding的bert类网络及相关参数,初始化预测head_predict网络,以及生成传入head_predict的summary vector的summary 网络,和span_head的start/end position network,还有剩下head的sequence tagging network

    def __init__(self,
                 bert,
                 config,
                 scale_classes: int,
                 head_count: int = 5,
                 hidden_size: int = None,
                 dropout_prob: float = None,
                 scale_criterion: nn.CrossEntropyLoss = None,
                 arithmetic_op_index: List = None,
                 op_mode: int = None,
                 ablation_mode: int = None,
                 ):
        super(MutiHeadModel, self).__init__()
        self.tapas = bert
        self.config = config
        self.scale_classes = scale_classes
        self._metrics = TagTreeEmAndF1(mode=2)
        if hidden_size is None:
            hidden_size = self.config.hidden_size
        if dropout_prob is None:
            dropout_prob = self.config.hidden_dropout_prob
        self.head_count = head_count
        self.NLLLoss = nn.NLLLoss(reduction="sum")
        self.scale_criterion = scale_criterion
        self.scale_predictor = Default_FNN(3 * hidden_size, hidden_size, scale_classes, dropout_prob)

        self.paragraph_summary_vector_module = nn.Linear(hidden_size, 1)
        self.table_summary_vector_module = nn.Linear(hidden_size, 1)
        self.question_summary_vector_module = nn.Linear(hidden_size, 1)

        self.head_predictor = Default_FNN(3 * hidden_size, hidden_size, head_count, dropout_prob)

        self.single_span_head = SingleSpanHead(hidden_size)
        self.sequence_tag_head = SequenceTagHead(hidden_size, dropout_prob)

        self.HEAD_CLASSES = HEAD_CLASSES_
Head

Single Span Head

from typing import Optional

import torch.nn as nn
import torch
from tag_tree.tagtree.tools.allennlp import replace_masked_values, masked_log_softmax


class SingleSpanHead(nn.Module):

    def __init__(self, input_size):
        super(SingleSpanHead, self).__init__()
        self.start_pos_predict = nn.Linear(input_size, 1)
        self.end_pos_predict = nn.Linear(input_size, 1)
        self.NLL = nn.NLLLoss(reduction="sum")

    def forward(self, input_vec, mask):
        # [batch_size, seq_len]
        start_logits = self.start_pos_predict(input_vec).squeeze(-1)
        end_logits = self.end_pos_predict(input_vec).squeeze(-1)

        start_log_probs = masked_log_softmax(start_logits, mask)
        end_log_probs = masked_log_softmax(end_logits, mask)

        return start_log_probs, end_log_probs

Sequence Tag Head

import torch.nn as nn
import torch
from tag_tree.tagtree.model_util import Default_FNN
from tag_tree.tagtree.tools.allennlp import replace_masked_values, masked_log_softmax
from typing import Tuple, Dict, Any, List, Union, Optional


class SequenceTagHead(nn.Module):

    def __init__(self, hidden_size, dropout_prob):
        super(SequenceTagHead, self).__init__()
        self.tag_predictor = Default_FNN(hidden_size, hidden_size, 2, dropout_prob)
        self.NLLLoss = nn.NLLLoss(reduction="sum")

    def forward(self, input_vec, table_mask, paragraph_mask):
		# input_vec [batch, seq_len, dim]/ [seq_len, dim]  mask [batch, seq_len]/ [seq_len]
        table_sequence_output = replace_masked_values(input_vec, table_mask.unsqueeze(-1), 0)
        table_tag_prediction = self.tag_predictor(table_sequence_output)
        table_tag_prediction = masked_log_softmax(table_tag_prediction, table_mask.unsqueeze(-1))
        table_tag_prediction = replace_masked_values(table_tag_prediction, table_mask.unsqueeze(-1), 0)

        paragraph_sequence_output = replace_masked_values(input_vec, paragraph_mask.unsqueeze(-1), 0)
        paragraph_tag_prediction = self.tag_predictor(paragraph_sequence_output)
        paragraph_tag_prediction = masked_log_softmax(paragraph_tag_prediction, paragraph_mask.unsqueeze(-1))
        paragraph_tag_prediction = replace_masked_values(paragraph_tag_prediction, paragraph_mask.unsqueeze(-1), 0)

        # [batch_size, seq_length, 2]

        return table_tag_prediction, paragraph_tag_prediction



forward
    def forward(self,
                input_ids: torch.LongTensor,
                attention_mask: torch.LongTensor,
                token_type_ids: torch.LongTensor,
                paragraph_mask: torch.LongTensor,
                paragraph_index: torch.LongTensor,
                tag_labels: torch.LongTensor,
                scale_labels: torch.LongTensor,
                gold_answers: List,
                paragraph_tokens: List[List[str]],
                paragraph_numbers: List[np.ndarray],
                table_cell_numbers: List[np.ndarray],
                question_ids: List[str],
                span_pos_labels: torch.LongTensor,
                head_labels: torch.LongTensor,
                position_ids: torch.LongTensor = None,
                table_mask: torch.LongTensor = None,
                table_cell_index: torch.LongTensor = None,
                table_cell_tokens: List[List[str]] = None,
                mode=None,
                epoch=None, ) -> Dict[str, torch.Tensor]:

        output_dict = {}
        token_representations = self.tapas(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids)[0]
        batch_size = token_representations.shape[0]
        question_mask = attention_mask - paragraph_mask - table_mask
        question_mask[0] = 0

        # cls token output
        cls_out = token_representations[:, 0, :]

        # question, table, paragraph summary to get head prediction
        paragraph_summary = self.summary_vector(token_representations, paragraph_mask, 'paragraph')
        table_summary = self.summary_vector(token_representations, table_mask, 'table')
        question_summary = self.summary_vector(token_representations, question_mask, 'question')

        # get different head data index from head labels
        paragraph_span_head_index = (head_labels.squeeze() == 0).cpu().detach()
        table_span_head_index = (head_labels.squeeze() == 1).cpu().detach()
        sequence_tag_head_index = (head_labels.squeeze() > 1).cpu().detach()

        # get head probabilities and predictions from head predictor
        answer_head_logits = self.head_predictor(torch.cat([paragraph_summary,
                                                            table_summary, question_summary], dim=-1))
        answer_head_log_probs = F.log_softmax(answer_head_logits, -1)
        predict_head = torch.argmax(answer_head_log_probs, dim=-1).unsqueeze(-1)

        # get scale probabilities from scale predictor
        scale_predict_logits = self.scale_predictor(torch.cat([cls_out,
                                                               paragraph_summary, table_summary], dim=-1))

        # get different head's prediction from different head label's data, and get corresponding position/tag labels
        table_tag_prediction, paragraph_tag_prediction = self.sequence_tag_head(
            token_representations[sequence_tag_head_index],
            table_mask[sequence_tag_head_index],
            paragraph_mask[sequence_tag_head_index])
        table_tag_labels = util.replace_masked_values(tag_labels[sequence_tag_head_index].float(),
                                                      table_mask[sequence_tag_head_index], 0)
        paragraph_tag_labels = util.replace_masked_values(tag_labels[sequence_tag_head_index].float(),
                                                          paragraph_mask[sequence_tag_head_index], 0)

        table_start_log_predictions, table_end_log_predictions = self.single_span_head(
            token_representations[table_span_head_index],
            table_mask[table_span_head_index])
        paragraph_start_log_predictions, paragraph_end_log_predictions = self.single_span_head(
            token_representations[paragraph_span_head_index],
            paragraph_mask[paragraph_span_head_index])

        # get loss
        sequence_tag_head_loss = self.NLLLoss(table_tag_prediction.transpose(1, 2),
                                              table_tag_labels.long()) + self.NLLLoss(
            paragraph_tag_prediction.transpose(1, 2), paragraph_tag_labels.long())

        table_span_loss = self.NLLLoss(table_start_log_predictions, span_pos_labels[table_span_head_index][:, 0]) \
                          + self.NLLLoss(table_end_log_predictions, span_pos_labels[table_span_head_index][:, 1])

        paragraph_span_loss = self.NLLLoss(paragraph_start_log_predictions,
                                           span_pos_labels[paragraph_span_head_index][:, 0]) \
                              + self.NLLLoss(paragraph_end_log_predictions,
                                             span_pos_labels[paragraph_span_head_index][:, 1])

        scale_loss = self.scale_criterion(scale_predict_logits, scale_labels)
        scale_prediction = torch.argmax(scale_predict_logits, dim=-1).unsqueeze(-1)

        head_loss = self.NLLLoss(answer_head_log_probs, head_labels)
        loss = head_loss + paragraph_span_loss + table_span_loss + sequence_tag_head_loss + scale_loss

        head_acc = (predict_head == head_labels).float().mean()
        output_dict['loss'] = loss
        output_dict['head_acc'] = head_acc
        output_dict['scale'] = []
        output_dict['head'] = []
        output_dict["question_id"] = []
        output_dict["answer"] = []

        with torch.no_grad():
            for bz in range(batch_size):
                answer = None
                if predict_head[bz] == self.HEAD_CLASSES["SPAN-TEXT"]:
                    # [seq_len]
                    start_log_probs, end_log_probs = self.single_span_head(token_representations[bz, :, :],
                                                                           paragraph_mask[bz, :])
                    start_log_probs = start_log_probs.detach().cpu()
                    end_log_probs = end_log_probs.detach().cpu()
                    answer = get_best_span(start_log_probs, end_log_probs, paragraph_index[bz, :], paragraph_tokens[bz])

                elif predict_head[bz] == self.HEAD_CLASSES["SPAN-TABLE"]:
                    start_log_probs, end_log_probs = self.single_span_head(token_representations[bz, :, :],
                                                                           table_mask[bz, :])
                    start_log_probs = start_log_probs.detach().cpu()
                    end_log_probs = end_log_probs.detach().cpu()
                    answer = get_best_span(start_log_probs, end_log_probs, table_cell_index[bz, :],
                                           table_cell_tokens[bz])

                else:
                    table_tag_prediction, paragraph_tag_prediction = self.sequence_tag_head(
                        token_representations[bz, :, :],
                        table_mask[bz, :],
                        paragraph_mask[bz, :])
                    paragraph_tag_prediction = torch.argmax(paragraph_tag_prediction, dim=-1).float()
                    paragraph_token_tag_prediction = reduce_mean_index(paragraph_tag_prediction, paragraph_index[bz, :])
                    paragraph_token_tag_prediction = paragraph_token_tag_prediction.detach().cpu().numpy()

                    table_tag_prediction = torch.argmax(table_tag_prediction, dim=-1).float()
                    table_cell_tag_prediction = reduce_mean_index(table_tag_prediction, table_cell_index[bz, :])
                    table_cell_tag_prediction = table_cell_tag_prediction.detach().cpu().numpy()

                    if predict_head[bz] == self.HEAD_CLASSES["MULTI_SPAN"]:

                        paragraph_selected_span_tokens = \
                            get_span_tokens_from_paragraph(paragraph_token_tag_prediction, paragraph_tokens[bz])
                        table_selected_tokens = \
                            get_span_tokens_from_table(table_cell_tag_prediction, table_cell_tokens[bz])
                        answer = paragraph_selected_span_tokens + table_selected_tokens
                        answer = sorted(answer)

                    elif predict_head[bz] == self.HEAD_CLASSES["COUNT"]:

                        paragraph_selected_tokens = \
                            get_span_tokens_from_paragraph(paragraph_token_tag_prediction, paragraph_tokens[bz])
                        table_selected_tokens = \
                            get_span_tokens_from_table(table_cell_tag_prediction, table_cell_tokens[bz])
                        answer = len(paragraph_selected_tokens) + len(table_selected_tokens)

                    else:
                        pass

                    output_dict["answer"].append(answer)
                    output_dict["scale"].append(SCALE[int(scale_prediction[bz])])
                    output_dict["head"].append(predict_head[bz])
                    output_dict["question_id"].append(question_ids[bz])
                    predict_type = ""
                    if answer is not None:
                        self._metrics(gold_answers[bz], answer, predict_type, SCALE[int(scale_prediction[bz])])

        return output_dict
工具函数

用于从network输出得到对应答案,其中get_best_span函数通过生成(seq_len,seq_len)矩阵然后加入上三角阵的mask(begin必须在end前),然后找整个矩阵的最大值获得最大概率的span(因为都是log_softmax结果,所以相加相当于概率相乘)

def get_best_span(span_start_probs, span_end_probs, span_index, span_tokens):
    if span_start_probs.dim() != 1 or span_end_probs.dim() != 1:
        raise ValueError("Input shapes must be (batch_size, passage_length)")
    passage_length = span_start_probs.shape[0]
    device = span_start_probs.device
    # (batch_size, passage_length, passage_length)
    span_log_probs = span_start_probs.unsqueeze(1) + span_end_probs.unsqueeze(0)
    # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
    # the span ends before it starts.
    span_log_mask = torch.triu(torch.ones((passage_length, passage_length),
                                          device=device)).log()
    valid_span_log_probs = span_log_probs + span_log_mask

    # Here we take the span matrix and flatten it, then find the best span using argmax.  We
    # can recover the start and end indices from this flattened list using simple modular
    # arithmetic.
    # (batch_size, passage_length * passage_length)
    best_spans = valid_span_log_probs.view(-1).argmax(-1)
    span_start_indice = best_spans // passage_length
    span_end_indice = best_spans % passage_length

    token_start_indice = span_index[span_start_indice]
    token_end_indice = span_index[span_end_indice]

    return ["".join(span_tokens[token_start_indice - 1: token_end_indice - 1])]


def get_span_tokens_from_paragraph(paragraph_token_tag_prediction, paragraph_tokens) -> List[str]:
    span_tokens = []
    span_start = False
    for i in range(1, min(len(paragraph_tokens) + 1, len(paragraph_token_tag_prediction))):
        if paragraph_token_tag_prediction[i] == 0:
            span_start = False
        if paragraph_token_tag_prediction[i] != 0:
            if not span_start:
                span_tokens.append([paragraph_tokens[i - 1]])
                span_start = True
            else:
                span_tokens[-1] += [paragraph_tokens[i - 1]]
    span_tokens = [" ".join(tokens) for tokens in span_tokens]
    return span_tokens


def get_span_tokens_from_table(table_cell_tag_prediction, table_cell_tokens) -> List[str]:
    span_tokens = []
    for i in range(1, len(table_cell_tag_prediction)):
        if table_cell_tag_prediction[i] != 0:
            span_tokens.append(str(table_cell_tokens[i - 1]))
    return span_tokens
def reduce_mean_index(values, index, max_length=512, name="index_reduce_mean"):
    return _index_reduce(values, index, max_length, "mean", name)


"""
def _index_reduce(values, index, max_length, index_reduce_fn, name):
    flat_index, num_index = flatten_index(index, max_length)
    bsz = values.shape[0]
    seq_len = values.shape[1]
    flat_values = values.reshape(bsz * seq_len)
    index_means = scatter(
        src=flat_values,
        index=flat_index.type(torch.long),
        dim=0,
        dim_size=num_index,
        reduce=index_reduce_fn,
    )
    output_values = index_means.view(bsz, -1)
    return output_values
"""


def _index_reduce(values, index, max_length, index_reduce_fn, name):
    index_means = scatter(
        src=values,
        index=index.type(torch.long),
        dim=0,
        dim_size=512,
        reduce=index_reduce_fn,
    )
    output_values = index_means.view(-1)
    return output_values


def flatten_index(index, max_length=512, name="index_flatten"):
    batch_size = index.shape[0]
    offset = torch.arange(start=0, end=batch_size, device=index.device) * max_length
    offset = offset.view(batch_size, 1)
    return (index + offset).view(-1), batch_size * max_length
metric

metric用于测量结果的em和F1度量,并取得不同head的表格

class TagTreeEmAndF1(object):
    """
    This :class:`Metric` takes the best span string computed by a model, along with the answer
    strings labeled in the data, and computes exact match and F1 score using the official DROP
    evaluator (which has special handling for numbers and for questions with multiple answer spans,
    among other things).
    """

    def __init__(self, mode: Mode = Mode.NUMBER_AND_SCALE) -> None:
        self._total_em = 0.0
        self._total_f1 = 0.0
        self._scale_em = 0.0
        self._head_em = 0.0
        self.head_correct_count = {"SPAN-TEXT": 0, "SPAN-TABLE": 1, "MULTI_SPAN": 2, "COUNT": 3, "ARITHMETIC": 4}
        self.head_total_count = {"SPAN-TEXT": 0, "SPAN-TABLE": 1, "MULTI_SPAN": 2, "COUNT": 3, "ARITHMETIC": 4}
        self.scale_correct_count = {"": 0, "thousand": 0, "million": 0, "billion": 0, "percent": 0}
        self.scale_total_count = {"": 0, "thousand": 0, "million": 0, "billion": 0, "percent": 0}
        self._count = 0
        self._details = []

    def __call__(self, ground_truth: dict, prediction: Union[str, List], pred_type, pred_scale="", pred_span=None,
                 gold_span=None,
                 pred_head=None, gold_head=None):  # type: ignore
        """
        Parameters
        ----------
        ground_truths: ``dict``
            All the ground truth answer annotations.
        prediction: ``Union[str, List]``
            The predicted answer from the model evaluated. This could be a string, or a list of string
            when multiple spans are predicted as answer.
        pred_scale: ``str``
        """
        # if not prediction:
        #     exact_match = 0
        #     f1_score = 0
        # else:
        #     gold_type, gold_answer, gold_scale = extract_gold_answers(ground_truth)
        #     ground_truth_answer_strings = get_answer_str(gold_type, gold_answer, gold_scale, self._mode)
        #     prediction = prediction if isinstance(prediction, list) else [prediction]
        #     prediction_strings = get_answer_str(pred_type, prediction, pred_scale, self._mode)
        #     exact_match, f1_score = metric_max_over_ground_truths(
        #             get_metrics,
        #             prediction_strings,
        #             ground_truth_answer_strings
        #     )
        if pred_head is not None:
            if pred_head == gold_head:
                self.head_correct_count[pred_head] += 1
                self._head_em += 1
            self.head_total_count[gold_head] += 1
        if pred_scale == ground_truth["scale"]:
            self.scale_correct_count[pred_scale] += 1
        self.scale_total_count[ground_truth["scale"]] += 1
        if not prediction:
            exact_match = 0
            f1_score = 0
            span_exact_match = 0
            span_f1_score = 0
        else:
            gold_type, gold_answer, gold_scale = extract_gold_answers(ground_truth)
            if not gold_answer:
                exact_match = 0
                f1_score = 0
                span_exact_match = 0
                span_f1_score = 0
            else:
                ground_truth_answer_strings = get_answer_str(gold_answer, gold_scale)

                if gold_scale == pred_scale:
                    self._scale_em += 1
                prediction = prediction if isinstance(prediction, list) else [prediction]
                prediction_strings = get_answer_str(prediction, pred_scale)
                prediction_strings = add_percent_pred(prediction_strings, pred_scale, prediction)
                exact_match, f1_score = metric_max_over_ground_truths(
                    get_metrics,
                    prediction_strings,
                    ground_truth_answer_strings
                )
                if gold_type in ['arithmetic', 'count']:
                    """if gold type equals with arithmetic and count, set the f1_score == exact_match"""
                    f1_score = exact_match
                if not pred_span:
                    span_exact_match = 0
                    span_f1_score = 0
                else:
                    pred_span_strings = get_answer_str(pred_span, "")
                    gold_span_strings = get_answer_str(gold_span, "")
                    span_exact_match, span_f1_score = metric_max_over_ground_truths(
                        get_metrics,
                        pred_span_strings,
                        gold_span_strings,
                    )

        self._total_em += exact_match
        self._total_f1 += f1_score
        self._count += 1
        it = {**ground_truth,
              **{"pred": prediction,
                 "pred_scale": pred_scale,
                 "em": exact_match,
                 "f1": f1_score,
                 "pred_span": pred_span,
                 "gold_span": gold_span,
                 "span_em": span_exact_match,
                 "span_f1": span_f1_score}}
        self._details.append(it)

    def get_overall_metric(self, reset: bool = False) -> Tuple[float, float, float, float]:
        """
        Returns
        -------
        Average exact match and F1 score (in that order) as computed by the official DROP script
        over all inputs.
        """
        exact_match = self._total_em / self._count if self._count > 0 else 0
        f1_score = self._total_f1 / self._count if self._count > 0 else 0
        scale_score = self._scale_em / self._count if self._count > 0 else 0
        head_score = self._head_em / self._count if self._count > 0 else 0
        head_em_detail = {"SPAN-TEXT": 0, "SPAN-TABLE": 1, "MULTI_SPAN": 2, "COUNT": 3, "ARITHMETIC": 4}
        scale_em_detail = {"": 0, "thousand": 0, "million": 0, "billion": 0, "percent": 0}
        for k in head_em_detail.keys():
            head_em_detail[k] = self.head_correct_count[k] / self.head_total_count[k] if self.head_total_count[
                                                                                             k] > 0 else 0
        print(head_em_detail)
        print(self.head_total_count)
        for k in scale_em_detail.keys():
            scale_em_detail[k] = self.scale_correct_count[k] / self.scale_total_count[k] if self.scale_total_count[
                                                                                                k] > 0 else 0
        print(scale_em_detail)
        print(self.scale_total_count)
        if reset:
            self.reset()
        return exact_match, f1_score, scale_score, head_score

    def get_detail_metric(self):
        df = pd.DataFrame(self._details)
        if len(self._details) == 0:
            return None, None
        em_pivot_tab = df.pivot_table(index='answer_type', values=['em'],
                                      columns=['answer_from'], aggfunc='mean').fillna(0)

        f1_pivot_tab = df.pivot_table(index='answer_type', values=['f1'],
                                      columns=['answer_from'], aggfunc='mean').fillna(0)
        return em_pivot_tab, f1_pivot_tab

    def get_raw_pivot_table(self):
        df = pd.DataFrame(self._details)
        pivot_tab = df.pivot_table(index='answer_type', values=['em'],
                                   columns=['answer_from'], aggfunc='count').fillna(0)
        return pivot_tab

    def get_raw(self):
        return self._details

    def reset(self):
        self._total_em = 0.0
        self._total_f1 = 0.0
        self._scale_em = 0.0
        self._head_em = 0.0
        self._count = 0
        self._details = []
        self.head_correct_count = {"SPAN-TEXT": 0, "SPAN-TABLE": 1, "MULTI_SPAN": 2, "COUNT": 3, "ARITHMETIC": 4}
        self.head_total_count = {"SPAN-TEXT": 0, "SPAN-TABLE": 1, "MULTI_SPAN": 2, "COUNT": 3, "ARITHMETIC": 4}
        self.scale_correct_count = {"": 0, "thousand": 0, "million": 0, "billion": 0, "percent": 0}
        self.scale_total_count = {"": 0, "thousand": 0, "million": 0, "billion": 0, "percent": 0}

    def __str__(self):
        return f"TagTreeEmAndF1(em={self._total_em}, f1={self._total_f1}, count={self._count})"

利用set的intersection方法获得交集用来计算F1 score

def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
    """
    Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
    between them and gets maximum metric values over all the answers.
    """
    scores = np.zeros([len(gold), len(predicted)])
    for gold_index, gold_item in enumerate(gold):
        for pred_index, pred_item in enumerate(predicted):
            # if _match_numbers_if_present(gold_item, pred_item): no need to match number in tatqa
            scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
    row_ind, col_ind = linear_sum_assignment(-scores)

    max_scores = np.zeros([max(len(gold), len(predicted))])
    for row, column in zip(row_ind, col_ind):
        max_scores[row] = max(max_scores[row], scores[row, column])
    return max_scores


def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
    intersection = len(gold_bag.intersection(predicted_bag))
    if not predicted_bag:
        precision = 1.0
    else:
        precision = intersection / float(len(predicted_bag))
    if not gold_bag:
        recall = 1.0
    else:
        recall = intersection / float(len(gold_bag))
    f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0
    return f1


def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
    gold_numbers = set()
    predicted_numbers = set()
    for word in gold_bag:
        if is_number(word):
            gold_numbers.add(word)
    for word in predicted_bag:
        if is_number(word):
            predicted_numbers.add(word)
    if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
        return True
    return False


def get_metrics(predicted: Union[str, List[str], Tuple[str, ...]],
                gold: Union[str, List[str], Tuple[str, ...]]) -> Tuple[float, float]:
    """
    Takes a predicted answer and a gold answer (that are both either a string or a list of
    strings), and returns exact match and the DROP F1 metric for the prediction.  If you are
    writing a script for evaluating objects in memory (say, the output of predictions during
    validation, or while training), this is the function you want to call, after using
    :func:`answer_json_to_strings` when reading the gold answer from the released data file.
    """
    predicted_bags = _answer_to_bags(predicted)
    gold_bags = _answer_to_bags(gold)
    # print("pred bags:" + str(predicted_bags))
    # print("answer bags:" + str(predicted_bags))

    if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
        exact_match = 1.0
    else:
        exact_match = 0.0

    f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
    f1 = np.mean(f1_per_bag)
    f1 = round(f1, 2)
    return exact_match, f1

Mutihead 类里方法:self._metrics 定义为 TagTreeEmAndF1
,使用TagTreeEmAndF1的call函数进行累加。

    def get_metrics(self, logger=None, reset: bool = False) -> Dict[str, float]:
        detail_em, detail_f1 = self._metrics.get_detail_metric()
        raw_detail = self._metrics.get_raw_pivot_table()
        exact_match, f1_score, scale_score, head_score = self._metrics.get_overall_metric(reset)
        print(f"raw matrix:{raw_detail}\r\n")
        print(f"detail em:{detail_em}\r\n")
        print(f"detail f1:{detail_f1}\r\n")
        print(f"global em:{exact_match}\r\n")
        print(f"global f1:{f1_score}\r\n")
        print(f"global scale:{scale_score}\r\n")
        print(f"global head:{head_score}\r\n")
        if logger is not None:
            logger.info(f"raw matrix:{raw_detail}\r\n")
            logger.info(f"detail em:{detail_em}\r\n")
            logger.info(f"detail f1:{detail_f1}\r\n")
            logger.info(f"global em:{exact_match}\r\n")
            logger.info(f"global f1:{f1_score}\r\n")
            logger.info(f"global scale:{scale_score}\r\n")
        return {'em': exact_match, 'f1': f1_score, "scale": scale_score}
AllenNLP

有mask的一些操作:

def masked_softmax(vector: torch.Tensor,
                   mask: torch.Tensor,
                   dim: int = -1,
                   memory_efficient: bool = True,#False,
                   mask_fill_value: float = -1e32) -> torch.Tensor:
    """
    ``torch.nn.functional.softmax(vector)`` does not work if some elements of ``vector`` should be
    masked.  This performs a softmax on just the non-masked portions of ``vector``.  Passing
    ``None`` in for the mask is also acceptable; you'll just get a regular softmax.

    ``vector`` can have an arbitrary number of dimensions; the only requirement is that ``mask`` is
    broadcastable to ``vector's`` shape.  If ``mask`` has fewer dimensions than ``vector``, we will
    unsqueeze on dimension 1 until they match.  If you need a different unsqueezing of your mask,
    do it yourself before passing the mask into this function.

    If ``memory_efficient`` is set to true, we will simply use a very large negative number for those
    masked positions so that the probabilities of those positions would be approximately 0.
    This is not accurate in math, but works for most cases and consumes less memory.

    In the case that the input vector is completely masked and ``memory_efficient`` is false, this function
    returns an array of ``0.0``. This behavior may cause ``NaN`` if this is used as the last layer of
    a model that uses categorical cross-entropy loss. Instead, if ``memory_efficient`` is true, this function
    will treat every element as equal, and do softmax over equal numbers.
    """
    if mask is None:
        result = torch.nn.functional.softmax(vector, dim=dim)
    else:
        mask = mask.float()
        while mask.dim() < vector.dim():
            mask = mask.unsqueeze(1)
        if not memory_efficient:
            # To limit numerical errors from large vector elements outside the mask, we zero these out.
            result = torch.nn.functional.softmax(vector * mask, dim=dim)
            result = result * mask
            result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
        else:
            masked_vector = vector.masked_fill((1 - mask).bool(), mask_fill_value)
            result = torch.nn.functional.softmax(masked_vector, dim=dim)
    return result


def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    ``torch.nn.functional.log_softmax(vector)`` does not work if some elements of ``vector`` should be
    masked.  This performs a log_softmax on just the non-masked portions of ``vector``.  Passing
    ``None`` in for the mask is also acceptable; you'll just get a regular log_softmax.

    ``vector`` can have an arbitrary number of dimensions; the only requirement is that ``mask`` is
    broadcastable to ``vector's`` shape.  If ``mask`` has fewer dimensions than ``vector``, we will
    unsqueeze on dimension 1 until they match.  If you need a different unsqueezing of your mask,
    do it yourself before passing the mask into this function.

    In the case that the input vector is completely masked, the return value of this function is
    arbitrary, but not ``nan``.  You should be masking the result of whatever computation comes out
    of this in that case, anyway, so the specific values returned shouldn't matter.  Also, the way
    that we deal with this case relies on having single-precision floats; mixing half-precision
    floats with fully-masked vectors will likely give you ``nans``.

    If your logits are all extremely negative (i.e., the max value in your logit vector is -50 or
    lower), the way we handle masking here could mess you up.  But if you've got logit values that
    extreme, you've got bigger problems than this.
    """
    if mask is not None:
        mask = mask.float()
        while mask.dim() < vector.dim():
            mask = mask.unsqueeze(1)
        # vector + mask.log() is an easy way to zero out masked elements in logspace, but it
        # results in nans when the whole vector is masked.  We need a very small value instead of a
        # zero in the mask for these cases.  log(1 + 1e-45) is still basically 0, so we can safely
        # just add 1e-45 before calling mask.log().  We use 1e-45 because 1e-46 is so small it
        # becomes 0 - this is just the smallest value we can actually use.
        vector = vector + (mask + 1e-45).log()
    return torch.nn.functional.log_softmax(vector, dim=dim)


def masked_max(vector: torch.Tensor,
               mask: torch.Tensor,
               dim: int,
               keepdim: bool = False,
               min_val: float = -1e7) -> torch.Tensor:
    """
    To calculate max along certain dimensions on masked values

    Parameters
    ----------
    vector : ``torch.Tensor``
        The vector to calculate max, assume unmasked parts are already zeros
    mask : ``torch.Tensor``
        The mask of the vector. It must be broadcastable with vector.
    dim : ``int``
        The dimension to calculate max
    keepdim : ``bool``
        Whether to keep dimension
    min_val : ``float``
        The minimal value for paddings

    Returns
    -------
    A ``torch.Tensor`` of including the maximum values.
    """
    one_minus_mask = (1.0 - mask).byte()
    replaced_vector = vector.masked_fill(one_minus_mask, min_val)
    max_value, _ = replaced_vector.max(dim=dim, keepdim=keepdim)
    return max_value
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值