pytorch笔记

pytorch常用的一些函数

  1. unsqueeze()函数介绍:在指定维度上增加维度。
    例如:
    a = t.arange(0,6)
    a.view(2,3) #a的维度为(2,3)
    a.unsqueeze(1) #增加第二个维度 ,a的维度为(2,1,3)
    
  2. tensor.expand_as():把一个tensor变成和函数括号内一样形状的tensor,用法与expand()类似
    例如:
    x = torch.tensor([[1],[2],[3]]) #x.size为(3,1)
    x.expand(3,4) #x.size为(3,4)
    
  3. gather()函数:沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
    torch.gather(input, dim, index, out=None) → Tensor
    例如:
    b = torch.Tensor([[1,2,3],[4,5,6]])
    	print b
    	index_1 = torch.LongTensor([[0,1],[2,0]])
    	index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
    	print torch.gather(b, dim=1, index=index_1)
    	print torch.gather(b, dim=0, index=index_2)
    
    观察它的输出结果:
     1  2  3
     4  5  6
    [torch.FloatTensor of size 2x3]
    1  2
    6  4
    [torch.FloatTensor of size 2x2]
    1  5  6
    1  2  3
    [torch.FloatTensor of size 2x3]
    
  4. permute(dims):将tensor的维度换位。
    例如:
    import torch
    import numpy    as np
    a=np.array([[[1,2,3],[4,5,6]]])
    unpermuted=torch.tensor(a)
    print(unpermuted.size())  #  ——>  torch.Size([1, 2, 3])
    permuted=unpermuted.permute(2,0,1)
    print(permuted.size())     #  ——>  torch.Size([3, 1, 2])
    
  5. cat()函数:对数据沿着某一维度进行拼接。cat后数据的总维数不变
    例如:下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。
    import torch	
    torch.manual_seed(1)
    x = torch.randn(2,3)	
    y = torch.randn(1,3)	
    print(x,y)
    
    结果:
    0.6614 0.2669 0.0617
    0.6213 -0.4519 -0.1661
    [torch.FloatTensor of size 2x3]
    -1.5228 0.3817 -1.0276
    [torch.FloatTensor of size 1x3]
    
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值