【Pytorch|Bug】解决 RuntimeError: Error(s) in loading state_dict for Network: size mismatch

问题背景

Github开源项目:https://github.com/zhang-tao-whu/e2ec

python train_net.py coco_finetune --bs 12 \
--type finetune --checkpoint data/model/model_coco.pth

报错如下:

loading annotations into memory...
Done (t=0.09s)
creating index...
index created!
load model: data/model/model_coco.pth
Traceback (most recent call last):
  File "train_net.py", line 67, in <module>
    main()
  File "train_net.py", line 64, in main
    train(network, cfg)
  File "train_net.py", line 40, in train
    begin_epoch = load_network(network, model_dir=args.checkpoint, strict=False)
  File "/root/autodl-tmp/e2ec/train/model_utils/utils.py", line 66, in load_network
    net.load_state_dict(net_weight, strict=strict)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Network:
        size mismatch for dla.ct_hm.2.weight: copying a param with shape torch.Size([80, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
        size mismatch for dla.ct_hm.2.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([1]).

由于我自己的数据集类别只有1,而COCO数据集有80个类别,预训练模型中dla.ct_hm.2 参数的大小与我的不符,所以需要舍弃预训练模型中的这个参数的权重。

解决方法

e2ec/train/model_utils/utils.py 中修改:

def load_network(net, model_dir, strict=True, map_location=None):

    if not os.path.exists(model_dir):
        print(colored('WARNING: NO MODEL LOADED !!!', 'red'))
        return 0

    print('load model: {}'.format(model_dir))
    if map_location is None:
        pretrained_model = torch.load(model_dir, map_location={'cuda:0': 'cpu', 'cuda:1': 'cpu',
                                                               'cuda:2': 'cpu', 'cuda:3': 'cpu'})
    else:
        pretrained_model = torch.load(model_dir, map_location=map_location)
    if 'epoch' in pretrained_model.keys():
        epoch = pretrained_model['epoch'] + 1
    else:
        epoch = 0
    pretrained_model = pretrained_model['net']

    net_weight = net.state_dict()
    for key in net_weight.keys():
        net_weight.update({key: pretrained_model[key]})
    '''
	舍弃部分参数
	'''
    net_weight.pop("dla.ct_hm.2.weight")
    net_weight.pop("dla.ct_hm.2.bias")
    
    net.load_state_dict(net_weight, strict=strict)
    return epoch

注意:load_state_dict 中设置 strict=False 只对增加或删除部分层有用,对于在原来参数上改变维度大小的情况不适用。

<think>好的,我现在要解决用户在加载LSTM模型state_dict时遇到的size mismatch问题,特别是关于linear.weight和linear.bias的错误。首先,我需要回忆一下PyTorch模型加载的相关知识,以及常见的错误原因和解决方法。 用户提到的问题是在加载state_dict时出现尺寸不匹配,错误涉及linear层的权重和偏置。根据引用中的信息,比如引用[2]中的例子,类似的错误可能是因为模型结构在保存和加载时不一致,或者在DataParallel的使用上存在问题。比如,当使用DataParallel包装模型后,保存的state_dict中的参数名称会带有"module."前缀,而加载时如果模型没有用DataParallel包装,就会导致名称不匹配,从而引发尺寸错误。这可能是一个原因。 另一个可能的原因是模型结构本身发生了变化。比如,用户可能在保存模型后修改了LSTM层的隐藏层大小、输入维度,或者后面的全连接层(linear层)的输入输出维度,导致加载时参数的形状与当前模型不匹配。例如,如果原来的linear层的输入是500维,而现在改成了2000维,那么对应的weight和bias的尺寸就会不一致,导致加载错误。 接下来,我需要考虑解决方法。根据引用[3],设置strict=False可以忽略不匹配的参数,但这样可能导致模型部分参数没有被加载,影响性能。所以这只是临时解决方法,不能根本解决问题。正确的做法应该是确保模型结构在保存和加载时完全一致。 具体到用户的问题,linear层的尺寸不匹配,可能是因为LSTM层的输出维度发生了变化,导致后面的全连接层的输入尺寸需要调整。例如,假设LSTM的hidden_size被修改了,或者是否使用双向LSTM(bidirectional)的情况不同,这会影响全连接层的输入维度。比如,双向LSTM的输出维度是hidden_size*2,如果之前是单向的,而现在改为双向,那么后面的linear层的输入维度需要乘以2,否则就会出现尺寸不匹配。 此外,检查是否有无意识地修改了模型的层结构,比如添加或删除了某些层,或者在模型定义时参数设置错误。例如,在定义LSTM时,设置input_size、hidden_size或者num_layers时可能有误,导致后续层的参数尺寸不符。 另外,检查保存的checkpoint是否完整,是否正确地保存了整个模型而不仅仅是state_dict。有时候如果保存的是模型的其他部分,可能在加载时结构不一致。 可能的解决步骤包括: 1. 检查模型定义,确保保存和加载时的结构完全一致,特别是LSTM的参数和后面的linear层。 2. 检查是否使用了DataParallel或相关的包装,确保加载时模型的结构与保存时一致。 3. 如果结构确实需要改变,考虑手动调整state_dict中的参数名称或尺寸,或者在加载后重新初始化某些层。 4. 打印当前模型和checkpoint中的参数形状,对比差异,找出具体是哪些层不匹配。 比如,用户可以打印当前模型的state_dict和加载的checkpoint中的state_dict,比较linear.weight和linear.bias的形状,看是否因为输入或输出维度变化导致。例如: ```python print("Current model's linear.weight:", model.linear.weight.shape) print("Checkpoint's linear.weight:", checkpoint['linear.weight'].shape) ``` 这样可以帮助确定问题所在。如果发现是输入维度不同,可能需要调整LSTM层的输出,或者调整后续层的结构,使其与保存的checkpoint中的参数匹配。 总结来说,关键点在于确保模型结构的一致性,避免在保存和加载之间修改模型架构,正确处理DataParallel带来的参数名前缀问题,以及在必要时调整参数或结构。</think>### 解决方案:LSTM模型加载state_dict时出现`size mismatch`的解决方法 #### 1. 根本原因分析 - **模型结构不一致**:加载的state_dict与当前模型结构参数(如LSTM的`input_size`、`hidden_size`、`num_layers`,或后续全连接层的维度)不匹配[^2][^3]。 - **DataParallel前缀问题**:保存模型时使用了`DataParallel`包装,但加载时未使用,导致参数名多出`module.`前缀,引发形状不匹配[^2]。 - **双向LSTM改动**:若LSTM从单向改为双向,输出维度翻倍,但后续全连接层未调整输入维度,导致`linear.weight`尺寸不匹配。 - **参数误修改**:训练时调整了模型层的维度(如隐藏层大小)但未同步修改后续层。 --- #### 2. 具体解决步骤 **步骤1:检查模型结构一致性** - 确保训练和加载时的模型定义**完全一致**,重点关注: ```python # LSTM层参数检查示例 class MyModel(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM( input_size=100, # 必须与保存时一致 hidden_size=256, # 必须与保存时一致 bidirectional=False # 若改为True,后续层需适配 ) self.linear = nn.Linear(256, 10) # 输入维度需匹配LSTM的hidden_size # 若使用双向LSTM,输入维度应为 hidden_size * 2 # self.linear = nn.Linear(256 * 2, 10) ``` **步骤2:处理DataParallel前缀问题** - **情况1**:保存时用`DataParallel`,加载时未用: ```python # 加载时手动去除"module."前缀 state_dict = torch.load('checkpoint.pth') new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} model.load_state_dict(new_state_dict) ``` - **情况2**:保存时未用`DataParallel`,加载时使用: ```python # 加载时添加"module."前缀 model = nn.DataParallel(model) model.load_state_dict(torch.load('checkpoint.pth')) ``` **步骤3:手动调整参数维度(适用于结构微调)** - 若必须修改模型结构,可**选择性加载匹配参数**: ```python checkpoint = torch.load('checkpoint.pth') model_dict = model.state_dict() # 仅加载名称和形状一致的参数 matched_dict = {k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape} model_dict.update(matched_dict) model.load_state_dict(model_dict) print(f"Loaded {len(matched_dict)}/{len(checkpoint)} parameters") ``` **步骤4:严格模式与警告处理** - 使用`strict=False`跳过不匹配参数(慎用,可能影响性能): ```python model.load_state_dict(checkpoint, strict=False) # 忽略不匹配参数[^3] ``` - 通过警告信息定位具体不匹配的层: ``` Missing keys: ['linear.weight', 'linear.bias'] # 缺失的键(当前模型新增的层) Unexpected keys: ['module.lstm.weight_ih_l0'] # 意外的键(checkpoint中多余的层) ``` --- #### 3. 验证方法 - **打印参数形状对比**: ```python # 检查关键层形状 print("当前模型 linear.weight:", model.linear.weight.shape) print("Checkpoint linear.weight:", checkpoint['linear.weight'].shape) ``` - **输出示例**: ``` 当前模型 linear.weight: torch.Size([10, 256]) Checkpoint linear.weight: torch.Size([10, 512]) # 不匹配!需调整LSTM输出维度 ``` --- #### 4. 高级调试技巧 - **可视化模型结构**:使用`torchsummary`库输出模型结构: ```python from torchsummary import summary summary(model, input_size=(100, 32)) # 输入维度示例(seq_len, batch_size) ``` - **检查LSTM输出维度**: ```python # 验证LSTM输出维度是否与linear层匹配 lstm_output, _ = model.lstm(torch.randn(10, 32, 100)) # (seq_len, batch, input_size) print("LSTM输出维度:", lstm_output.shape[-1]) # 应为 hidden_size * num_directions ``` ---
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

宇宙爆肝锦标赛冠军

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值