Pytorch中常见函数1

函数:unfold(), view(), reshape(), permute(), transpose(), flatten(), cat(), chunk(), split(),
stack(), take(), tile(), unbind(), squeeze(), unsqueeze(), where(), full(), cumprod(), gather()


一、Tensor操作的函数

1.unfold()

unfold函数相当于一维滑动窗口的操作
unfold(dimension, size, step) -> Tensor
dimension->表示从哪个维度展开滑动
size->滑动窗口的大小
step->每次滑动的步长
对于一维的情况
代码如下(示例):

x = torch.arange(1., 5)
print(x)
x = x.unfold(0, 2, 1)   
print(x)
print(x.shape)

最终输出的x是3个大小为2的窗口
输出结果如下:

tensor([1., 2., 3., 4.])
tensor([[1., 2.],
        [2., 3.],
        [3., 4.]])
torch.Size([3, 2])

对于二维的情况
连续两次使用unfold()函数,就变成了一个2维滑动窗口的效果

x = torch.randn(4, 3)
print(x)
x = x.unfold(0, 2, 1).unfold(1, 2, 1)  # 获得了3×2个大小为2×2的窗口
print(x)
print(x.shape)

输出结果:

tensor([[-1.0882, -0.9543,  0.7692],
        [ 1.0062, -0.5342,  2.0376],
        [-1.2618, -0.2669,  0.0763],
        [ 0.8875, -1.6452, -0.7613]])
tensor([[[[-1.0882, -0.9543],
          [ 1.0062, -0.5342]],

         [[-0.9543,  0.7692],
          [-0.5342,  2.0376]]],


        [[[ 1.0062, -0.5342],
          [-1.2618, -0.2669]],

         [[-0.5342,  2.0376],
          [-0.2669,  0.0763]]],


        [[[-1.2618, -0.2669],
          [ 0.8875, -1.6452]],

         [[-0.2669,  0.0763],
          [-1.6452, -0.7613]]]])
torch.Size([3, 2, 2, 2])

2.view()

view()将tensor重塑为想要的shape。其操作是将张量展平成一维之后,再排列成想要的形状。
代码如下(示例):

x = torch.randn(4, 3)
print(x)
x = x.view((3, 4))
print(x)
x = x.view(-1)	# 将x张量展平
print(x)

# 变形后不知道其中一个维度的大小,可用-1表示
x = x.view(2, -1, 2)
print(x.shape)

tips:view()函数,从3×4的形状变到4×3,这个操作与转置操作不同。
输出结果:

tensor([[ 0.2126, -0.8827,  1.2574],
        [ 0.7362,  0.2389, -0.9048],
        [ 0.6516, -1.2950,  0.7889],
        [ 1.5935,  1.4653,  0.5844]])
tensor([[ 0.2126, -0.8827,  1.2574,  0.7362],
        [ 0.2389, -0.9048,  0.6516, -1.2950],
        [ 0.7889,  1.5935,  1.4653,  0.5844]])
tensor([ 0.2126, -0.8827,  1.2574,  0.7362,  0.2389, -0.9048,  0.6516, -1.2950,
         0.7889,  1.5935,  1.4653,  0.5844])
torch.Size([2, 3, 2])

3.reshape()

reshape()与view()都是对张量shape进行重塑。
详细区别参考博客:PyTorch:view() 与 reshape() 区别详解

x = torch.randn(4, 3)
print(x)
x = x.reshape((3, 4))
print(x)
x = x.reshape(-1)	# 将x张量展平
print(x)

# 变形后不知道其中一个维度的大小,可用-1表示
x = x.reshape(2, -1, 2)
print(x.shape)

输出:

tensor([[-0.1333, -1.4792, -1.5896],
        [ 1.4252,  1.9033,  1.1032],
        [ 0.4291,  1.4380, -2.6079],
        [ 0.6619,  1.5024,  0.1740]])
tensor([[-0.1333, -1.4792, -1.5896,  1.4252],
        [ 1.9033,  1.1032,  0.4291,  1.4380],
        [-2.6079,  0.6619,  1.5024,  0.1740]])
tensor([-0.1333, -1.4792, -1.5896,  1.4252,  1.9033,  1.1032,  0.4291,  1.4380,
        -2.6079,  0.6619,  1.5024,  0.1740])
torch.Size([2, 3, 2])

4.permute()

permute()是对tensor进行转置。
代码如下:

x = torch.randn(1, 2, 3)	# 张量大小为1x2x3
print(x.size())
print(x.permute(2, 0, 1).size())	
# 将下标为2的维度转置到第0个维度(2->0)0->11->2

