softargmax可以作为argmax的近似。
因为argmax是不可导的,而softargmax是可导的
或者写作
思想就是使得softmax之后的结果,实现最大的那个很大,无限接近于1,而其他的都很小,无限接近于0,从而再乘上对应位置的index,就能实现把最大位置的index取出来的效果了
但是可以看到,位置计算不够准确(2.57 -> 3),一个原因就是最大值的概率不够大,或者说增大相对最大值而减弱其他值的影响就可以得到更加准确的位置坐标。
所以有了下面的式子
以beta=10为例子
就相当于除以温度τ
import torch import torch.nn as nn def soft_argmax(x, beta): """ Arguments: voxel patch in shape (batch_size, channel, H, W, depth) Return: 3D coordinates in shape (batch_size, channel, 3) """ #x: [bs,c,h,w] x = x.reshape(x.shape[0], x.shape[1], -1) #[bs,c,h*w] L = x.shape[2] soft_max = nn.functional.softmax(x*beta,dim=2) soft_max = soft_max.view(x.shape) indices = torch.arange(start=0, end=L).unsqueeze(0) soft_argmax = soft_max * indices indices = soft_argmax.sum(dim=2) #[bs,c] return indices if __name__ == "__main__": x = torch.randn(1024,16,35,35) #[bs, c, h, w] indices = soft_argmax(x, beta=10000.0) #[bs,c] print(indices)
Fdevmsy/PyTorch-Soft-Argmax: PyTorch implementation of Soft-Argmax in 1D/2D/3D (github.com)