基于Bert语言模型的中文短文本分类

基于Bert语言模型的中文短文本分类

一、前言

本次的任务是基于谷歌开源的Bert语言模型,进行微调,完成中文短文本分类任务。利用爬虫从微博客户端中获取热门评论,做为训练语料。

二、添加自定义类MyDataProcessor

添加自定义类MyDataProcessor,完成训练和测试语料的文件读取和预处理工作。

class MyDataProcessor(DataProcessor):
  """Base class for data converters for sequence classification data sets."""

  def get_train_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the train set."""
    # 读入训练文本数据
    file_path = os.path.join(data_dir,'train_sentiment.txt')
    f = open(file_path,'r',encoding='utf-8')
    train_data = []
    index = 0
    # 以行的方式读入
    for line in f.readlines() :
        # guid用来区分每一个example
        guid = "train-%d" % (index)
        line = line.replace('\n','').split('\t')
        # text_a 要分类的文本
        text_a = tokenization.convert_to_unicode(str(line[1]))
        # 文本对应的分类类别
        label = str(line[2])
        train_data.append(
            InputExample(guid=guid,text_a=text_a,text_b=None,label=label))
        index += 1
    return train_data

  def get_dev_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the dev set."""
    # 读入测试文本数据
    file_path = os.path.join(data_dir, 'test_sentiment.txt')
    f = open(file_path, 'r', encoding='utf-8')
    dev_data = []
    index = 0
    # 以行的方式读入
    for line in f.readlines():
        # guid用来区分每一个example
        guid = "dev-%d" % index
        line = line.replace('\n', '').split('\t')
        # text_a 要分类的文本
        text_a = tokenization.convert_to_unicode(str(line[1]))
        # 文本对应的分类类别
        label = str(line[2])
        dev_data.append(
            InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        index += 1
    return dev_data

  def get_test_examples(self, data_dir):
    """Gets a collection of `InputExample`s for prediction."""
    # 读入文本数据
    file_path = os.path.join(data_dir, 'test.csv')
    test_df = pd.read_csv(file_path,encoding='utf-8')
    test_data = []
    # 以行的方式读入
    for index,test in enumerate(test_df.values) :
        # guid用来区分每一个example
        guid = "test-%d" % index
        # text_a 要分类的文本
        text_a = tokenization.convert_to_unicode(str(test[0]))
        # 文本对应的分类类别
        label = str(test[1])
        test_data.append(
            InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        index += 1
    return test_data

  def get_labels(self):
    """Gets the list of labels for this data set."""
    return ['0','1','2']

MydataProcessor类中包含三个方法:get_train_examples(),get_dev_examples(),get_test_examples。其实逻辑思路按照开源程序中英文文本分类任务的进行修改即可。需要注意的是,英文文本分类任务中包含text_a 和 text_b,而在本次任务中只有text_a因此可以将text_b = None.最终 get_labels()方法返回三个标签[“0”,“1”,“2”],分别对应[“中立”,“正向”,“负向”] .

三、修改主类

在这里插入图片描述
在主类中添加自定义类的类名。

四、修改运行参数

–task_name=mydata # 修改为自定义类类名
–do_train=true # 是否训练
–do_eval=true # 是否验证
–data_dir=…/GLUE/glue_data/mydata
–vocab_file=…/GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/vocab.txt
–bert_config_file=…/GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_config.json
–init_checkpoint=…/GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_model.ckpt
–max_seq_length=128 # 文本最大长度
–train_batch_size=6
–learning_rate=2e-5 # 学习率
–num_train_epochs=1.0
–output_dir=…/GLUE/chineseoutput # 模型的最终保存位置

中文文本分类任务:

--task_name=mydata
--do_train=true
--do_eval=true
--data_dir=../GLUE/glue_data/mydata
--vocab_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/vocab.txt
--bert_config_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_config.json
--init_checkpoint=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_model.ckpt
--max_seq_length=128
--train_batch_size=6
--learning_rate=2e-5
--num_train_epochs=1.0
--output_dir=../GLUE/chineseoutput

英文文本分类任务:

--task_name=MRPC
--do_train=true
--do_eval=true
--data_dir=../GLUE/glue_data/MRPC
--vocab_file=../GLUE/BERT_BASE_DIR/uncased_L-12_H-768_A-12/vocab.txt
--bert_config_file=../GLUE/BERT_BASE_DIR/uncased_L-12_H-768_A-12/bert_config.json
--init_checkpoint=../GLUE/BERT_BASE_DIR/uncased_L-12_H-768_A-12/bert_model.ckpt
--max_seq_length=128
--train_batch_size=6
--learning_rate=2e-5
--num_train_epochs=3.0
--output_dir=../GLUE/output

五、运行

配置完参数后直接运行
报错:tensorflow.python.framework.errors_impl.DataLossError: Checksum does not match: stored 4283821441 vs. calculated on the restored bytes 2653108158
在这里插入图片描述
经过查阅资料,发现可能是由于ckpt文件有问题,下载的Bert预训练模型中的中文ckpt文件出错;
解决方法: 重新换一个ckpt文件

更换过后
报错:tensorflow.python.framework.errors_impl.OutOfRangeError: Read fewer bytes than requested
在这里插入图片描述
解决方案: 重新换回原来的文件
在这里插入图片描述
模型可以跑起来了!!!
在这里插入图片描述
等待若干小时
在这里插入图片描述

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

敷衍zgf

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值