BERT+CRF+pytorch
先给出模型结构图:
https://blog.csdn.net/HUSTHY/article/details/109276404
BERT + CRF损失函数的计算
# CRF的损失函数计算
def loss_fn(self, bert_encode, output_mask, tags):
loss = self.crf.negative_log_loss(bert_encode, output_mask, tags)
return loss
注意重要的是计算损失函数的方式,不是简单的CrossEntropy,而是对CRF的发射矩阵进行训练与计算,CRF的定义如下:
import torch
import torch.nn as nn
from torch.autograd import Variable
class CRF(nn.Module):
"""线性条件随机场"""
def __init__(self, num_tag, use_cuda=False):
https://blog.csdn.net/lcomecon/article/details/108728880