pytorch复制模型

很多时候,我们需要把模型copy一份,两份模型用于不同的用途(e.g., 分别训练、teacher / student model)。虽然torch并没有提供类似于model.clone()这种接口,但是这个功能可以简单通过copy.deepcopy()实现。如下图所示:
在这里插入图片描述

copy.deepcopy()不仅可以把原先的模型的所有参数,原模原样复制一份,连device也照样能够复制。并且各自优化,互不干扰。

验证如下:

>>> import torch
>>> import copy 
>>> m=torch.nn.Linear(3,3).to("cuda:4")
>>> mc=copy.deepcopy(m) ## deepcopy original model parameters

>>> t1=torch.randn(3,3)  # for model 'm'
>>> t1=t1.to("cuda:4")
>>> t1
tensor([[-0.6198,  0.2503,  0.9287],
        [ 0.6553, -0.6422, -2.0498],
        [-0.7867, -0.6862,  1.9102]], device='cuda:4')

>>> t2=torch.randn(3,3)  # for model 'mc'
>>> t2=t2.to("cuda:4")
>>> t2
tensor([[ 0.9616,  1.1679, -0.3201],
        [ 0.6383, -0.4115, -1.5540],
        [ 0.6649,  0.8439, -1.3090]], device='cuda:4')

>>> out1=m(t1)
>>> loss1=torch.sum(out1)
# we can find that the loss1 can backpropogate to m
>>> torch.autograd.grad(loss1,list(m.parameters())[0],retain_graph=True)
(tensor([[-0.7512, -1.0782,  0.7891],
        [-0.7512, -1.0782,  0.7891],
        [-0.7512, -1.0782,  0.7891]], device='cuda:4'),)
# but it can not bw backpropogated to mc
>>> torch.autograd.grad(loss1,list(mc.parameters())[0],retain_graph=True)
Traceback (most recent call last):
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

# the following results are similar to the conclusions above
>>> out2=mc(t2)
>>> loss2=torch.sum(out2)
>>> torch.autograd.grad(loss2,list(m.parameters())[0],retain_graph=True)
Traceback (most recent call last):
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
>>> torch.autograd.grad(loss2,list(mc.parameters())[0],retain_graph=True)
(tensor([[ 2.2648,  1.6003, -3.1831],
        [ 2.2648,  1.6003, -3.1831],
        [ 2.2648,  1.6003, -3.1831]], device='cuda:4'),)

参考:
torch forum - Can I deepcopy a model?

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值