Implement SoftArgmax with Pytorch.
在编程时,有时候需要返回一个张量最大值所在的维度序号(如分类任务中返回概率最大的类别编号、定位任务中返回概率最大的空间坐标编号),此时需要用到argmax操作。
Pytorch中的argmax函数定义为torch.argmax(input, dim=None, keepdim=False),其中的dim参数指定寻找最大值的维度,keepdim参数指定是否保持原张量的维度。
如一个尺寸为(3,4,5)(3,4,5)(3,4,

本文介绍了在Pytorch中如何实现SoftArgmax操作,作为argmax的可微分替代方案,用于在网络中进行反向传播。详细阐述了argmax函数的工作原理,并提供了SoftArgmax在三维张量上的应用实例。
订阅专栏 解锁全文
3260





