一、x = x.view()
x = x.view(x.size(0), -1)
在PyTorch中,x.view(x.size(0), -1)
是一种常用的操作,用于改变张量(Tensor)的形状而不改变其数据。这里的x
是一个多维张量,而.view()
函数是用来重新塑形这个张量的,同时保持其元素的总数不变。
具体来说,x.view(x.size(0), -1)
的含义是:
x.size(0)
:这部分获取了张量x
的第一个维度的大小(即,如果x
是一个形状为(a, b, c)
的张量,那么x.size(0)
就等于a
)。在大多数情况下,这代表了批处理中的样本数或者是序列的长度,取决于上下文。-1
:在.view()
函数中,-1
是一个特殊的值,表示该维度的大小会自动计算,以便保持总元素数不变。换句话说,PyTorch会根据其他维度的大小和总元素数来推断出-1
应该代表的具体数值。
因此,x.view(x.size(0), -1)
的作用是将张量x
重新塑形为一个二维张量,其中第一维的大小保持不变(即原始张量的第一个维度的大小),而第二维的大小则自动调整,以包含所有剩余的元素。这种操作在需要将多维数据“展平”为二维数据以进行某些操作(如全连接层)时非常有用。
例如,如果x
是一个形状为(64, 3, 28, 28)
的张量(通常表示一个包含64个图像,每个图像有3个颜色通道,每个通道的大小为28x28像素的数据集),那么x.view(x.size(0), -1)
将会把x
重新塑形为一个形状为(64, 3*28*28)
的张量,其中每个样本都被展平成了一个长向量。
二、torch.flatten()函数
x = torch.flatten(x, start_dim=0, end_dim=2)
x = torch.flatten(x, 0)
当你使用 x = torch.flatten(x, 0)
时,这里的 0
是 start_dim
参数的值,而 end_dim
参数仍然默认为 -1
。这意呀着展平操作将从张量 x
的第一个维度(索引为0的维度)开始,并且一直进行到张量的最后一个维度。
然而,由于 start_dim
被设置为0,并且 end_dim
默认为 -1
,实际上这会将整个张量 x
完全展平为一个一维张量。换句话说,无论原始张量 x
的形状如何,调用 torch.flatten(x, 0)
后,x
将变成一个一维张量,其长度等于原始张量中所有元素的总数。
例如,如果原始张量 x
的形状是 (a, b, c, d)
,那么调用 x = torch.flatten(x, 0)
后,x
的新形状将是 (a*b*c*d,)
,即一个包含 a*b*c*d
个元素的一维张量。
这种完全展平的操作在需要将多维数据转换为适合某些特定操作(如完全连接层的前馈传播)的一维形式时非常有用。然而,它也意味着你丢失了原始数据的形状信息,除非你在其他地方记录了这些信息或者你的操作不需要保留这些形状信息。
x = torch.flatten(x, 1)
在PyTorch中,torch.flatten(x, start_dim=0, end_dim=-1)
函数用于将张量x
在指定的维度范围内展平(或扁平化),而不改变其数据。这里的start_dim
是开始展平的维度(包含该维度),end_dim
是结束展平的维度(不包含该维度),默认情况下end_dim
为-1,即最后一个维度。
当你使用x = torch.flatten(x, 1)
时,你告诉PyTorch从第二个维度(索引为1,因为索引是从0开始的)开始,一直到最后一个维度,将所有的这些维度都展平成一个维度。这意味着,如果x
是一个多维张量,那么除了第一个维度之外的所有维度都将被合并成一个维度。
例如,如果x
的形状是(64, 3, 28, 28)
(代表64个图像,每个图像有3个颜色通道,每个通道的大小为28x28像素),那么x = torch.flatten(x, 1)
将会把x
展平成一个形状为(64, 3*28*28)
的张量。这里,第一个维度(样本数64)保持不变,而剩下的三个维度(3, 28, 28)被合并成了一个维度。
这种操作在处理图像数据时特别有用,尤其是在需要将图像数据传递给全连接层之前,因为全连接层通常期望输入是二维的(尽管在实践中,通常会先通过一个或多个卷积层来处理图像数据)。通过展平操作,你可以将多维的图像数据转换成二维的形式,以便进行后续处理。
x = torch.flatten(x, 2)
在PyTorch中,torch.flatten(input, start_dim=0, end_dim=-1)
函数用于将多维张量(tensor)展平(flatten)为一维张量,但你可以通过指定start_dim
和end_dim
参数来控制从哪一维度开始展平,以及在哪一维度结束(不包括该维度)。这意味着你可以保留张量的某些维度不变,而将其他维度展平。
对于你的代码 x = torch.flatten(x, 2)
,这里:
x
是你想要展平的原始张量。2
是start_dim
参数的值,而end_dim
参数默认为-1
,表示展平操作会一直进行到张量的最后一个维度。
因此,torch.flatten(x, 2)
的意思是从张量x
的第3维(因为索引从0开始)开始,将之后的所有维度都展平成一个维度。如果x
的形状是例如 (a, b, c, d, e)
,那么torch.flatten(x, 2)
之后,x
的形状将变为 (a, b, c*d*e)
。这里,a
和b
维度保持不变,而c
、d
和e
三个维度被合并成了一个新的维度。
这种操作在处理多维数据时非常有用,特别是当你需要将一部分数据的维度保持不变,而将其他部分数据“展平”以便于后续处理(如全连接层处理)时。
三、示例
import torch
A = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[13,14,15,16],[17,18,19,20],[21,22,23,24]]])
# print(A.size)
print(A.shape)
B = torch.flatten(A,1)
print(B.shape)
# print(B)
C = torch.flatten(A,0,1)
print(C.shape)
# print(C)
D = torch.flatten(A,2)
print(D.shape)
# print(D)
E = torch.flatten(A,0)
print(E.shape)
# print(E)
F = A.view(A.size(0), -1)
print(F.shape)
# print(F)
G = A.view(A.size(0), -1, 1)
print(G.shape)
输出:
torch.Size([2, 3, 4])
torch.Size([2, 12])
torch.Size([6, 4])
torch.Size([2, 3, 4])
torch.Size([24])
torch.Size([2, 12])
torch.Size([2, 12, 1])