安装开发环境
除了安装基本的python安装环境外,在Anaconda中安装datasets
和transformers
包。或者在Jupyter Notebook中直接运行
!pip install datasets
!pip install transformers
从这两个包中导入一些必要的包
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel,Trainer, TrainingArguments,AutoModelForSequenceClassification
读取数据和预训练模型,这里加载烂番茄(rotten_tomatoes)的评论包, 以及微软的BERT模型蒸馏microsoft/xtremedistil-l6-h256-uncased
烂番茄的评论包是个二分类数据库,因此加载Model的时候设置num_labels=2
, 设置hidden_dropout_prob
防止过拟合。
dataset = load_dataset("rotten_tomatoes")
tokenizer = AutoTokenizer.from_pretrained("microsoft/xtremedistil-l6-h256-uncased")
model = AutoModelForSequenceClassification.from_pretrained("microsoft/xtremedistil-l6-h256-uncased",num_labels=2,hidden_dropout_prob=2)
需要注意的是, 在model card中官方给定的加载模型代码是:
tokenizer = AutoTokenizer.from_pretrained("microsoft/xtremedistil-l6-h256-unc