PyTorch常用函数总结(更新中)

目录

1.linspace函数(tensor生成)

2.rand函数和randn函数(随机tensor生成)

3.torch.squeeze函数和torch.unsqueeze函数(tensor的维度去除和增加)

4.repeat函数(tensor的复制)

5.expand函数(tensor扩张)

6.torch.cat函数和torch.stack函数(tensor的拼接)

7.torch.chunk函数和torch.split函数(tensor的拆分)

8.permute函数和transpose函数(交换维度)

9.reshape函数和resize函数(改变tensor的形状)

10.view函数(改变tensor形状)

11.torch.flatten函数(展平tensor)

12.meshgrid函数(生成网格)

13.torch.argmax(返回指定维度最大值的序号)

1.linspace函数(tensor生成)

torch.linspace(start, end, steps=分割点数,dtype=返回类型)

>>> t = torch.linspace(3, 10, 5)
>>> t
tensor([ 3.0000,  4.7500,  6.5000,  8.2500, 10.0000])

>>> c = torch.linspace(3, 10, 5, dtype=torch.int)
>>> c
tensor([ 3,  4,  6,  8, 10], dtype=torch.int32)

2.rand函数和randn函数(随机tensor生成)

rand函数和randn函数两个都是生成随机tensor的函数,两者的区别是rand生成的tensor是基于均匀分布的,而randn生成的向量是基于标准正态分布的

>>> print(torch.rand(2, 3))
tensor([[0.4525, 0.1064, 0.9951],
        [0.2394, 0.7348, 0.5460]])

>>> print(torch.randn(2, 3))
tensor([[ 1.7014,  1.1992, -0.5759],
        [-0.2074, -1.2572,  0.8707]])

3.torch.squeeze函数和torch.unsqueeze函数(tensor的维度去除和增加)

squeeze函数是将tenser维度为1的删除,保留其他维度

>>> import torch
>>> x = torch.randn(size=(2, 1, 2, 1, 2))
>>> x.shape
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x) 
# 把x中维度为1删除掉,保留其他维度
>>> y.shape
torch.Size([2, 2, 2])

同时可以指定dim去删除,但是不可以删除维度大于1的 

>>> x = torch.randn(size=(2, 1, 2, 1, 2))
>>> y = torch.squeeze(x, dim=1)
>>> y.shape
torch.Size([2, 2, 1, 2])
>>> z = torch.squeeze(x, dim=-1)
>>> z.shape
torch.Size([2, 1, 2, 1, 2])

unsqueeze函数主要是对tensor升维,

>>> x = torch.tensor([1, 2, 3, 4])
>>> y = torch.unsqueeze(x, dim=0)
>>> y
tensor([[1, 2, 3, 4]])
>>> y.shape
torch.Size([1, 4])

4.repeat函数(tensor的复制)

repeat函数是对tensor进行复制,可以指定维度复制,可以对非单数维复制

>>> import torch
>>> a = torch.tensor(([[1, 2, 3], [1, 2, 3]]))
>>> b = a.repeat(2, 2)
>>> b
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])
>>> b.shape
torch.Size([4, 6])

 当参数个数大于原有tensor的维度时,在第0维扩展维度为1的维度,然后复制

>>> a = torch.tensor(([[1, 2, 3], [1, 2, 3]]))
>>> a.shape
torch.Size([2, 3])
>>> b = a.repeat(1, 2, 1)
>>> b
tensor([[[1, 2, 3],
         [1, 2, 3],
         [1, 2, 3],
         [1, 2, 3]]])
>>> b.shape
torch.Size([1, 4, 3])

5.expand函数(tensor扩张)

 expand函数是对tensor在指定dim进行扩张,与repeat函数区别是仅在单位维上扩张(复制)

>>> a = torch.randn(1, 1, 3)
>>> b = a.expand(-1, 3, -1)
>>> c = a.expand(3, 3 ,3)
>>> a
tensor([[[-0.8384,  0.5715, -1.5192]]])
>>> b.shape
torch.Size([1, 3, 3])
>>> b
tensor([[[-0.8384,  0.5715, -1.5192],
         [-0.8384,  0.5715, -1.5192],
         [-0.8384,  0.5715, -1.5192]]])
