使用huggingface实现BERT+BILSTM情感3分类(附数据集源代码)

一、前言

        最近在学习自然语言处理和大模型实战,通过实战来总结一下学习内容,顺便将学的一些东西发表在博客上,希望能对看到文章的您有帮助,有任何问题也可以发表讨论或联系作者。

        GitHub源代码:Befineyou/bert-bilstm-in-Sentiment-classification: The author applies BERT+BILSTM to emotion classification (github.com)

        本次实战讲解围绕BERT+BILSTM模型开展情感3分类进行,介绍在huggingface框架下使得我们具有方便的模型训练方法。

二、huggingface 下载

        本文使用的是中文bert,需要提前从huggingface官网中将预训练好的模型组件下载下来,下载网站来自:bert-base-chinese at main (huggingface.co)

三、数据集介绍

        数据集来源:疫情期间网民情绪识别 竞赛 - DataFountain

        看到了一位博主的数据处理讲解代码,讲的很不错,大家可以去看一下这个博主的处理方法:【NLP实战】基于Bert和双向LSTM的情感分类【上篇】_bert-lstm-CSDN博客

四、模型搭建

4.1 确定编码工具 

        编码工具是为了将抽象文本数据转化为词典后的数据。

from transformers import BertTokenizer
token = BertTokenizer.from_pretrained('bert-base-chinese')
print(token)

4.2 自定义数据集

        huggingface的一大好处就是该社区集成了大量的数据集和模型,可以直接调用,但当处理自己的数据集(csv)的形式,需要将csv转化为huggingface的dataset类型,从而进行构建,注意在我的示例数据集中,已经默认text列为文本内容列,label为文本标签列。

train_data = pd.read_csv('data/train_clean.csv')
train_dataset = Dataset.from_pandas(train_data)

class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.dataset = train_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        text = self.dataset[item]['text']
        label = self.dataset[item]['label']
        return text, label

train_dataset = Dataset()

4.3 定义数据整理函数以及数据集加载器

         训练时要将数据打包成loader的形式,需要将抽象的文本数据转换为编码后的数据,这时就用到了之前定义好的编码工具

input_ids 代表句子中每个字的词典编号

attention_mask 只有0或1,0代表空,也就是PAD

token_type_ids 只有0或1,0代表第一个句子和特殊符号,1代表第二个句子

注意,在编码时的句子已经被修改,一般在句首加CLS,句尾加PAD,根据max_length选择对句子截断或者加PAD

def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    #labels = torch.LongTensor(labels)
    # labels = torch.tensor(labels).long()
    labels = torch.tensor([label if label != -1 else 0 for label in labels]).long()

    return input_ids, attention_mask, token_type_ids, labels
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

 4.4 定义预训练模型

        首先要将BERT预训练模型添加进来

from transformers import BertModel

pretrained = BertModel.from_pretrained('bert-base-chinese')
for param in pretrained.parameters():
    param.requires_grad_(False)

        之后就是定义Bert+BILSTM的模型

class BertBiLSTMClassifier(nn.Module):
    def __init__(self, num_classes, hidden_size=768, lstm_hidden_size=128, lstm_layers=1):
        super(BertBiLSTMClassifier, self).__init__()
        # BiLSTM层
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers,
                            batch_first=True, bidirectional=True)
        # 全连接层用于分类
        self.fc = nn.Linear(lstm_hidden_size * 2, num_classes)


    def forward(self, input_ids, attention_mask, token_type_ids):
        # BERT的前向传播
        with torch.no_grad():
            outputs = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        #pooled_output = outputs.pooler_output
        last_hidden_state = outputs.last_hidden_state
        # 将BERT输出输入BiLSTM
        lstm_out, _ = self.lstm(last_hidden_state)

        # 提取BiLSTM的最后一层输出
        lstm_out = lstm_out[:, -1, :]

        # 全连接层分类
        logits = self.fc(lstm_out)

        return logits
model = BertBiLSTMClassifier(num_classes)

 4.5 训练模型

def train():
    optimizer = AdamW(model.parameters(), lr=5e-4)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)
    model.train()
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
        out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if i % 10 == 0:
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), lr, accuracy)
        if i % 90 == 0:
            torch.save(model.state_dict(), f'bert_cnn_model_epoch_{i}.pth')

 4.6 测试模型

def test():
    loader_test = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    model.eval()
    correct = 0
    total = 0
    for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(loader_test):
        if i==5:
            break
        print(i)
        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids = token_type_ids)
            #out = out.argmax(dim=1)
            out = torch.argmax(out, dim=1)
            correct +=(out==labels).sum().item()
            total +=len(labels)
        print(correct/total)
test()

 这样一个简单的BERT+BILSTM实战模型便可以应用了。

五、总结

如有疑问,可以联系作者。

  • 21
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值