怎样判断一个操作是否是可导的(Pytorch) & 不可导的操作的影响

在Pytorch中,看一个操作是否可导,即经过这个操作梯度是否还能顺利传递。

 可以看到,经过+操作后得到的z,仍能保持梯度的传递

而像torch.argmax(),  torch.eq() 这些操作就不行了,这些操作就是不可导的

遇到不可导的操作并不一定会导致训练报错,但是可能会导致那一个分支无法更新网络,使得加上之后网络性能却没啥提升

 

而如果像soft argmax

import torch
import torch.nn as nn

def soft_argmax(x):
	"""
	Arguments: voxel patch in shape (batch_size, channel, H, W, depth)
	Return: 3D coordinates in shape (batch_size, channel, 3)
	"""
	# alpha is here to make the largest element really big, so it
	# would become very close to 1 after softmax
	alpha = 10000.0 
	N,C,L = x.shape
	soft_max = nn.functional.softmax(x*alpha,dim=2)
	soft_max = soft_max.view(x.shape)
	indices_kernel = torch.arange(start=0, end=L).unsqueeze(0)
	# indices_kernel = indices_kernel.view((H,W,D))
	# indices_kernel = indices_kernel.view(H,W)
	conv = soft_max*indices_kernel
	indices = conv.sum(2)
	# z = indices%D
	# y = (indices).floor()%W
	# x = (((indices).floor())/W).floor()%H
	# coords = torch.stack([x,y,z],dim=2)
	# coords = torch.stack([x,y],dim=2)
    #coords[0][0]代表第一个channel的最大点的坐标值
    #coords[0][1]代表第2个channel的最大点的坐标值
	return indices
 
if __name__ == "__main__":
	x = torch.randn(1024,16,35*35,requires_grad=True) # (batch_size, channel, H, W, depth)
	coords = soft_argmax(x)
    #coords是[b,c,2]
	print(coords)

 操作就是可导的

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值