>>> c.shape
torch.Size([3, 3, 3])
>>> c
tensor([[[-0.8384,  0.5715, -1.5192],
         [-0.8384,  0.5715, -1.5192],
         [-0.8384,  0.5715, -1.5192]],
        [[-0.8384,  0.5715, -1.5192],
         [-0.8384,  0.5715, -1.5192],
         [-0.8384,  0.5715, -1.5192]],
        [[-0.8384,  0.5715, -1.5192],
         [-0.8384,  0.5715, -1.5192],
         [-0.8384,  0.5715, -1.5192]]])

6.torch.cat函数和torch.stack函数(tensor的拼接)

cat函数是将两个tensor按照指定维度拼接在一起,除了拼接维数dim数值可以不同,其他数值需要相同,不然无法对齐。(注,cat函数不会新增维度)

>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 3)
>>> c = torch.cat((a, b), dim=0)
>>> c.shape
torch.Size([5, 3])
>>> a, b, c
(tensor([[-0.4516, -0.3405,  0.7484],
        [ 0.5414, -1.2156,  0.2577]]), 
 tensor([[-1.9560,  1.1200, -0.1139],
        [-0.5212,  0.3194,  0.8153],
        [-0.8011,  1.1799, -0.0382]]), 
 tensor([[-0.4516, -0.3405,  0.7484],
        [ 0.5414, -1.2156,  0.2577],
        [-1.9560,  1.1200, -0.1139],
        [-0.5212,  0.3194,  0.8153],
        [-0.8011,  1.1799, -0.0382]]))

stack函数 是在指定参数dim新增维度,然后拼接

>>> a = torch.arange(6.0).reshape(2, 3)
>>> a
tensor([[0., 1., 2.],
        [3., 4., 5.]])
>>> b = torch.linspace(0, 10, 6).reshape(2, 3)
>>> b
tensor([[ 0.,  2.,  4.],
        [ 6.,  8., 10.]])
>>> f = torch.stack((a, b), dim=1)
>>> f
tensor([[[ 0.,  1.,  2.],
         [ 0.,  2.,  4.]],
        [[ 3.,  4.,  5.],
         [ 6.,  8., 10.]]])
>>> f.shape
torch.Size([2, 2, 3])

7.torch.chunk函数和torch.split函数(tensor的拆分)

chunk函数是按照某维度,对tensor进行均匀切分

>>> a = torch.arange(12).reshape(4, 3)
>>> a
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
>>> c = torch.chunk(a, 4, dim=0)
>>> c
(tensor([[0, 1, 2]]), 
tensor([[3, 4, 5]]), 
tensor([[6, 7, 8]]), 
tensor([[ 9, 10, 11]]))

split函数是对tensor在指定dim维度下,按照split_size进行拆分,若split_size不能整除,则剩下数据为一块

>>> a = torch.arange(12).reshape(4, 3)
>>> a
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
>>> s = torch.split(a, 2, dim=0)
>>> s
(tensor([[0, 1, 2],
        [3, 4, 5]]), 
tensor([[ 6,  7,  8],
        [ 9, 10, 11]]))

>>> s = torch.split(a, 2, dim=1)
>>> s
(tensor([[ 0,  1],
        [ 3,  4],
        [ 6,  7],
        [ 9, 10]]), 
tensor([[ 2],
        [ 5],
        [ 8],
        [11]]))

8.permute函数和transpose函数(交换维度)

permute函数可以将维度位置进行交换,和transpose函数区别是可以对任意高维矩阵进行转置

>>> a = torch.randn(2, 3, 5)
>>> a.shape
torch.Size([2, 3, 5])
>>> a.permute(2, 0, 1).shape
torch.Size([5, 2, 3])

 transpose函数tensor转置,只可以操作2D转置

>>> a = torch.randn(2, 3)
>>> a
tensor([[ 2.0059, -0.5201,  1.0149],
        [-0.3918, -0.8710,  0.5579]])
