是训练yolov5_obb时出现的问题,用了预训练模型,但是加载时权重参数训练时出错:
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
raise RuntimeError('Error(s) in loading state_dict for {}:\n t{}'.format(
RuntimeError:Error(s) in loading state_dict for Model:
size mismatch for model.24.m.0.weight: copying a param with shape torch.Size([603,256,1,1]) from checkpoint,the shape in current model is torch.Size([564, 256, 1, 1]).
size mismatch for model.24.m.0.bias: copying a param with shape torch.Size([603,256,1,1]) from checkpoint the shape in current model is torch.Size([564, 256, 1,1]).
size mismatch for model.24.m.1.weight: copying a param with shape torch.Size([603,256,1,1]) from checkpoint,the shape in current model is torch.Size([564, 256, 1, 1]).
size mismatch for model.24.m.1.bias: copying a param with shape torch.Size([603,256,1,1]) from checkpoint the shape in current model is torch.Size([564, 256, 1,1]).
size mismatch for model.24.m.2.weight: copying a param with shape torch.Size([603,256,1,1]) from checkpoint,the shape in current model is torch.Size([564, 256, 1, 1]).
size mismatch for model.24.m.2.bias: copying a param with shape torch.Size([603,256,1,1]) from checkpoint the shape in current model is torch.Size([564, 256, 1,1]).
查找了一些办法,基本上是关于 loading_state_dict()的用法,如何去处理加载的权重,开始直接定位到了ckpt[‘model’]的这几层的权重,强行将几层的权重维度匹配到model中,但是训练时仍然没有解决问题。
关于获取权重文件的参数:
ckpt = torch.load(weights, map_location=device)
tmpckpt = ckpt['model'].float().state_dict()
"""下面是让权重为度匹配的处理方法,具体情况具体分析"""
csd = {}
for k,v in tmpckpt.items():
if model.state_dict()[k].numel() != v.numel():
tmpwht = v[:(model.state_dict()[k].shape[0]),:,:,:]
csd[k] = tmpwht
else:
csd[k] = v
涉及到的几个函数:
① intersect_dicts(csd, model.state_dict(), exclude=exclude)
def intersect_dicts(da, db, exclude=()):
"""筛选预训练权重中的键值对,用于筛选字典中的键值对
将da中的值复制给da,但是除了exclude中的键值对"""
# 返回字典da中的键值对 要求键k在字典db中且全部都不在exclude中 同时da中值的shape对应db中值的shape(相同)
return {k: v for k, v in da.items() if k in db
and not any(x in k for x in exclude) and v.shape == db[k].shape}
② model.load_state_dict(csd, strict=False)
其中strict=False就是为了预防权重加载时维度匹配达到兼容的,
strict=True 则要求模型维度与权重模型完全一致
但是这些搞明白下来仍然没有解决问题,而且有一段调试时间报错并没有定位到:
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
而这也是直接的问题所在。
一般来说,这种维度的不匹配 会由于类别数量的不一致 导致权重模型无法与要训练的模型的维度相同,从而报错,所以要解决这个错只能修改这处的信息了。
但是,这处的问题并不好改,如果也是使用yolov5_rotated代码训练,建议对照yolov5的训练代码,可以发现这个rotated框架在优化器部分的代码写的不完善,而且预训练判断部分的代码也有问题,即使成功训练起来,epoch数量也不是从0,1开始,所以主要还是框架考虑不够完善。
有效的解决办法:
建议把 if pretrained 部分对应 optimizer 和 ema 部分的代码注释掉,这样在训练时可以不受optimizer和ema的干扰
想完善一点的,参考yolov5的train.py 改一下相关代码,完全可以解决问题并训练起来