pytorch中,关于Model的浅拷贝与深拷贝问题

本文探讨了PyTorch中模型拷贝的问题,包括直接赋值与使用copy.deepcopy的区别,以及如何正确地修改模型参数。通过实例演示了浅拷贝与深拷贝的效果,并解释了如何使用state_dict()进行参数的更新。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

今天写代码发现,pytorch中拷贝定义一个模型会出现以下问题:

1、使用model.state_dict()[name] = param没法修改模型参数
2、model1 = model2 修改model1会导致model2改变,说明pytorch中Model的拷贝是浅拷贝类型

验证问题:

初始化测试用例

简单定义了一个神经元(in=2, out=1)

import torch
import copy
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(in_features=2, out_features=1)
    def forward(self, input):
        x = torch.relu(self.linear_1(input))
        return x
初始化Model
model = Model()
model2 = Model()
model3 = model1
model4 = copy.deepdopy(model1)
model和model2

结果发现使用赋值符,不能修改model
使用copy_()可以修改model
关于copy_函数

print("model.state_dict()", model.state_dict())
print("model2.state_dict(): ", model2.state_dict())

# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1524, -0.4679]])), ('linear_1.bias', tensor([-0.5376]))])
# model2.state_dict():  OrderedDict([('linear_1.weight', tensor([[0.6001, 0.0914]])), ('linear_1.bias', tensor([0.5060]))])

for name, param in model2.state_dict().items():
    model2.state_dict()[name] = param * 2

print("model.state_dict()", model.state_dict())
print("model2.state_dict(): ", model2.state_dict())

# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1524, -0.4679]])), ('linear_1.bias', tensor([-0.5376]))])
# model2.state_dict():  OrderedDict([('linear_1.weight', tensor([[0.6001, 0.0914]])), ('linear_1.bias', tensor([0.5060]))])

for name, param in model2.state_dict().items():
    model2.state_dict()[name].copy_(param * 2)

print("model.state_dict()", model.state_dict())
print("model2.state_dict(): ", model2.state_dict())

# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1524, -0.4679]])), ('linear_1.bias', tensor([-0.5376]))])
# model2.state_dict():  OrderedDict([('linear_1.weight', tensor([[1.2001, 0.1827]])), ('linear_1.bias', tensor([1.0119]))])
model3和model4

结果显示修改model3,会导致model的改变
修改model4,不会导致model的改变

关于python浅拷贝和深拷贝

print("model.state_dict()", model.state_dict())
print("model3.state_dict(): ", model3.state_dict())
print("model4.state_dict(): ", model4.state_dict())

# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.0855,  0.5340]])), ('linear_1.bias', tensor([-0.4530]))])
# model3.state_dict():  OrderedDict([('linear_1.weight', tensor([[-0.0855,  0.5340]])), ('linear_1.bias', tensor([-0.4530]))])
# model4.state_dict():  OrderedDict([('linear_1.weight', tensor([[-0.0855,  0.5340]])), ('linear_1.bias', tensor([-0.4530]))])



for name, param in model.state_dict().items():
    model3.state_dict()[name].copy_(param * 2)


print("model.state_dict()", model.state_dict())
print("model3.state_dict(): ", model3.state_dict())
print("model4.state_dict(): ", model4.state_dict())

# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1710,  1.0680]])), ('linear_1.bias', tensor([-0.9061]))])
# model3.state_dict():  OrderedDict([('linear_1.weight', tensor([[-0.1710,  1.0680]])), ('linear_1.bias', tensor([-0.9061]))])
# model4.state_dict():  OrderedDict([('linear_1.weight', tensor([[-0.0855,  0.5340]])), ('linear_1.bias', tensor([-0.4530]))])



for name, param in model.state_dict().items():
    model4.state_dict()[name].copy_(param * 2)

print("model.state_dict()", model.state_dict())
print("model3.state_dict(): ", model3.state_dict())
print("model4.state_dict(): ", model4.state_dict())

# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1710,  1.0680]])), ('linear_1.bias', tensor([-0.9061]))])
# model3.state_dict():  OrderedDict([('linear_1.weight', tensor([[-0.1710,  1.0680]])), ('linear_1.bias', tensor([-0.9061]))])
# model4.state_dict():  OrderedDict([('linear_1.weight', tensor([[-0.3420,  2.1360]])), ('linear_1.bias', tensor([-1.8121]))])
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值