>>> torch.transpose(a, 0, 1)
tensor([[ 2.0059, -0.3918],
        [-0.5201, -0.8710],
        [ 1.0149,  0.5579]])

9.reshape函数和resize函数(改变tensor的形状)

reshape函数是对tensor形状进行改变,但是元素个数不变

>>> a = torch.randn(4)
>>> a
tensor([-1.2296,  0.2770, -0.7377, -0.9197])
>>> torch.reshape(a, (2, 2))
tensor([[-1.2296,  0.2770],
        [-0.7377, -0.9197]])

 resize函数是对tensor尺寸进行调整,但与view不同,它可以修改tensor的尺寸。如果新的尺寸超过原来的尺寸,会自动分配新的内存空间,而如果新的尺寸小于原尺寸,则之前的数据依旧会被保存。

>>> b = torch.arange(0, 6)
tensor([0, 1, 2, 3, 4, 5])
>>> b.resize(2, 3)
tensor([[0, 1, 2],
        [3, 4, 5]])

10.view函数(改变tensor形状)

view函数可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,返回的新tensor与源tensor共享内存,即更改其中一个,另外一个也会跟着改变。

>>> a = torch.arange(0, 6)
>>> a
tensor([0, 1, 2, 3, 4, 5])
>>> b = a.view(-1, 3)
>>> b
tensor([[0, 1, 2],
        [3, 4, 5]])

11.torch.flatten函数(展平tensor)

torch.flatten(input, start_dim=0, end_dim=-1)将Tensor从start_dim展平到end_dim

>>> t = torch.randn([3, 2, 3])
>>> t
tensor([[[-0.9367,  0.5908,  0.8416],
         [ 0.2513,  1.4519, -0.6566]],
        [[-0.0527, -0.2044,  0.0610],
         [ 1.7106, -1.7432,  0.2789]],
        [[ 1.4201,  0.4685, -0.9257],
         [ 1.5866,  0.9033, -0.8300]]])
>>> f = torch.flatten(t)
>>> f
tensor([-0.9367,  0.5908,  0.8416,  0.2513,  1.4519, -0.6566, -0.0527, -0.2044,
         0.0610,  1.7106, -1.7432,  0.2789,  1.4201,  0.4685, -0.9257,  1.5866,
         0.9033, -0.8300])

>>> t = torch.randn([3, 2, 3])
>>> t
tensor([[[-0.4668,  1.1902, -0.1419],
         [ 0.8214, -1.3676,  1.5473]],
        [[-0.1603, -0.8412,  0.0700],
         [-0.3150, -1.2187, -0.8126]],
        [[ 0.2662, -0.5577, -1.8281],
         [ 0.4309,  0.6306, -1.8118]]])
>>> f = torch.flatten(t, start_dim=1)
>>> f
tensor([[-0.4668,  1.1902, -0.1419,  0.8214, -1.3676,  1.5473],
        [-0.1603, -0.8412,  0.0700, -0.3150, -1.2187, -0.8126],
        [ 0.2662, -0.5577, -1.8281,  0.4309,  0.6306, -1.8118]])

12.meshgrid函数(生成网格)

输入两个数据类型相同的一维tensor,输出两个tensor(tensor行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数)。(注,输入两个tensor数据类型和维度不是一维会出错)

>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([5, 6, 7])
>>> resultx, resulty = torch.meshgrid(x,y)
>>> print(x,y)
tensor([1, 2, 3]) tensor([5, 6, 7])
>>> print(resultx)
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
>>> print(resulty)
tensor([[5, 6, 7],
        [5, 6, 7],
        [5, 6, 7]])

13.torch.argmax(返回指定维度最大值的序号)

>>> a = torch.tensor([[1, 5, 5, 2], [9, 6, -2, 8], [-3, 7, -9, 1]])
>>> a
>>> tensor([[ 1,  5,  5,  2],
        [ 9,  6, -2,  8],
        [-3,  7, -9,  1]])
>>> b = torch.argmax(a, dim=0)
>>> b
tensor([1, 2, 0, 1])

  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值