记录一下,pytorch中的view()函数

x = torch.range(1, 5).view(1, 1, 2, 2)

有时候会纳闷这个view()传入四个参数到底是什么,博客上大多数都是讲传入1个或者两个参数。

这里我查看了一下官方文档和自己的测试发现,第一个参数类似于一个标志位,根据这个来计算后面tensor的大小

x = torch.randn(4, 4)
print(x.size())
y = x.view(16)
print(y.size())
z = x.view(-1, 8)  # -1表示该维度取决于其它维度大小,即(4*4)/ 8
print(z.size())
m = x.view(2, 2, 4) #如果传入2,则后面只能是16/2,也就是后面的tensor维度的几乘几只能为8
print(m)
print(m.size())
t = x.view(2, 2, 2, 2)
print(t)
print(t.size())

 可以简单的理解为第一个是用来规范后面的size。拿图上的例子t来说,第一个位置传入了2,那么后面的尺寸只能是4*4/2=8.那么我们后面三个位置传入的参数只能是2*2*2或者1*1*8或者2*1*4等,只能是8的公约数。具体想要什么尺寸就看你的要求[batch_size, channel, h, w]

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值