小黑转向LE-BERT模型进攻:BertSoftmaxForNer

import torch
import torch.nn as nn
import torch.nn.functional as F
from crf import CRF
from lebert import LEBertModel
from torch.nn import CrossEntropyLoss
import sys
sys.path.append('..')
from losses.focal_loss import FocalLoss
from losses.label_smoothing import LabelSmoothingCrossEntropy
from losses.focal_loss import FocalLoss
from transformers import BertModel,BertPreTrainedModel,BertConfig

class BertSoftmaxForNer(BertPreTrainedModel):
    
    def __init__(self,config):
        super(BertSoftmaxForNer,self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size,config.num_labels)
        self.loss_type = config.loss_type
        self.init_weights()
    
    def forward(self,input_ids,attention_mask,token_type_ids,ignore_index,labels = None):
        # input_ids:[batch_size,max_len]
        # attention_mask:[batch_size,max_len]
        # token_type_ids:[batch_size,max_len]
        # ignore_index:0
        
        # outputs:([batch_size,max_len,bert_dim],[batch_size,bert_dim])
        outputs = self.bert(input_ids = input_ids,attention_mask = attention_mask,token_type_ids = token_type_ids)
        # sequence_output:[batch_size,max_len,bert_dim]
        sequence_output = outputs[0]
        # sequence_output:[batch_size,max_len,bert_dim]
        sequence_output = self.dropout(sequence_output)
        # logits:[batch_size,max_len,num_tags]
        logits = self.classifier(sequence_output)
        # outputs:([batch_size,max_len,num_tags],)
        outputs = (logits,) + outputs[2:]
        if labels is not None:
            assert self.loss_type in ['lsr','focal','ce']
            if self.loss_type == 'lsr':
                loss_fct = LabelSmoothingCrossEntropy(ignore_index = ignore_index)
            elif self.loss_type == 'focal':
                loss_fct = FocalLoss(ignore_index = ignore_index)
            else:
                loss_fct = CrossEntropyLoss(ignore_index = ignore_index)
            
            # 加入attention_mask
            if attention_mask is not None:
                # activate_loss:[batch_size*max_len]
                active_loss = attention_mask.contiguous().view(-1) == 1
                # active_logits:[每个batch有效长度总和,num_tags]
                active_logits = logits.contiguous().view(-1,self.num_labels)[active_loss]
                # active_labels:[每个batch有效长度总和]
                active_labels = labels.contiguous().view(-1)[active_loss]
                loss = loss_fct(active_logits,active_labels)
            else:
                loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1))
            # outputs:(loss,[batch_size,max_len,num_tags])
            outputs = (loss,) + outputs
        return outputs

pretrain_model_path = 'bert-base-chinese'
input_ids = torch.randint(0,100,[4,10])
token_type_ids = torch.zeros([4,10]).long()
attention_mask = torch.ones([4,10]).long()
word_embeddings = torch.randn([4,10,5,200])
word_mask = torch.ones(4,10,5).long()
config = BertConfig.from_pretrained(pretrain_model_path,num_labels = 20)
config.word_embed_dim = 200
config.loss_type = 'ce'
config.word_vocab_size = 2162
model = BertSoftmaxForNer(config)
outputs = model(input_ids = input_ids,
                attention_mask = attention_mask,
                token_type_ids = token_type_ids,
                ignore_index = 0)
print('output.shape:',outputs[0].shape)

输出:

output.shape: torch.Size([4, 10, 20])

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值