RuntimeError: Error(s) in loading state_dict for BertClassifier 模型不匹配

RuntimeError: Error(s) in loading state_dict for BertClassifier:
	size mismatch for lstm.weight_ih_l0: copying a param with shape torch.Size([400, 768]) from checkpoint, the shape in current model is torch.Size([360, 768]).
	size mismatch for lstm.weight_hh_l0: copying a param with shape torch.Size([400, 100]) from checkpoint, the shape in current model is torch.Size([360, 90]).
	size mismatch for lstm.weight_ih_l0_reverse: copying a param with shape torch.Size([400, 768]) from checkpoint, the shape in current model is torch.Size([360, 768]).
	size mismatch for lstm.weight_hh_l0_reverse: copying a param with shape torch.Size([400, 100]) from checkpoint, the shape in current model is torch.Size([360, 90]).
	size mismatch for linear1.weight: copying a param with shape torch.Size([100, 401]) from checkpoint, the shape in current model is torch.Size([90, 361]).
	size mismatch for linear1.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([90]).
	size mismatch for linear2.weight: copying a param with shape torch.Size([7, 100]) from checkpoint, the shape in current model is torch.Size([7, 90]).
	size mismatch for linear1_ent.weight: copying a param with shape torch.Size([50, 200]) from checkpoint, the shape in current model is torch.Size([45, 180]).
	size mismatch for linear1_ent.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([45]).
	size mismatch for linear2_ent.weight: copying a param with shape torch.Size([2, 50]) from checkpoint, the shape in current model is torch.Size([2, 45]).

原因

由于导入的模型和当前模型的参数不一致。

首先找到当前的模型,
下面展示一些 内联代码片

model = BertClassifier(args)
class BertClassifier(nn.Module):
    'Neural Network Architecture'
    def __init__(self, args):
        
        super(BertClassifier, self).__init__()
        
        self.hid_size = args.hid
        self.batch_size = args.batch
        self.num_layers = args.num_layers
        self.num_classes = len(args.label_to_id)
        self.num_ent_classes = 2

        self.dropout = nn.Dropout(p=args.dropout)
        # lstm is shared for both relation and entity
        self.lstm = nn.LSTM(768, self.hid_size, self.num_layers, bias = False, bidirectional=True)

        # MLP classifier for relation
        self.linear1 = nn.Linear(self.hid_size*4+args.n_fts, self.hid_size)
        self.linear2 = nn.Linear(self.hid_size, self.num_classes)

        # MLP classifier for entity
        self.linear1_ent = nn.Linear(self.hid_size*2, int(self.hid_size / 2))
        self.linear2_ent = nn.Linear(int(self.hid_size / 2), self.num_ent_classes)

        self.act = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)
        self.softmax_ent = nn.Softmax(dim=2)

查看报错,找到不匹配的内容

比如

size mismatch for linear2_ent.weight: copying a param with shape torch.Size([2, 50]) from checkpoint, the shape in current model is torch.Size([2, 45]).

linear2_ent.weight,找到这个在BertClassifier(nn.Module):,查看其构成,

self.linear2_ent = nn.Linear(int(self.hid_size / 2), self.num_ent_classes)

对比两个模型的参数,修改使其一致

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值