输出结果:

torch.Size([1, 2, 3])
torch.Size([3, 1, 2])

5.transpose()

transpose()只能对tensor的某两个维度进行转置。

x = torch.randn(2, 3, 4)
print(x.size())
print(x.transpose(0, 2).size())		# 维度0和维度2进行转置

输出:

torch.Size([2, 3, 4])
torch.Size([4, 3, 2])

6.flatten()

flatten字面意思就是展平,拉平。

t = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8]],
                  [[9, 10, 11, 12], [13, 14, 15, 16]]])
print(t.size())     
print(torch.flatten(t))     # 把张量t展平成一维
print(t.flatten(start_dim=1))   # 从下标为1的维度开始,到最后一个下标拉平,大小变为2x8(原来为2x2x4)
print(t.flatten(start_dim=0, end_dim=1))   # 从下标为0的维度开始,到下标为1的维度拉平,大小为4x4(原来为2x2x4)

输出:

torch.Size([2, 2, 4])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16])
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8],
        [ 9, 10, 11, 12, 13, 14, 15, 16]])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])

7.cat()

cat()作用是让各Tensor在某个维度上拼接。各Tensor非拼接的维度上的维度数必须相同。
第一个参数的各tensor类型可以是元组,也可以是列表
第二个参数dim,选择在哪个维度上进行拼接

x = torch.randn(2, 3)
print(x)
print(torch.cat([x, x, x], 0))    # 在第0个维度上进行拼接,大小为6x3
print(torch.cat((x, x, x), 1))   # 在第1个维度上进行拼接,大小为2x9

输出:

tensor([[-1.4191,  0.9298, -0.8898],
        [-1.3625,  0.2327, -0.9634]])
tensor([[-1.4191,  0.9298, -0.8898],
        [-1.3625,  0.2327, -0.9634],
        [-1.4191,  0.9298, -0.8898],
        [-1.3625,  0.2327, -0.9634],
        [-1.4191,  0.9298, -0.8898],
        [-1.3625,  0.2327, -0.9634]])
tensor([[-1.4191,  0.9298, -0.8898, -1.4191,  0.9298, -0.8898, -1.4191,  0.9298,
         -0.8898],
        [-1.3625,  0.2327, -0.9634, -1.3625,  0.2327, -0.9634, -1.3625,  0.2327,
         -0.9634]])

8.chunk()

chunk()是将一个张量分割成特定数目的张量。如果给定dim不能整除chunks,最后一个张量会比较小。
torch.chunk(input, chunks, dim=0)
chunk()例子


9.split()

torch.split(tensor, split_size_or_sections, dim=0)
split_size_or_sections类型可以是int或者list(int)
dim(int)张量划分维度
在这里插入图片描述


10、stack()

stack()是在一个新的维度连接几个tensor张量。
在这里插入图片描述


11、take()

take()把张量展平,按索引取值。
torch.take(input, index)
在这里插入图片描述


12、tile()

拷贝张量
torch.tile(input, dims)
在这里插入图片描述


13、unbind()

对指定的维度把张量拆分为多个小的张量。
torch.unbind(input, dim=0)
在这里插入图片描述


14、squeeze()与unsqueeze()

squeeze()移除张量中维度大小为1的。
torch.squeeze(input, dim=None)
unsqueeze()增加张量的维度。
torch.unsqueeze(input, dim)
在这里插入图片描述
在这里插入图片描述


15、where()

在这里插入图片描述


16、full()

torch.full( size , fill_value)
size ( int…torch.Size ) --定义输出张量形状的列表、元组或整数。
fill_value ( Scalar ) – 填充输出张量的值。
在这里插入图片描述


17、cumprod()

torch.cumprod( input , dim , * , dtype = None , out = None )
input ( Tensor ) – 输入张量。
dim ( int ) – 进行操作的维度。
在这里插入图片描述
a=torch.tensor([ x 1 x_1 x1, x 2 x_2 x2, x 3 x_3 x3])
b=torch.cumprod(a, dim=0)
b=tensor([ x 1 x_1 x1, x 1 ∗ x 2 x_1*x_2 x1x2, x 1 ∗ x 2 ∗ x 3 x_1*x_2*x_3 x1x2x3])


18、gather()

torch.gather(t,dim=1,index=index_a)
dim=0竖着取值,index是行索引。
在这里插入图片描述
dim=1横着取值,index列索引。
在这里插入图片描述

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值