通过huggingface的transformers库,datasets库函数,调用BERT的分词器和预训练模型进行中文分类,如情感分析。

训练和验证完整代码如下:

'''  
项目:通过huggingface的transformers库,datasets库,函数,调用BERT的分词器和预训练模型进行中文分类,如情感分析。 
时间:2024年6月23日    
  
'''  
import sys    
import torch    
from datasets import load_dataset, load_from_disk #一个在线下载,一个本地加载  

# 前期准备数据工作    
class Dataset(torch.utils.data.Dataset):    
    def __init__(self, dataset):    
        # self.dataset = load_dataset(path="data/ChnSentiCorp/train", split=split)  
        # self.dataset = load_dataset(path='./data/ChnSentiCorp', split=split)        
        self.dataset = dataset # 通过本地加载    
  
    def __len__(self):         
        return len(self.dataset)         
  
    def __getitem__(self, i):         
        text = self.dataset[i]['text']         
        label = self.dataset[i]['label']         
        return text, label         
  
train_dataset = load_from_disk('./data/ChnSentiCorp/train')
valid_dataset = load_from_disk('./data/ChnSentiCorp/validation')
test_dataset = load_from_disk('./data/ChnSentiCorp/test')

train_dataset = Dataset(train_dataset)     
valid_dataset = Dataset(valid_dataset)     
test_dataset = Dataset(test_dataset)     
print("train_dataset:", len(train_dataset))     
print("valid_dataset:", len(valid_dataset))     
print("test_dataset:", len(valid_dataset))     
  
# 1.加载字典和分词工具,权重模型        
from transformers import BertTokenizer, BertModel        
tokenizer = BertTokenizer.from_pretrained('D:/Codes/Models/bert-base-chinese')
'''  
# 测试代码成功!       
text = "Hello, how are you?"       
inputs = tokenizer(text, return_tensors="pt")       
outputs = pretrained_model(**inputs)       
print(outputs)       
'''  
  
# 现将数据全部进行转换成token向量      
def collate_fn(data):      
    sents = [i[0] for i in data]     
    labels = [i[1] for i in data]     
  
    # 编码     
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents, # 所有句子  
                                   truncation=True,    
                                   padding='max_length',   
                                   max_length=500,   
                                   return_tensors='pt', # pytroch   
                                   return_length=True   
                                   )  
  
    input_ids = data['input_ids']  
    attention_mask = data['attention_mask']  
    token_type_ids = data['token_type_ids'] #第一个句子和特殊符号的位置是0,第二个句子的位置是1  
    labels = torch.LongTensor(labels)  
  
    return input_ids, attention_mask, token_type_ids, labels  
# 数据加载器  
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,  
                                     batch_size=16,  
                                     collate_fn=collate_fn,  
                                     shuffle=True,  
                                     drop_last=True)  
  
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
    break   
  
print('loader:', len(train_loader))   
print('input_ids.shpae:', input_ids.shape)   
print('attention_mask.shpae:', attention_mask.shape)
print('token_type_ids.shpae:', token_type_ids.shape)
  
pretrained_model = BertModel.from_pretrained('D:/Codes/Models/bert-base-chinese')
#不训练,不需要计算梯度   
for param in pretrained_model.parameters():   
    param.requires_grad_(False)   
  
#模型试算(经过模型变成, 每一个token变成768维的向量)   
out = pretrained_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
  
print('original_out:', out.last_hidden_state.shape)
  
  
# 2.定义下游任务的模型   
class Model(torch.nn.Module):   
    def __init__(self):   
        super().__init__()   
        self.fc = torch.nn.Linear(768, 2)   
    def forward(self, input_ids, attention_mask, token_type_ids):
  
        with torch.no_grad():   
            out = pretrained_model(input_ids=input_ids,   
                       attention_mask=attention_mask,   
                       token_type_ids=token_type_ids)   
        out = self.fc(out.last_hidden_state[:, 0])   
        out = out.softmax(dim=1)   
        return out   
  
model = Model()   
output_size = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).shape  
print('our_output_size:', output_size)   
  
# 3.训练   
from transformers import AdamW   
  
optimizer = AdamW(model.parameters(), lr=5e-4)   
criterion = torch.nn.CrossEntropyLoss()   
  
def train():  
    model.train()   
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
        out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        loss = criterion(out, labels)   
        loss.backward()   
        optimizer.step()   
        optimizer.zero_grad()   
  
        if i % 5 == 0:   
            out = out.argmax(dim=1)   
            accuracy = (out == labels).sum().item() / len(labels)   
            print(i, loss.item(), accuracy)   
  
        if i == 300:   
            break   
    print('Completely!')   
train()   
  
# 4. 验证    
def vaild():    
    model.eval()    
    correct = 0    
    total = 0    
  
    loader_validation = torch.utils.data.DataLoader(dataset=valid_dataset,
                                              batch_size=32,    
                                              collate_fn=collate_fn,    
                                              shuffle=True,    
                                              drop_last=True)    
  
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_validation):
        if i == 5:    
            break    
        print(i)    
        with torch.no_grad():    
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        out = out.argmax(dim=1)    
        correct += (out == labels).sum().item()    
        total += len(labels)    
    print(correct / total)    
vaild()  

结果如下:

在这里插入图片描述

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值