4.tensor的索引与变形

代码如下:

import torch

if __name__ == '__main__':

    #一、tensor的索引

    a=torch.Tensor([[1,2,3],[0,3,2]])
    #1.根据下标索引
    print(a[1][2])
    print("1.根据下标索引",a[:,1])

    #2.选择a中大于1的元素,返回和a相同大小的tensor,符合条件的输出1,否则输出0
    s=a>1
    print("2.选择a中大于1的元素",s)

    #3.选择符合条件的元素并返回
    print("3.选择符合条件的元素并返回",a[s])

    #4.torch.where(condition,x,y),满足condition的位置输出x,否则输出y
    hh=torch.where(a>1,2,a)
    print("4.torch.where(condition,x,y)",hh)

    #5.clamp()函数
    t=a.clamp(1,2)#限制最小值为1,最大值为2
    print("5.clamp()函数",t)

    #6.选择非0元素的坐标
    g=torch.nonzero(a)
    print("6.选择非0元素的坐标",g)


    #二、tensor的变形
    #常见的有 #view,resize,reshape
            #transpose,permute
            #squeeze,unsqueeze
            #expand,exoand_as
    #1.view,resize,reshape
    a=torch.arange(1,17)
    print(a.shape)
    print(a.reshape(2,8).shape)#reshape
    print(a.resize(4,4).shape)#resize
    print(a.view(8,2).shape)#view

    #2.transpose,permute:各维度之间的位置变换
    a=torch.Tensor([[1,2,3],[4,5,6]])
    b=a.transpose(0,1)#将第0维与第1维的元素进行转置
    print(b)

    c=a.permute(1,0)#按照1,0的维度顺序重新排序
    print(c)

    #3.squeeze,unsqueeze,用来去除size为1的维度
    a=torch.arange(4)
    print(a.shape,a)
    b=a.unsqueeze(1)
    print(b.shape,b)

    #4.expand,exoand_as:采样复制的方式来扩展tensor的维度
    a=torch.randn(2,2,1)
    b=a.expand(2,2,3)
    print("a",a)
    print("b",b)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值