dim维度
dim=0代表是列,dim=1代表是行
import torch
a = [[1,3,5],
[2,4,6],
[7,8,9]]
a = torch.tensor(a).float()
t = a.mean(dim=0) #dim=0代表是列
print(t)
输出结果(列求均值):
t = a.mean(dim=1) # dim=1代表是行
print(t)
输出结果(行求均值):
import torch
a = [[1,3,5],
[2,4,6],
[7,8,9]]
a = torch.tensor(a).float()
t = a.mean(dim=0) #dim=0代表是列
print(t)
输出结果(列求均值):
t = a.mean(dim=1) # dim=1代表是行
print(t)
输出结果(行求均值):