Intro: When I was processing the data set and practicing how to change the data by dimension, the “dim” confused me. 😯
Start: Let’s read the codes. 🚶Congratulations, if you know why it’s happened. 🎆
import torch
a = torch.ones(
(2, 5, 4)
)
print(a.shape)
print("value of scalar: ",a.sum())
print("scalar: ",a.sum().shape)
print("a: ", a)
print("axis=1", a.sum(axis=1))
print("axis=1, keepdims=True: \n", a.sum(axis=1, keepdims=True))
print("axis=2", a.sum(axis=2))
print("axis=2, keepdims=True: \n", a.sum(axis=2, keepdims=True))
print("axis=0", a.sum(axis=0))
print("axis=0, keepdims=True: \n", a.sum(axis=0, keepdims=True))
print("axis=[0, 2]", a.sum(axis=[0, 2]))
print("axis=[0, 2], keepdims=True: \n", a.sum(axis=[0, 2], keepdims=True))