涉及到多维tensor时,对softmax的参数dim总是很迷,下面用一个例子说明
import torch.nn as nn
m = nn.Softmax(dim=0)
n = nn.Softmax(dim=1)
k = nn.Softmax(dim=2)
input = torch.randn(2, 2, 3)
print(input)
print(m(input))
print(n(input))
print(k(input))
输出:
input
tensor([[[ 0.5450, -0.6264, 1.0446],
[ 0.6324, 1.9069, 0.7158]],
[[ 1.0092, 0.2421, -0.8928],
[ 0.0344, 0.9723, 0.4328]]])
dim=0
tensor([[[0.3860, 0.2956, 0.8741],
[0.6452, 0.7180, 0.5703]],
[[0.6140, 0.7044, 0.1259],
[0.3548, 0.2820, 0.4297]]])