class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
self.conv=torch.nn.Conv2d(2, 4, 1)
a = A()
a.state_dict().key()
打印:
odict_keys(['conv.weight', 'conv.bias'])
再执行:
b={}
for key in a.state_dict().keys():
if 'weight' in key:
w1, w2 = a.state_dict()[key].chunk(2, 1)
(b[key + '.w11'], b[key + '.w12'], b[key + '.w13'], b[key + '.w14']) =w1.chunk(4, 0)
(b[key + '.w21'], b[key + '.w22'], b[key + '.w23'], b[key + '.w24']) = w1.chunk(4, 0)
c = torch.Tensor([1000])
b['conv.weight.w11'].copy_(c)
print(a.state_dict())
打印:
OrderedDict([('conv.weight', tensor([[[[ 1.0000e+03]],
[[ 2.5883e-01]]],
[[[ 1.0379e-01]],
[[ 3.4200e-01]]],
[[[ 4.6878e-01]],
[[ 6.9393e-01]]],
[[[-4.8772e-01]],
[[-2.3880e-01]]]])), ('conv.bias', tensor([-0.6252, -0.5940, 0.5863, -0.1846]))])
我们发现c的值已经拷贝到模型a的参数里面了,这里b[‘conv.weight.w11’].copy_©不能写成b[‘conv.weight.w11’]=c,这样就会出错。