利用BERT训练情感分类

1 篇文章 0 订阅

 

这里仅仅记录一下自己训练BERT到应用的过程。

一、下载BERT的代码

可以在github上直接下载github的代码。

或者可以直接在百度云直接下载(里面有训练好的模型和github上的不同的数据集),训练的数据集也在里面

链接:https://pan.baidu.com/s/1o-9tX431zpDCJdQQB5eiIw 
提取码:o837 
 

二、修改BERT代码

修改run_classifier.py这个代码文件,这个也是训练的启动文件。

1).首先添加一个读取数据集的类,看一下下面这个类

这个类是谷歌给我们提供的一个demo,作用就是读取训练集,后面只要我们训练就按照这个类的样子仿写一个类来读取我们自己的训练集。

这个是我写的类名。

我把代码直接贴在这里吧~~哈哈,这里读取数据的方式可以按照自己的来,但是return的数据格式一定要按规定的来,还有最后那个get_labels一定要返回自己分类的标签。

class SimProcessor(DataProcessor):
    """Processor for 情感分类"""

    def get_train_examples(self, data_dir):
        file_path = os.path.join(data_dir, 'train.txt')
        f = open(file_path, 'r', encoding='utf-8')
        train_data = []
        index = 0
        for line in f.readlines():
            guid = 'train-%d' % index#参数guid是用来区分每个example的
            line = line.replace("\n", "").split("\t")
            text_a = tokenization.convert_to_unicode(str(line[1]))#要分类的文本
            label = str(line[2])#文本对应的情感类别
            train_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))#加入到InputExample列表中
            index += 1
        return train_data


    def get_dev_examples(self, data_dir):
        file_path = os.path.join(data_dir, 'dev.txt')
        f = open(file_path, 'r', encoding='utf-8')
        dev_data = []
        index = 0
        for line in f.readlines():
            guid = 'dev-%d' % index
            line = line.replace("\n", "").split("\t")
            text_a = tokenization.convert_to_unicode(str(line[1]))
            label = str(line[2])
            dev_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
            index += 1
        return dev_data

    def get_test_examples(self, data_dir):
        file_path = os.path.join(data_dir, 'test.txt')
        f = open(file_path, 'r', encoding='utf-8')
        test_data = []
        index = 0
        for line in f.readlines():
            guid = 'dev-%d' % index
            line = line.replace("\n", "").split("\t")
            text_a = tokenization.convert_to_unicode(str(line[1]))
            # label = str(line[2])
            label = '0'
            test_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
            index += 1
        return test_data

    def get_labels(self):
        return ['0', '1', '2']

 

2)在main里在添加这样一串代码

给我们刚才写的那个类做一个映射,相当于给他起了个名字叫sim~~哈哈

3)再convert_single_example这个方法里添加一段代码,生成label2id的字典,因为我们后期要用到。

要添加的代码

output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
    if not os.path.exists(output_label2id_file):
        with open(output_label2id_file, 'wb') as w:
            pickle.dump(label_map, w)

我是直接添加到最后面了,

至此代码部分修改完毕。

看一下数据集的格式:

 

三、开始训练

 

1):再linux上训练直接执行以下代码即可

这是放在linux上的文件

python run_classifier.py \
  --task_name=sim \
  --do_train=true \
  --do_eval=true \
  --data_dir=data \
  --vocab_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/vocab.txt \
  --bert_config_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=4 \
  --learning_rate=2e-5 \
  --num_train_epochs=5.0 \
  --output_dir=sim_model

2):如果想调试一下,就得再windos上训练了,由于我的windos配置比较低,所以我的batch设置为1了,不然跑不起来。

首先添加启动参数

--task_name=sim
\
--do_train=true
\
--do_eval=true
\
--data_dir=data
\
--vocab_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/vocab.txt
\
--bert_config_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_config.json
\
--init_checkpoint=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_model.ckpt
\
--max_seq_length=128
\
--train_batch_size=1
\
--learning_rate=2e-5
\
--num_train_epochs=3.0
\
--output_dir=sim_model

后面直接在pycharm中debug或run就阔以了。

模型训练结束后可以在sim_model看看模型的输出文件:

可以用以下命令进行测试集预测:

python run_classifier.py \
  --task_name=sim \
  --do_predict=true \
  --data_dir=data \
  --vocab_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/vocab.txt \
  --bert_config_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_model.ckpt \
  --max_seq_length=128 \
  --output_dir=sim_model

四、模型部署

因为生成的ckpt文件实在太大了,所以我们将其转换为pb格式的文件,命令如下:

python freeze_graph.py \
    -bert_model_dir ../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12 \
    -model_dir sim_model \
    -max_seq_len 128 \
    -num_labels 3

转换完成后生成这样一个文件:

在部署之前我们要先下载bert-base,然后安装一些依赖包,

pip install bert-base==0.0.7 -i https://pypi.python.org/simple
pip install flask 
pip install flask_compress
pip install flask_cors
pip install flask_json

下面我们来启动bert服务端:

bert-base-serving-start \
    -model_dir /root/python_demo/bert_classifivation/BERT_Chinese_Classification/sim_model \
    -bert_model_dir /root/python_demo/bert_classifivation/GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12 \
    -model_pb_dir /root/python_demo/bert_classifivation/BERT_Chinese_Classification/sim_model \
    -mode CLASS \
    -max_seq_len 128 \
    -http_port 8091 \
    -port 5575 \
    -port_out 5576 \
    -device_map 1 \
    -num_worker 16

启动完成后

然后需要另外启动一个窗口来测试模型,用以下语句:

curl -X POST http://10.106.1.61:8091/encode \
  -H 'content-type: application/json' \
  -d '{"id": 111,"texts": ["手机很流畅,很好用。","这个手机好差劲!!!店大欺客"], "is_tokenized": false}'

我们看一下结果

至此模型就可以使用了。

我们将模型部署成flask接口,就可以通过api来使用模型了,以下是代码:


import re
from flask import Flask, jsonify, request,abort
from bert_base.client import BertClient


app = Flask(__name__)

# 切分句子
def cut_sent(txt):
    # 先预处理去空格等
    txt = re.sub('([  \t]+)', r" ", txt)  # blank word
    txt = txt.rstrip()  
    nlist = txt.split("\n")
    nnlist = [x for x in nlist if x.strip() != '']  # 过滤掉空行
    return nnlist


# 对句子进行预测识别
def class_pred(list_text):
    with BertClient(ip='10.106.1.61', port=5575, port_out=5576, show_server_config=False, check_version=False,
                    check_length=False, timeout=10000, mode='CLASS') as bc:
        rst = bc.encode(list_text)
        print('result:', rst)

    # 返回结构为:
    # rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}]
    # 抽取出标注结果
    pred_label = rst[0]['pred_label']
    result_txt = [[pred_label[i], list_text[i]] for i in range(len(pred_label))]
    return result_txt


#使用模型入口
@app.route('/model_use',methods = ['get','post'])
def model_use():
    try:
        res = {}
        txt = request.form.get('text', None)
        lstseg = cut_sent(txt)
        res['result'] = class_pred(lstseg)
        print('result:%s' % str(res))
        return jsonify(res)
    except Exception as e:
        # print(e)
        return '程序出错'



if __name__ == '__main__':
    # model_use('手机不错,买值了。')
    app.run(host='0.0.0.0',port=8910)


启动接口后:

利用postman使用接口:

  • 1
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值