学习代码的第一个拦路虎,在pytorch中非常常用的维度转换。
首先了解一下tensor的size是怎么来的,几个中括号就说明有几个维度,然后看第一个中括号里用逗号分隔开了几个元素,就是第一个维度的值,其他依次类推。例如Tensor([[[1,2,3],[4,5,6]]])中第一个中括号里为[[1,2,3],[4,5,6]],只有一个元素,第二个中括号内为[1,2,3],[4,5,6],有两个元素,第三个为[1,2,3],有三个元素。
view()
view变换维度,把原先tensor中的数据按行优先的顺序排成一个一维数据(这里应该是因为要求地址是连续存储的),然后按照输入参数要求,组合成其他维度的tensor。例如:
a=torch.Tensor([[[1,2,3],[4,5,6]]])# ——> torch.Size([1, 2, 3])
print(a.view(3,2))# ——> torch.Size([3, 2])
#输出为:
tensor([[1