view()函数作用和reshape函数类似,就是对tensor的shape进行调整,可以通过view函数将tensor的shape调整成一个你希望的样子。
import torch
torch.manual_seed(2)
a=torch.randn(4,5)
print(a)
print(a.view(-1,2)) # 此时 -1,代表默认值,代表根据后面的列数来计算行数
'''
a
=
tensor([[-1.0408, 0.9166, -1.3042, -1.1097, 0.0299],
[-0.0498, 1.0651, 0.8860, -0.8110, 0.6737],
[-1.1233, -0.0919, 0.1405, 1.1191, 0.3152],
[ 1.7528, -0.7396, -1.2425, -0.1752, 0.6990]])
a.view(-1,2)
=
tensor([[-1.0408, 0.9166],
[-1.3042, -1.1097],
[ 0.0299, -0.0498],
[ 1.0651, 0.8860],
[-0.8110, 0.6737],
[-1.1233, -0.0919],
[ 0.1405, 1.1191],
[ 0.3152, 1.7528],
[-0.7396, -1.2425],
[-0.1752, 0.6990]])
'''
此时通过view调整,将一个 45,一共20个数字,变成了 102 (10为 view中 -1的默认数,代表按照后面的列数来计算的行数)
在训练神经网络时,经常会遇到这样的一段代码
x = x.view(x.size(0), -1)
x.size(0)指batchsize的值。这句话的出现就是为了将前面多维度的tensor展平成一维,然后再输入给 nn.Linear()结构,-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。