参考链接:https://blog.csdn.net/qq_43827595/article/details/115324118
https://blog.csdn.net/qq_41596915/article/details/122005538
关于深拷贝和浅拷贝的参考文章:
https://blog.csdn.net/weixin_44133119/article/details/123307291
- model.state_dict()实际上是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。model这个对象实际上是指向各个参数矩阵的,而浅拷贝只会拷贝最外层的这些“指针”。
- model.load_state_dict(xxx) 是深拷贝
原理图:
model.state_dict()["conv1.weight"]指向tensor1
- model.state_dict()["conv1.weight"].copy_(...)原地改变了tensor1的值,原model也受到影响
- 相当于改变了箭头的指向,原model没有受到影响
- model.state_dict()["conv1.weight"][0]=torch.tensor(0)相当于改变了第二层参数箭头的指向,原model受到影响
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=18, kernel_size=3, stride=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.relu(x)
return x
1、model.state_dict()["conv1.weight"].copy_(...)
a = Tudui()
m = a.state_dict()
for k, v in m.items():
m[k].copy_(torch.zeros_like(v))
break
f = a.state_dict()
可见原模型参数被改变了
2、model.state_dict()["conv1.weight"]=torch.tensor(0)
a = Tudui()
m = a.state_dict()
for k, v in m.items():
m[k] = torch.tensor(0)
break
f = a.state_dict()
可见原模型参数没有被改变
3、model.state_dict()["conv1.weight"][0]=torch.tensor(0)
a = Tudui()
m = a.state_dict()
for k, v in m.items():
m[k][0] = torch.tensor(0)
break
f = a.state_dict()
可见原模型参数被改变了