torch.view(a , b , c) //表示三维
//其中a表示有多少组(b , c)
torch.view(a , b , c)[... , d] //表示取其中d列的所有行
例子:
a = torch.arange(12)
print(a)
b = a.view(3 , 2 , 2)
print(b)
b = a.view(3 , 2 , 2)[... , 0]
输出:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
tensor([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]]])
tensor([[ 0, 2],
[ 4, 6],
[ 8, 10]])