1.模型下载
bert-base-Chinese是一个基于Transformer架构的中文预训练模型,使用了大量的中文语料进行训练。bert-base-chinese具有12层的Transformer编码器,包含约110万个参数。它在多个中文自然语言处理任务上表现出色,如文本分类、命名实体识别和情感分析等。
通过使用bert-base-chinese,我们可以将其用作下游任务的特征提取器或者进行微调以适应特定任务
#模型下载 cache_dir模型下载路径
from modelscope import snapshot_download
model_dir = snapshot_download('tiansz/bert-base-chinese',cache_idr='')
2.数据集
从huggingface上进行数据集下载
3.定义增量微调任务
from transformers import BertModel
import torch
# 定义设备信息 选择是GPU还是cpu
DEVICE= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
# 将模型加载到device上
pretrained = (BertModel.from_pretrained(r'D:\PyProject\trainModel\model\bert\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f')
.to(DEVICE))
print(pretrained)
# 增量微调,根据 (dense): Linear(in_features=768, out_features=768, bias=True) 进行下游匹配处理
# 定义下游任务(增量微调模型->情感分析)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
#设计全连接网络,设置in_features参数为加载的bert模型特征提取维度768,设置out_features进行二分类任务
self.fc = torch.nn.Linear(768,2)
def forward(self,input_ids,attention_mask,token_type_ids):
# 冻结模型的参数,让他不参与训练
with torch.no_grad():
out = pretrained(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
# 增量模型参与训练
out = self.fc(out.last_hidden_state[:,0])
return out
4.自定义数据集
加载数据集,包括训练集,测试集,验证集。内容为text与label,分别是评价内容与标注标签
from torch.utils.data import Dataset
from datasets import load_from_disk
class MyDataset(Dataset):
def __init__(self,split):
#从磁盘加载数据
self.dataset = load_from_disk(r"D:\PyProject\trainModel\data\chn_senti_corp")
if split == 'train':
self.dataset = self.dataset["train"]
elif split == "test":
self.dataset = self.dataset["test"]