神经网络view(),torch.flatten(),torch.nn.Flatten()
在神经网络中经常看到view(),torch.flatten(),torch.nn.Flatten()这几个方法。这几个方法一般用于改变tensor的形状。为日后方便使用下面就一一透彻的理解一下。
1、view()
view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),w代表的是列(想要变为几列)#这里所说并不严谨,只是为了更好理解,
view()的参数 | 作用 |
---|---|
h | 取值代表行数,当不知道要变为几行,但知道要变为几列时可取-1 |
w | 取值代表列数,当不知道要变为几列,但知道要变为几行时可取-1 |
注意:元素个数要能整除行和列|
下面看几个例子就理解了。
1、把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor
import torch
a=torch.Tensor([[[1,2,3],[4,5,6],[7,8,9]]])
b=torch.Tensor([1,2,3,4,5,6,7,8,9])
#结果:
torch.Size([1, 3, 3])
tensor([[[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]]])
torch.Size([9])
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
#a是[1,3,3]的tensor向量:
b是[9]的tensor向量
a1 = a.view(3,-1)
b1 = b.view(3,-1)
#a1和b1的结果:
torch.Size([3, 3])
tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
torch.Size([3, 3])
tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
结果一样
2、当知道要变成的tensor的行时:
import torch
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=a.view(3,-1)
#a的结果:2行3列
tensor([[[1., 2., 3.],
[4., 5., 6.]]])
#b的结果:变为了3行2列
tensor([[1., 2.],
[3., 4.],
[5., 6.]])
2、当知道要变成的tensor的列时:
import torch
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=a.view(-1,1)
#a的结果:2行三列
tensor([[[1., 2., 3.],
[4., 5., 6.]]])
#b的结果:变为了6行1列
tensor([[1.],
[2.],
[3.],
[4.],
[5.],
[6.]])
2、torch.nn.Flatten()
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim与end_dim代表合并的维度,开始的默认值为1,结束的默认值为-1,因此常被使用在神经网络当中,将每个batch的数据拉伸成一维。
下面举几个例子:
1、默认参数时:
import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten()
a1 = F(a)
a的大小:
torch.Size([8, 3, 64, 64])
a1的大小:
torch.Size([8, 12288])
默认将第0维保留下来,其余拍成一维
2、有一个参数时(一个参数代表开始合并的维度):
import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten(2)
a1 = F(a)
a的大小:
torch.Size([8, 3, 64, 64])
a1的大小:
torch.Size([8, 3, 4096])
从第二维开始,拍成一维
3、有两个参数时(前一个参数代表开始合并的维度,后一个参数代表结束合并的维度)
import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten(1,2)
a1 = F(a)
a的大小:
torch.Size([8, 3, 64, 64])
a1的大小:
torch.Size([8, 192, 64])
将第一维到第二维拍成一维,其余不变
3、torch.flatten()
与 torch.nn.flatten 类似,都是用于展平 tensor 的,但是torch.flatten默认是从0开始的。
torch.flatten(t, start_dim=0, end_dim=-1)
t表示的时要展平的tensor,start_dim是开始展平的维度,end_dim是结束展平的维度
这里只举一个例子,其余与torch.nn.Flatten()是一样的。
import torch
a = torch.randn(8,3,64,64)
F = torch.flatten(a)
a的大小:
torch.Size([8, 3, 64, 64])
F的大小(默认从第0维展平):
torch.Size([98304])