关于.state_dict()和.load_state_dict()分别是浅拷贝和深拷贝的问题


参考链接:https://blog.csdn.net/qq_43827595/article/details/115324118

https://blog.csdn.net/qq_41596915/article/details/122005538

关于深拷贝和浅拷贝的参考文章:

https://blog.csdn.net/weixin_50829653/article/details/127675849?spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7ERate-2-127675849-blog-127951831.235%5Ev43%5Epc_blog_bottom_relevance_base1&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7ERate-2-127675849-blog-127951831.235%5Ev43%5Epc_blog_bottom_relevance_base1&utm_relevant_index=3

https://blog.csdn.net/weixin_44133119/article/details/123307291

  • model.state_dict()实际上是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。model这个对象实际上是指向各个参数矩阵的,而浅拷贝只会拷贝最外层的这些“指针”。
  • model.load_state_dict(xxx) 是深拷贝

原理图:

model.state_dict()["conv1.weight"]指向tensor1

  1. model.state_dict()["conv1.weight"].copy_(...)原地改变了tensor1的值,原model也受到影响
  2. 相当于改变了箭头的指向,原model没有受到影响
  3. model.state_dict()["conv1.weight"][0]=torch.tensor(0)相当于改变了第二层参数箭头的指向,原model受到影响
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=18, kernel_size=3, stride=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x

1、model.state_dict()["conv1.weight"].copy_(...)

a = Tudui()
m = a.state_dict()
for k, v in m.items():
    m[k].copy_(torch.zeros_like(v))
    break
f = a.state_dict()

可见原模型参数被改变了

2、model.state_dict()["conv1.weight"]=torch.tensor(0)

a = Tudui()
m = a.state_dict()
for k, v in m.items():
    m[k] = torch.tensor(0)
    break
f = a.state_dict()

可见原模型参数没有被改变

3、model.state_dict()["conv1.weight"][0]=torch.tensor(0)
 

a = Tudui()
m = a.state_dict()
for k, v in m.items():
    m[k][0] = torch.tensor(0)
    break
f = a.state_dict()

可见原模型参数被改变了

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值