在Pytorch中,看一个操作是否可导,即经过这个操作梯度是否还能顺利传递。
可以看到,经过+操作后得到的z,仍能保持梯度的传递
而像torch.argmax(), torch.eq() 这些操作就不行了,这些操作就是不可导的
遇到不可导的操作并不一定会导致训练报错,但是可能会导致那一个分支无法更新网络,使得加上之后网络性能却没啥提升
而如果像soft argmax
import torch import torch.nn as nn def soft_argmax(x): """ Arguments: voxel patch in shape (batch_size, channel, H, W, depth) Return: 3D coordinates in shape (batch_size, channel, 3) """ # alpha is here to make the largest element really big, so it # would become very close to 1 after softmax alpha = 10000.0 N,C,L = x.shape soft_max = nn.functional.softmax(x*alpha,dim=2) soft_max = soft_max.view(x.shape) indices_kernel = torch.arange(start=0, end=L).unsqueeze(0) # indices_kernel = indices_kernel.view((H,W,D)) # indices_kernel = indices_kernel.view(H,W) conv = soft_max*indices_kernel indices = conv.sum(2) # z = indices%D # y = (indices).floor()%W # x = (((indices).floor())/W).floor()%H # coords = torch.stack([x,y,z],dim=2) # coords = torch.stack([x,y],dim=2) #coords[0][0]代表第一个channel的最大点的坐标值 #coords[0][1]代表第2个channel的最大点的坐标值 return indices if __name__ == "__main__": x = torch.randn(1024,16,35*35,requires_grad=True) # (batch_size, channel, H, W, depth) coords = soft_argmax(x) #coords是[b,c,2] print(coords)
操作就是可导的
怎样判断一个操作是否是可导的(Pytorch) & 不可导的操作的影响
于 2022-01-22 16:57:48 首次发布