每天学点pytorch--torch.nn.ReLU(inplace=False)中inplace的作用

记录pytorch中遇到的一些问题,文章没有顺序关系

官方连接:

ReLU — PyTorch 1.10.0 documentation

inplace为True时,计算结果会对原来的结果进行覆盖。

还是看下pytorch中的具体操作:

>>> import torch
>>> import torch.nn as nn
>>> conv1 = nn.Conv2d(3, 3, kernel_size=3)
>>> rl1 = nn.ReLU(inplace=True)
>>> rl2 = nn.ReLU()
>>> input = torch.randn(1,3,5,5)
>>> o1 = conv1(input)
>>> id(o1)
139670453299872
>>> o1
tensor([[[[-0.1162,  0.5905,  1.0601],
          [-0.1423,  0.7013,  0.1079],
          [ 0.1096, -0.3253, -0.6799]],

         [[ 0.3407,  0.5013, -0.2121],
          [-0.6805, -0.8362,  0.3360],
          [ 1.1606, -0.2564,  0.2965]],

         [[ 0.4317, -0.2480,  0.2381],
          [-0.0314, -0.0850,  0.1920],
          [-0.2762,  0.0338, -0.2298]]]], grad_fn=<ThnnConv2DBackward>)
>>> h1 = rl1(o1)
>>> id(h1)
139670453299872 #和o1的id一样,说明h1和o1指向同一个地方
>>> o1 # o1的值发生了变化,inplace操作起了作用
tensor([[[[0.0000, 0.5905, 1.0601],
          [0.0000, 0.7013, 0.1079],
          [0.1096, 0.0000, 0.0000]],

         [[0.3407, 0.5013, 0.0000],
          [0.0000, 0.0000, 0.3360],
          [1.1606, 0.0000, 0.2965]],

         [[0.4317, 0.0000, 0.2381],
          [0.0000, 0.0000, 0.1920],
          [0.0000, 0.0338, 0.0000]]]], grad_fn=<ReluBackward1>)
>>> h1
tensor([[[[0.0000, 0.5905, 1.0601],
          [0.0000, 0.7013, 0.1079],
          [0.1096, 0.0000, 0.0000]],

         [[0.3407, 0.5013, 0.0000],
          [0.0000, 0.0000, 0.3360],
          [1.1606, 0.0000, 0.2965]],

         [[0.4317, 0.0000, 0.2381],
          [0.0000, 0.0000, 0.1920],
          [0.0000, 0.0338, 0.0000]]]], grad_fn=<ReluBackward1>)

从上面的操作可以看出,如果采用inplace操作,输入参数o1的结果被直接修改。

下面看下inplace为False时的结果:

>>> o1 = conv1(input)
>>> id(o1)
139670453299712
>>> o1
tensor([[[[-0.1162,  0.5905,  1.0601],
          [-0.1423,  0.7013,  0.1079],
          [ 0.1096, -0.3253, -0.6799]],

         [[ 0.3407,  0.5013, -0.2121],
          [-0.6805, -0.8362,  0.3360],
          [ 1.1606, -0.2564,  0.2965]],

         [[ 0.4317, -0.2480,  0.2381],
          [-0.0314, -0.0850,  0.1920],
          [-0.2762,  0.0338, -0.2298]]]], grad_fn=<ThnnConv2DBackward>)
>>> h1 = rl2(o1) # relu的inplace为False
>>> id(h1) #h1和o1的id不一样了
139670453330560
>>> o1 #查看o1的值,发现没有改变
tensor([[[[-0.1162,  0.5905,  1.0601],
          [-0.1423,  0.7013,  0.1079],
          [ 0.1096, -0.3253, -0.6799]],

         [[ 0.3407,  0.5013, -0.2121],
          [-0.6805, -0.8362,  0.3360],
          [ 1.1606, -0.2564,  0.2965]],

         [[ 0.4317, -0.2480,  0.2381],
          [-0.0314, -0.0850,  0.1920],
          [-0.2762,  0.0338, -0.2298]]]], grad_fn=<ThnnConv2DBackward>)
>>> h1 #h1是经过了relu操作后的结果
tensor([[[[0.0000, 0.5905, 1.0601],
          [0.0000, 0.7013, 0.1079],
          [0.1096, 0.0000, 0.0000]],

         [[0.3407, 0.5013, 0.0000],
          [0.0000, 0.0000, 0.3360],
          [1.1606, 0.0000, 0.2965]],

         [[0.4317, 0.0000, 0.2381],
          [0.0000, 0.0000, 0.1920],
          [0.0000, 0.0338, 0.0000]]]], grad_fn=<ReluBackward0>)

inplace为False时,不修改输入的值,而是生成一个新的对象,符合预期。

采用原地操作可以节省内存,但是在多分支(Multi-branch)的网络中,使用时需要注意,比如:

conv1 = nn.Conv2d(3, 3, kernel_size=3)
conv2 = nn.Conv2d(3, 3, kernel_size=3)
rl1 = nn.ReLU(inplace=True)
...
x = conv1(x)
h1 = rl1(x)
h2 = conv2(x) # 此时x的值可能已经变化了

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值