今天聊一聊Flatten展平,torch、numpy等都有对应的方法

深度学习中,经常要用到Flatten去对输入数据的维度进行转换,所以就来聊一聊Flatten

直接上源码
在这里插入图片描述
torch官方的文档
在这里插入图片描述
这里提到的总结一下(好好品味这段话):

torch.nn.Flatten中是一个带有两个参数的方法,分别是start_dim和end_dim。默认start_dim的值为1,end_dim的值为-1。当然,我们也可以自己传参进去。start_dim表示要对原张量数据展平的第一个维度,end_dim则表示最后一个维度,且一般情况下end_dim>=start_dim(end_dim=-1除外)。

实践是最好的老师,讲再多都是纸上得来终觉浅。

导入torch

import torch

定义四个张量

a = torch.randn(100, 1, 28, 28)
k = torch.randn(100, 2, 28)
j = torch.randn(100, 2)
z = torch.randn(100)

定义Flatten层的模型

b = torch.nn.Flatten()

把张量放进去Flatten作为输入

print(b(a).shape)
print(b(k).shape)
print(b(j).shape)
print(b(z).shape)     # 报错

得到的结果是:

torch.Size([100, 784])
torch.Size([100, 56])
torch.Size([100, 2])
# 最后一个结果报错

可以看到每个张量第2个维度是的size都是相同的,a作为输入的张量,得到的输出是第二个维度size全为784的张量,k作为输入的张量得到的是第二个维度size全为56的张量,j作为输入的张量得到的是第二个维度size全为2的张量

每一个Flatten对应的处理就是

torch.nn.Flatten(start_dim=1, end_dim=3)
torch.nn.Flatten(start_dim=1, end_dim=2)
torch.nn.Flatten(start_dim=1, end_dim=1)
# 最后一个报错,无法找到start_dim=1,即第二个维度
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

刘先生的u写倒了

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值