使用Bert进行文本分类
本文为学习Datawhale 2021.8组队学习NLP入门之Transformer笔记
原学习文档地址:https://github.com/datawhalechina/learn-nlp-with-transformers
1 数据的读入
1.1 Transformer Datasets
使用Transformers Datasets库读取网络数据,可以用于在公开数据集上验证模型的好坏。
除了mnli-mm以外,其他任务都可以直接通过任务名字进行加载。数据加载之后会自动缓存。
from datasets import list_datasets, load_dataset, list_metrics, load_metric
actual_task = "mnli" if task == "mnli-mm" else task
dataset = load_dataset("glue", actual_task)
metric = load_metric('glue', actual_task)
注意容易出现网络问题,根据报错信息在hosts文件中设置github网址的端口,更新最新的datasets库版本,可以解决这个问题。
也可以下载好数据集后,手动放到cache里面,如
C:\Users\用户名.cache\huggingface\datasets\glue\cola\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad
datasets对象本身是一种DatasetDict数据结构,对于训练集、验证集和测试集,只需要使用对应的key(train,validation,test)即可得到相应的数据。
print(dataset)
# 输出
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 8551
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1043
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1063
})
})
就是一个嵌套字典。
dataset[‘train’][0][‘sentence’] 可以这样来调用训练集里第一个数据的sentence
1.1.1 datasets.Metric
可以输入metric查看其使用方法
举例:
import numpy as np
fake_preds = np.random.randint(0, 2, size=(64,))
fake_labels = np.random.randint(0, 2, size=(64,