阿里天池-糖尿病命名实体识别-CRF

更新于:https://blog.csdn.net/zc1226/article/details/138974229?spm=1001.2014.3001.5501


从上个博客中新增加CRF模块,只有部分代码需要更改,如有需要可以参考上个博客。

一、实验步骤

1.准备数据集

代码如下(示例):

# 没有修改

2.转换成dataloader

代码如下(示例):

# 没有修改

3.构建模型

代码如下(示例):

from torchcrf import CRF
class BERT_CRF(torch.nn.Module):
    def __init__(self,bertModel):
        super(BERT_CRF, self).__init__()
        self.bert = bertModel
        self.dropout = nn.Dropout(0.1)
        self.linear = nn.Linear(768, 31)
        #添加CRF模块
        self.crf = CRF(31, batch_first=True)

    def forward(self, input_ids, attention_mask, labels = None):
        outputs = self.bert(input_ids, attention_mask)
        sequence_output = self.dropout(outputs[0])
        emissions = self.linear(sequence_output)
        #以下为修改内容
        if labels is not None:
            # 计算 CRF 损失
            loss = -self.crf(emissions, labels, mask=attention_mask.byte(), reduction='mean')
            return loss
        else:
            # 解码以获得预测标签
            predictions = self.crf.decode(emissions, mask=attention_mask.byte())
            return predictions
        #以上为修改内容
        
bertModel = BertModel.from_pretrained('hfl/chinese-roberta-wwm-ext')

model = BERT_CRF(bertModel)
model = model.to(device)

4.训练模型

代码如下(示例):

#对计算结果和label变形,并且移除pad
def reshape_and_remove_pad(labels, attention_mask):
    #变形,便于计算loss
    labels = labels.reshape(-1)

    #忽略对pad的计算结果
    select = attention_mask.reshape(-1) == 1
    labels = labels[select]
    return labels

your_labels_list = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
optimizer = AdamW(model.parameters(), lr=3e-5)
epochs = 3
total_steps = len(train_dataloader) * epochs
loss_function = torch.nn.CrossEntropyLoss()
all_big_idx = np.array([])
all_targets = np.array([])
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for _,data in enumerate(train_dataloader, 0):
        batch_input_ids = data['input_ids'].to(device, dtype = torch.long)
        batch_input_mask = data['attention_mask'].to(device, dtype = torch.long)
        batch_labels = data['labels'].to(device, dtype = torch.long) 
        
        # 导入了CRF模块,直接计算损失,不用交叉熵损失函数
        loss= model(batch_input_ids, attention_mask=batch_input_mask,labels=batch_labels)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (_+1) %100 == 0:
            avg_train_loss = total_loss / len(train_dataloader)
            print(f"Average train loss: {avg_train_loss}")
            total_loss = 0

5.验证模型

代码如下(示例):

model.eval()
all_big_idx = np.array([])
all_targets = np.array([])
for _,data in enumerate(test_dataloader, 0):
    batch_input_ids = data['input_ids'].to(device, dtype = torch.long)
    batch_input_mask = data['attention_mask'].to(device, dtype = torch.long)
    batch_labels = data['labels'].to(device, dtype = torch.long) 

    # 不把labels放入模型中,这样能得到预测值并计算各类指标
    out= model(batch_input_ids, attention_mask=batch_input_mask)
    out = [label for sublist in out for label in sublist]
    targets = reshape_and_remove_pad(batch_labels,batch_input_mask)

    all_big_idx = np.concatenate((all_big_idx,out))
    targets = targets.cpu()
    targets = targets.detach().numpy()
    all_targets = np.concatenate((all_targets,targets))

    if (_+1) %10 == 0:
        precision = precision_score(all_big_idx, all_targets, labels=your_labels_list, average='weighted')
        recall = recall_score(all_big_idx, all_targets, labels=your_labels_list, average='weighted')
        f = f1_score(all_big_idx, all_targets, labels=your_labels_list, average='weighted')
        print(f)
    if (_+1) %1000 == 0:
        print(big_idx)
        print(targets)

总结

加入CRF,并把模型bert换成hfl/chinese-roberta-wwm-ext。
下图为F1的部分运行结果。
部分结果展示

  • 15
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值