记忆要点:
dim = 0 说明是处理行
dim = 1 说明是处理列
keepdim = True 保留处理的行/列的特征
keepdim = False 不保留处理的行/列的特征
网上流传的版本有很多,但是我们根据结果来说话。我的理解是哪个维度发生了变化就是处理的是哪个维度。
if __name__ == "__main__":
#模型参数初始化
num_input = 784
num_output = 10
W = torch.tensor(np.random.normal(0,0.1,(num_input,num_output)),dtype=torch.float32)
b = torch.tensor(num_output,dtype=torch.float32)
W.requires_grad_(requires_grad = True)
b.requires_grad_(requires_grad = True)
#多维Tensor按维度操作
X = torch.tensor([[1,2,3],[4,5,6]])
print(X.sum(dim = 0,keepdim = True)) # dim为0,将同一列中所有行相加,并在结果中保留行特征 1,3
print(X.sum(dim = 1,keepdim = True)) # dim为1,同一行中所有列相加,并在结果中保留列特征 2,1
print(X.sum(dim = 0,keepdim = False))# dim为0,将同一列中所有行相加,并在结果中不保留行特征 3
print(X.sum(dim = 1,keepdim = False)) # dim为1,同一行中所有列相加,并在结果中不保留行特征 2
#
tensor([[5, 7, 9]])
tensor([[ 6],
[15]])
tensor([5, 7, 9])
tensor([ 6, 15])
#