PyTorch 中的 tensor 及使用

这篇文章主要是围绕 PyTorch 中的 tensor 展开的,讨论了张量的求导机制,在不同设备之间的转换,神经网络中权重的更新等内容。面向的读者是使用过 PyTorch 一段时间的用户。本文中的代码例子基于 Python 3 和 PyTorch 1.1,如果文章中有错误或者没有说明白的地方,欢迎在评论区指正和讨论。

文章具体内容分为以下6个部分:

  1. tensor.requires_grad
  2. torch.no_grad()
  3. 反向传播及网络的更新
  4. tensor.detach()
  5. CPU and GPU
  6. tensor.item()

1. requires_grad

当我们创建一个张量 (tensor) 的时候,如果没有特殊指定的话,那么这个张量是默认是不需要求导的。我们可以通过 tensor.requires_grad 来检查一个张量是否需要求导。

在张量间的计算过程中,如果在所有输入中,有一个输入需要求导,那么输出一定会需要求导;相反,只有当所有输入都不需要求导的时候,输出才会不需要 [1]

举一个比较简单的例子,比如我们在训练一个网络的时候,我们从 DataLoader 中读取出来的一个 mini-batch 的数据,这些输入默认是不需要求导的,其次,网络的输出我们没有特意指明需要求导吧,Ground Truth 我们也没有特意设置需要求导吧。这么一想,哇,那我之前的那些 loss 咋还能自动求导呢?其实原因就是上边那条规则,虽然输入的训练数据是默认不求导的,但是,我们的 model 中的所有参数,它默认是求导的,这么一来,其中只要有一个需要求导,那么输出的网络结果必定也会需要求的。来看个实例:

input = torch.randn(8, 3, 50, 100)
print(input.requires_grad)
# False

net = nn.Sequential(nn.Conv2d(3, 16, 3, 1),
                    nn.Conv2d(16, 32, 3, 1))
for param in net.named_parameters():
    print(param[0], param[1].requires_grad)
# 0.weight True
# 0.bias True
# 1.weight True
# 1.bias True

output = net(input)
print(output.requires_grad)
# True

诚不欺我!但是,大家请注意前边只是举个例子来说明。在写代码的过程中,不要把网络的输入和 Ground Truth 的 requires_grad 设置为 True。虽然这样设置不会影响反向传播,但是需要额外计算网络的输入和 Ground Truth 的导数,增大了计算量和内存占用不说,这些计算出来的导数结果也没啥用。因为我们只需要神经网络中的参数的导数,用来更新网络,其余的导数都不需要。

好了,有个这个例子做铺垫,那么我们来得寸进尺一下。我们试试把网络参数的 requires_grad 设置为 False 会怎么样,同样的网络:

input = torch.randn(8, 3, 50, 100)
print(input.requires_grad)
# False

net = nn.Sequential(nn.Conv2d(3, 16, 3, 1),
                    nn.Conv2d(16, 32, 3, 1))
for param in net.named_parameters():
    param[1].requires_grad = False
    print(param[0], param[1].requires_grad)
# 0.weight False
# 0.bias False
# 1.weight False
# 1.bias False

output = net(input)
print(output.requires_grad)
# False

这样有什么用处&

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

芝麻开花666

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值