pytorch 中的 backward()

今天在学 pytorch 反向传播时发现 backward() 函数是可以往里面传参的,于是仔细查了一下这个函数及其参数是干什么的。

github上有大牛分析如下:

https://sherlockliao.github.io/2017/07/10/backward/

这里再简单总结一下。

如果 backward() 没有参数,调用 backward() 函数的变量必须是一个标量,即形状为(1,),否则就会报错。这时候其实相当于传参 backward(torch.Tensor([1.0])),这里的 1.0 可以理解为最后每个梯度要乘以的步长。

如果 backward() 有参数,比如 backward(torch.FloatTensor([0.1, 1.0, 0.001])),那么此时调用 backward() 函数的变量可以不是标量,也可以是向量,但是要注意的是,这个向量的维度必须和 backward() 的参数向量的维度相同,其实此时就相当于函数有多个输出,每个输出都要算一个梯度,并且每个输出算出来的梯度的步长对应于 backward() 的参数向量对应的维度,且最终某个输入 x 的梯度是所有这些输出针对 x 算出的梯度的加权求和(权值向量就是 backward 的参数向量)。
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值