KBQA学习记录-NER的main函数

目录

一、main函数实现的内容

0.main()函数

1.CrfInputExample类

2.CrfInputFeatures类

3.NERprocessor类

4.类所需函数

①load_and_cache_example

②crf_convert_examples_to_features


一、main函数实现的内容

在main()函数中,主要是对样本的处理,使得我们能够得到能够输入模型训练的数据。需要一些辅助工具的类,提前定义。大概流程如下:

1.通过argparser添加参数

2.实例化NER processor类

3.实例化tokenizer = bertTokenizer(),构建训练数据时用

4.实例化BertCRF模型,训练用

5.获取训练数据

6.训练,所需函数使用“NER训练及验证”文章记录。

0.main()函数

该函数设置了参数,以及引导了整体流程。

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="数据文件目录,应当有train.txt dev.txt")

    parser.add_argument("--vob_file", default=None, type=str, required=True,
                        help="词表文件")

    parser.add_argument("--model_config", default=None, type=str, required=True,
                        help="模型配置文件json文件")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="输出结果的文件")

    # Other parameters
    parser.add_argument("--pre_train_model", default=None, type=str, required=False,
                        help="预训练的模型文件,参数矩阵。如果存在就加载")

    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="输入到bert的最大长度,通常不应该超过512")
    parser.add_argument("--do_train", action='store_true',
                        help="是否进行训练")
    parser.add_argument("--train_batch_size", default=8, type=int,
                        help="训练集的batch_size")
    parser.add_argument("--eval_batch_size", default=8, type=int,
                        help="验证集的batch_size")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="梯度累计更新的步骤,用来弥补GPU过小的情况")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="学习率")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="权重衰减")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="最大的梯度更新")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="epoch 数目")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="让学习增加到1的步数,在warmup_steps后,再衰减到0")

    args = parser.parse_args()

    args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)
    #          filename='./output/bert-crf-ner.log',
    processor = NerProcessor()

    # 得到tokenizer
    tokenizer_inputs = ()
    tokenizer_kwards = {'do_lower_case': False,
                        'max_len': args.max_seq_length,
                        'vocab_file': args.vob_file}
    tokenizer = BertTokenizer(*tokenizer_inputs,**tokenizer_kwards)

    print(len(processor.get_labels()))
    model = BertCrf(config_name= args.model_config,model_name=args.pre_train_model,num_tags = len(processor.get_labels()),batch_first=True)
    model = model.to(args.device)


    train_dataset = load_and_cache_example(args,tokenizer,processor,'train')
    eval_dataset = load_and_cache_example(args,tokenizer,processor,'dev')
    test_dataset = load_and_cache_example(args, tokenizer, processor, 'test')

    if args.do_train:
        trains(args,train_dataset,eval_dataset,model)

1.CrfInputExample类

这个类里面定义的是样本相关的内容,样本id,样本text,样本label,用于后续调用

class CrfInputExample(object):
    def __init__(self, guid, text, label=None):
        self.guid = guid
        self.text = text
        self.label = label

2.CrfInputFeatures类

这里面定义的是样本特征的内容,也就是用于输入模型的内容。

class CrfInputFeatures(object):
    def __init__(self, input_ids, attention_mask, token_type_ids, label):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.label = label

上面这两个类定义好,都是为了后面方便调用,直接获取相关内容。

3.NERprocessor类

这个类用来创建训练、验证、测试样本

class NerProcessor(DataProcessor):
    def get_train_examples(self,data_dir):
        return self._create_examples(
            os.path.join(data_dir,"train.txt"))

    def get_dev_examples(self, data_dir):
        return self._create_examples(
            os.path.join(data_dir, "dev.txt"))

    def get_test_examples(self, data_dir):
        return self._create_examples(
            os.path.join(data_dir, "test.txt"))


    def get_labels(self):
        return CRF_LABELS

    @classmethod
    def _create_examples(cls, path):
        lines = []
        max_len = 0
        with codecs.open(path, 'r', encoding='utf-8') as f:
            word_list = []
            label_list = []
            for line in f:
                tokens = line.strip().split(' ')
                if 2 == len(tokens):
                    word = tokens[0]
                    label = tokens[1]
                    word_list.append(word)
                    label_list.append(label)
                elif 1 == len(tokens) and '' == tokens[0]:
                    if len(label_list) > max_len:
                        max_len = len(label_list)

                    lines.append((word_list,label_list))
                    word_list = []
                    label_list = []
        examples = []
        for i,(sentence,label) in enumerate(lines):
            examples.append(
                CrfInputExample(guid=i,text=" ".join(sentence),label=label)
            )
        return examples

