使用paddle ernie预训练模型进行中文文本分类代码

需要解决的问题:在电商中有一些黑产使用机器脚本自动注册大量的垃圾店铺,而这些垃圾店铺的店铺名有一些是无意义的乱文,例如“唇评照桌”,“脑冻砸路忻故”等,因此需要训练一个中文文本的二分类模型来识别哪些是垃圾店铺名哪些是正常的店铺名

下面的文章中包含了:
1)使用paddle的预训练模型ernie_tiny进行中文文本的二分类模型训练
2)使用动态图模型进行数据预测
3)将动态图模型转化为静态图模型
4)加载静态图模型,并用其进行数据预测

注意:动态图模型便于模型的调试,但是预测的速度较慢;但是静态图模型不利于模型的调试,但是预测的速度快,一般是训练好的动态图模型转化为静态图模型,然后部署上线预测

一、训练分类模型

train.py


import pandas as pd
import paddle
import paddlehub as hub
import ast
import argparse
from paddlehub.datasets.base_nlp_dataset import TextClassificationDataset
 
class MyDataset(TextClassificationDataset):
    # 数据集存放目录
    base_path = 'data/shop_name'
    # 数据集的标签列表,多分类标签格式为['0', '1', '2', '3',...]
    label_list = ['0', '1']
 
    def __init__(self, tokenizer, max_seq_len: int = 10, mode: str = 'train'):
        if mode == 'train':
            data_file = 'train.tsv'
        elif mode == 'test':
            data_file = 'test.tsv'
        else:
            data_file = 'dev.tsv'
        super().__init__(
            base_path=self.base_path,
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
            mode=mode,
            data_file=data_file,
            label_list=self.label_list,
            is_file_with_header=True)

 
# 转成tsv格式
file_path = "data/shop_name/shop_name_train.csv"
text = pd.read_csv(file_path, sep="\t")
text = text.sample(frac=1)  # 打乱数据集
print(len(text))
 
train = text[:int(len(text) * 0.8)]
dev = text[int(len(text) * 0.8):int(len(text) * 0.9)]
test = text[int(len(text) * 0.9):]
 
train.to_csv('data/shop_name/train.tsv', sep='\t', header=None, index=False, columns=None, mode="w")
dev.to_csv('data/shop_name/dev.tsv', sep='\t', header=None, index=False, columns=None, mode="w")
test.to_csv('data/shop_name/test.tsv', sep='\t', header=None, index=False, columns=None, mode="w")
 
# 验证train,dev,test标签分布是否均匀
for file in ['train', 'dev', 'test']:
    file_path = f"data/shop_name/{file}.tsv"
    text = pd.read_csv(file_path, sep="\t", header=None)
    prob = dict()
    total = len(text[0])
    for i in text[0]:
        if prob.get(i) is None:
            prob[i] = 1
        else:
            prob[i] += 1
    # 按标签排序
    prob = {i[0]: round(i[1] / total, 3) for i in sorted(prob.items(), key=lambda k:
  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值