深度学习中,经常要用到Flatten去对输入数据的维度进行转换,所以就来聊一聊Flatten
直接上源码
torch官方的文档
这里提到的总结一下(好好品味这段话):
torch.nn.Flatten中是一个带有两个参数的方法,分别是start_dim和end_dim。默认start_dim的值为1,end_dim的值为-1。当然,我们也可以自己传参进去。start_dim表示要对原张量数据展平的第一个维度,end_dim则表示最后一个维度,且一般情况下end_dim>=start_dim(end_dim=-1除外)。
实践是最好的老师,讲再多都是纸上得来终觉浅。
导入torch
import torch
定义四个张量
a = torch.randn(100, 1, 28, 28)
k = torch.randn(100, 2, 28)
j = torch.randn(100, 2)
z = torch.randn(100)
定义Flatten层的模型
b = torch.nn.Flatten()
把张量放进去Flatten作为输入
print(b(a).shape)
print(b(k).shape)
print(b(j).shape)
print(b(z).shape) # 报错
得到的结果是:
torch.Size([100, 784])
torch.Size([100, 56])
torch.Size([100, 2])
# 最后一个结果报错
可以看到每个张量第2个维度是的size都是相同的,a作为输入的张量,得到的输出是第二个维度size全为784的张量,k作为输入的张量得到的是第二个维度size全为56的张量,j作为输入的张量得到的是第二个维度size全为2的张量
每一个Flatten对应的处理就是
torch.nn.Flatten(start_dim=1, end_dim=3)
torch.nn.Flatten(start_dim=1, end_dim=2)
torch.nn.Flatten(start_dim=1, end_dim=1)
# 最后一个报错,无法找到start_dim=1,即第二个维度