x = torch.range(1, 5).view(1, 1, 2, 2)
有时候会纳闷这个view()传入四个参数到底是什么,博客上大多数都是讲传入1个或者两个参数。
这里我查看了一下官方文档和自己的测试发现,第一个参数类似于一个标志位,根据这个来计算后面tensor的大小
x = torch.randn(4, 4)
print(x.size())
y = x.view(16)
print(y.size())
z = x.view(-1, 8) # -1表示该维度取决于其它维度大小,即(4*4)/ 8
print(z.size())
m = x.view(2, 2, 4) #如果传入2,则后面只能是16/2,也就是后面的tensor维度的几乘几只能为8
print(m)
print(m.size())
t = x.view(2, 2, 2, 2)
print(t)
print(t.size())
可以简单的理解为第一个是用来规范后面的size。拿图上的例子t来说,第一个位置传入了2,那么后面的尺寸只能是4*4/2=8.那么我们后面三个位置传入的参数只能是2*2*2或者1*1*8或者2*1*4等,只能是8的公约数。具体想要什么尺寸就看你的要求[batch_size, channel, h, w]