今天写代码发现,pytorch中拷贝定义一个模型会出现以下问题:
1、使用model.state_dict()[name] = param没法修改模型参数
2、model1 = model2 修改model1会导致model2改变,说明pytorch中Model的拷贝是浅拷贝类型
验证问题:
初始化测试用例
简单定义了一个神经元(in=2, out=1)
import torch
import copy
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(in_features=2, out_features=1)
def forward(self, input):
x = torch.relu(self.linear_1(input))
return x
初始化Model
model = Model()
model2 = Model()
model3 = model1
model4 = copy.deepdopy(model1)
model和model2
结果发现使用赋值符,不能修改model
使用copy_()可以修改model
关于copy_函数
print("model.state_dict()", model.state_dict())
print("model2.state_dict(): ", model2.state_dict())
# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1524, -0.4679]])), ('linear_1.bias', tensor([-0.5376]))])
# model2.state_dict(): OrderedDict([('linear_1.weight', tensor([[0.6001, 0.0914]])), ('linear_1.bias', tensor([0.5060]))])
for name, param in model2.state_dict().items():
model2.state_dict()[name] = param * 2
print("model.state_dict()", model.state_dict())
print("model2.state_dict(): ", model2.state_dict())
# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1524, -0.4679]])), ('linear_1.bias', tensor([-0.5376]))])
# model2.state_dict(): OrderedDict([('linear_1.weight', tensor([[0.6001, 0.0914]])), ('linear_1.bias', tensor([0.5060]))])
for name, param in model2.state_dict().items():
model2.state_dict()[name].copy_(param * 2)
print("model.state_dict()", model.state_dict())
print("model2.state_dict(): ", model2.state_dict())
# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1524, -0.4679]])), ('linear_1.bias', tensor([-0.5376]))])
# model2.state_dict(): OrderedDict([('linear_1.weight', tensor([[1.2001, 0.1827]])), ('linear_1.bias', tensor([1.0119]))])
model3和model4
结果显示修改model3,会导致model的改变
修改model4,不会导致model的改变
print("model.state_dict()", model.state_dict())
print("model3.state_dict(): ", model3.state_dict())
print("model4.state_dict(): ", model4.state_dict())
# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.0855, 0.5340]])), ('linear_1.bias', tensor([-0.4530]))])
# model3.state_dict(): OrderedDict([('linear_1.weight', tensor([[-0.0855, 0.5340]])), ('linear_1.bias', tensor([-0.4530]))])
# model4.state_dict(): OrderedDict([('linear_1.weight', tensor([[-0.0855, 0.5340]])), ('linear_1.bias', tensor([-0.4530]))])
for name, param in model.state_dict().items():
model3.state_dict()[name].copy_(param * 2)
print("model.state_dict()", model.state_dict())
print("model3.state_dict(): ", model3.state_dict())
print("model4.state_dict(): ", model4.state_dict())
# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1710, 1.0680]])), ('linear_1.bias', tensor([-0.9061]))])
# model3.state_dict(): OrderedDict([('linear_1.weight', tensor([[-0.1710, 1.0680]])), ('linear_1.bias', tensor([-0.9061]))])
# model4.state_dict(): OrderedDict([('linear_1.weight', tensor([[-0.0855, 0.5340]])), ('linear_1.bias', tensor([-0.4530]))])
for name, param in model.state_dict().items():
model4.state_dict()[name].copy_(param * 2)
print("model.state_dict()", model.state_dict())
print("model3.state_dict(): ", model3.state_dict())
print("model4.state_dict(): ", model4.state_dict())
# model.state_dict() OrderedDict([('linear_1.weight', tensor([[-0.1710, 1.0680]])), ('linear_1.bias', tensor([-0.9061]))])
# model3.state_dict(): OrderedDict([('linear_1.weight', tensor([[-0.1710, 1.0680]])), ('linear_1.bias', tensor([-0.9061]))])
# model4.state_dict(): OrderedDict([('linear_1.weight', tensor([[-0.3420, 2.1360]])), ('linear_1.bias', tensor([-1.8121]))])