nn.Module.load_state_dict()不能直接load tensor的问题
load_state_dict()函数的作用是将已有权重载入相应模型,载入的前提条件是被载入的权重的size与对应模型所需权重的size相同。
比如现有一个model如下:
model = nn.Conv2d(in_channels=3, kernel_size=3, out_channels=32, stride=1, bias=False)
它需要的权重的size是[32, 3, 3, 3]的
现有权重weight,需要将其载入model
则model.load_state_dict(weight)
weight的size必须也是[32, 3, 3, 3]
此外,对weight的type也有要求:
如果weight的type是Tensor,则会报错:
原因如上图,因为Tensor型变量没有copy属性,所以不能作为load_state_dict()函数的参数
而正如该函数中所示的‘dict’所表明,该函数所调用的对象类型应该为字典类型,
用dir()函数查看字典类型的属性:
可以看见字典类型的数据有copy这一属性
因此,如果现有Tensor类型的权重weight
需要在调用用load_state_dict()前完成Tensor到dict的转化
weight = {‘weight’ : weight}