函数:unfold(), view(), reshape(), permute(), transpose(), flatten(), cat(), chunk(), split(),
stack(), take(), tile(), unbind(), squeeze(), unsqueeze(), where(), full(), cumprod(), gather()
文章目录
一、Tensor操作的函数
1.unfold()
unfold函数相当于一维滑动窗口的操作
unfold(dimension, size, step) -> Tensor
dimension->表示从哪个维度展开滑动
size->滑动窗口的大小
step->每次滑动的步长
对于一维的情况
代码如下(示例):
x = torch.arange(1., 5)
print(x)
x = x.unfold(0, 2, 1)
print(x)
print(x.shape)
最终输出的x是3个大小为2的窗口
输出结果如下:
tensor([1., 2., 3., 4.])
tensor([[1., 2.],
[2., 3.],
[3., 4.]])
torch.Size([3, 2])
对于二维的情况
连续两次使用unfold()函数,就变成了一个2维滑动窗口的效果
x = torch.randn(4, 3)
print(x)
x = x.unfold(0, 2, 1).unfold(1, 2, 1) # 获得了3×2个大小为2×2的窗口
print(x)
print(x.shape)
输出结果:
tensor([[-1.0882, -0.9543, 0.7692],
[ 1.0062, -0.5342, 2.0376],
[-1.2618, -0.2669, 0.0763],
[ 0.8875, -1.6452, -0.7613]])
tensor([[[[-1.0882, -0.9543],
[ 1.0062, -0.5342]],
[[-0.9543, 0.7692],
[-0.5342, 2.0376]]],
[[[ 1.0062, -0.5342],
[-1.2618, -0.2669]],
[[-0.5342, 2.0376],
[-0.2669, 0.0763]]],
[[[-1.2618, -0.2669],
[ 0.8875, -1.6452]],
[[-0.2669, 0.0763],
[-1.6452, -0.7613]]]])
torch.Size([3, 2, 2, 2])
2.view()
view()将tensor重塑为想要的shape。其操作是将张量展平成一维之后,再排列成想要的形状。
代码如下(示例):
x = torch.randn(4, 3)
print(x)
x = x.view((3, 4))
print(x)
x = x.view(-1) # 将x张量展平
print(x)
# 变形后不知道其中一个维度的大小,可用-1表示
x = x.view(2, -1, 2)
print(x.shape)
tips:view()函数,从3×4的形状变到4×3,这个操作与转置操作不同。
输出结果:
tensor([[ 0.2126, -0.8827, 1.2574],
[ 0.7362, 0.2389, -0.9048],
[ 0.6516, -1.2950, 0.7889],
[ 1.5935, 1.4653, 0.5844]])
tensor([[ 0.2126, -0.8827, 1.2574, 0.7362],
[ 0.2389, -0.9048, 0.6516, -1.2950],
[ 0.7889, 1.5935, 1.4653, 0.5844]])
tensor([ 0.2126, -0.8827, 1.2574, 0.7362, 0.2389, -0.9048, 0.6516, -1.2950,
0.7889, 1.5935, 1.4653, 0.5844])
torch.Size([2, 3, 2])
3.reshape()
reshape()与view()都是对张量shape进行重塑。
详细区别参考博客:PyTorch:view() 与 reshape() 区别详解
x = torch.randn(4, 3)
print(x)
x = x.reshape((3, 4))
print(x)
x = x.reshape(-1) # 将x张量展平
print(x)
# 变形后不知道其中一个维度的大小,可用-1表示
x = x.reshape(2, -1, 2)
print(x.shape)
输出:
tensor([[-0.1333, -1.4792, -1.5896],
[ 1.4252, 1.9033, 1.1032],
[ 0.4291, 1.4380, -2.6079],
[ 0.6619, 1.5024, 0.1740]])
tensor([[-0.1333, -1.4792, -1.5896, 1.4252],
[ 1.9033, 1.1032, 0.4291, 1.4380],
[-2.6079, 0.6619, 1.5024, 0.1740]])
tensor([-0.1333, -1.4792, -1.5896, 1.4252, 1.9033, 1.1032, 0.4291, 1.4380,
-2.6079, 0.6619, 1.5024, 0.1740])
torch.Size([2, 3, 2])
4.permute()
permute()是对tensor进行转置。
代码如下:
x = torch.randn(1, 2, 3) # 张量大小为1x2x3
print(x.size())
print(x.permute(2, 0, 1).size())
# 将下标为2的维度转置到第0个维度(2->0),0->1,1->2
输出结果:
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
5.transpose()
transpose()只能对tensor的某两个维度进行转置。
x = torch.randn(2, 3, 4)
print(x.size())
print(x.transpose(0, 2).size()) # 维度0和维度2进行转置
输出:
torch.Size([2, 3, 4])
torch.Size([4, 3, 2])
6.flatten()
flatten字面意思就是展平,拉平。
t = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8]],
[[9, 10, 11, 12], [13, 14, 15, 16]]])
print(t.size())
print(torch.flatten(t)) # 把张量t展平成一维
print(t.flatten(start_dim=1)) # 从下标为1的维度开始,到最后一个下标拉平,大小变为2x8(原来为2x2x4)
print(t.flatten(start_dim=0, end_dim=1)) # 从下标为0的维度开始,到下标为1的维度拉平,大小为4x4(原来为2x2x4)
输出:
torch.Size([2, 2, 4])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13, 14, 15, 16]])
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
7.cat()
cat()作用是让各Tensor在某个维度上拼接。各Tensor非拼接的维度上的维度数必须相同。
第一个参数的各tensor类型可以是元组,也可以是列表
第二个参数dim,选择在哪个维度上进行拼接
x = torch.randn(2, 3)
print(x)
print(torch.cat([x, x, x], 0)) # 在第0个维度上进行拼接,大小为6x3
print(torch.cat((x, x, x), 1)) # 在第1个维度上进行拼接,大小为2x9
输出:
tensor([[-1.4191, 0.9298, -0.8898],
[-1.3625, 0.2327, -0.9634]])
tensor([[-1.4191, 0.9298, -0.8898],
[-1.3625, 0.2327, -0.9634],
[-1.4191, 0.9298, -0.8898],
[-1.3625, 0.2327, -0.9634],
[-1.4191, 0.9298, -0.8898],
[-1.3625, 0.2327, -0.9634]])
tensor([[-1.4191, 0.9298, -0.8898, -1.4191, 0.9298, -0.8898, -1.4191, 0.9298,
-0.8898],
[-1.3625, 0.2327, -0.9634, -1.3625, 0.2327, -0.9634, -1.3625, 0.2327,
-0.9634]])
8.chunk()
chunk()是将一个张量分割成特定数目的张量。如果给定dim不能整除chunks,最后一个张量会比较小。
torch.chunk(input, chunks, dim=0)
9.split()
torch.split(tensor, split_size_or_sections, dim=0)
split_size_or_sections类型可以是int或者list(int)
dim(int)张量划分维度
10、stack()
stack()是在一个新的维度连接几个tensor张量。
11、take()
take()把张量展平,按索引取值。
torch.take(input, index)
12、tile()
拷贝张量
torch.tile(input, dims)
13、unbind()
对指定的维度把张量拆分为多个小的张量。
torch.unbind(input, dim=0)
14、squeeze()与unsqueeze()
squeeze()移除张量中维度大小为1的。
torch.squeeze(input, dim=None)
unsqueeze()增加张量的维度。
torch.unsqueeze(input, dim)
15、where()
16、full()
torch.full( size , fill_value)
size ( int…torch.Size ) --定义输出张量形状的列表、元组或整数。
fill_value ( Scalar ) – 填充输出张量的值。
17、cumprod()
torch.cumprod( input , dim , * , dtype = None , out = None )
input ( Tensor ) – 输入张量。
dim ( int ) – 进行操作的维度。
a=torch.tensor([
x
1
x_1
x1,
x
2
x_2
x2,
x
3
x_3
x3])
b=torch.cumprod(a, dim=0)
b=tensor([
x
1
x_1
x1,
x
1
∗
x
2
x_1*x_2
x1∗x2,
x
1
∗
x
2
∗
x
3
x_1*x_2*x_3
x1∗x2∗x3])
18、gather()
torch.gather(t,dim=1,index=index_a)
dim=0竖着取值,index是行索引。
dim=1横着取值,index列索引。