4.类所需函数

①load_and_cache_example

通过如下函数获取,该函数实现的内容大致是:

如果已经有存好的特征文件,就导入,否则就自己创建特征

创建特征:通过crf_convert_examples_to_features()函数获取特征

处理特征:将特征挨个抽取出来,每个特征都保存成一个列表,并转化为tensor,最后通过TensorDataset给整合起来(torch.utils.data.TensorDataset)

def load_and_cache_example(args,tokenizer,processor,data_type):


    type_list = ['train', 'dev', 'test']
    if data_type not in type_list:
        raise ValueError("data_type must be one of {}".format(" ".join(type_list)))

    cached_features_file = "cached_{}_{}".format(data_type, str(args.max_seq_length))
    cached_features_file = os.path.join(args.data_dir, cached_features_file)
    if os.path.exists(cached_features_file):
        features = torch.load(cached_features_file)
    else:
        label_list = processor.get_labels()
        if type_list[0] == data_type:
            examples = processor.get_train_examples(args.data_dir)
        elif type_list[1] == data_type:
            examples = processor.get_dev_examples(args.data_dir)
        elif type_list[2] == data_type:
            examples = processor.get_test_examples(args.data_dir)
        else:
            raise ValueError("UNKNOW ERROR")
        features = crf_convert_examples_to_features(examples=examples,tokenizer=tokenizer,max_length=args.max_seq_length,label_list=label_list)
        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_label = torch.tensor([f.label for f in features], dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label)
    return dataset

②crf_convert_examples_to_features

将输入的样本转为特征,共四个特征

input_id:将序号,文本样本输入tokenizer.encode_plus,会返回多个值,第一个就是我们需要的input_id

attention_mask:根据input_id的长度,创建的全1的列表

token_type_id:将序号,文本样本输入tokenizer.encode_plus,会返回多个值,第二个就是我们需要的token_type_id

label_id:根据输入的标签列表,转成id之后,另外加上bert所需要的分隔符[CLS]等,对应位置可以添加0,因为有mask,后面计算的时候会自动抹除。

def crf_convert_examples_to_features(examples,tokenizer,
                                     max_length=512,
                                     label_list=None,
                                     pad_token=0,
                                     pad_token_segment_id = 0,
                                     mask_padding_with_zero = True):

    label_map = {label:i for i, label in enumerate(label_list)}

    features = []

    for (ex_index, example) in enumerate(examples):
        inputs = tokenizer.encode_plus(
            example.text,
            add_special_tokens=True,
            max_length=max_length,
            truncate_first_sequence=True  # We're truncating the first sequence in priority if True
        )
        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)


        padding_length = max_length - len(input_ids)
        input_ids = input_ids + ([pad_token] * padding_length)
        attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
        token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        # 第一个和第二个[0] 加的是[CLS]和[SEP]的位置,  [0]*padding_length是[pad] ,把这些都暂时算作"O",后面用mask 来消除这些,不会影响
        labels_ids = [0] + [label_map[l] for l in example.label] + [0] + [0]*padding_length



        assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
        assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),max_length)
        assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids),max_length)

        assert len(labels_ids) == max_length, "Error with input length {} vs {}".format(len(labels_ids),max_length)


        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
            logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
            logger.info("label: %s " % " ".join([str(x) for x in labels_ids]))

        features.append(
            CrfInputFeatures(input_ids,attention_mask,token_type_ids,labels_ids)
        )
    return features
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值