tensor 之间的 [] 运算

例子1, tensor 之间的 indices 运算

aa = [[1,2,3],[4,5,6]]
aa = torch.tensor(aa,dtype=torch.long)  # (x,y)

bb = torhc.zeros(2,5).type(torch.long)  #(z,k)

# 接着使用运算
cc = aa[bb]

aa = [[1,2,3]
       [4,5,6]]

bb = [[0,0,0,0,0]
        [0,0,0,0,0]]

首先 cc的shape 为: x * k * y == 2*5*3
接着 cc的

tensor([[[4,5,6],
        [4,5,6],
         [4,5,6],
         [4,5,6],
         [4,5,6]],
        [[4,5,6],
         [4,5,6],
         [4,5,6],
        [4,5,6],
         [4,5,6]]])

例子2. 

a = [[[1,2,3],
      [4,5,6]],
     [[1,2,3],
      [4,5,6]],
     [[1,2,3],
      [4,5,6]]]

a = torch.tensor(a,dtype=torch.long)

# print "a.shape": torch.Size([3, 2, 3])

b = [[1,-1],[1,-1],[-1,1]]
b = torch.tensor(b,dtype=torch.long)
# print "b.shape": torch.Size([3, 2])

c = a[b>=0]  # shape = 3 * 3

tensor([[1, 2, 3],
        [1, 2, 3],
        [4, 5, 6]])

再来一个view 操作

c = torch.Tensor([[1, 2, 3],
        [1, 2, 3],
        [4, 5, 6]])  # shape: 3 * 3
 # 通过 view 操作
d = c.view(3,1,-1)  # 3 代表原先 a 中 的batch_size =3, 1 代表 b 中每一个维度只有一个 -1, -1 代表最后一个维度不变

# print "d.shap": torch.Size([3, 1, 3])
tensor([[[1, 2, 3]],
        [[1, 2, 3]],
        [[4, 5, 6]]])

x 为 一个 tensor,然后进行如下计算: y = x[..., 1:].ne(0).view(-1)==1 

 x[..., 1:]  等价于 x[: , 1:]

x = torch.Tensor([[1,2,3,4],[4,5,6,7]])
x.shape
Out[38]: torch.Size([2, 4])
x
Out[39]: 
tensor([[1., 2., 3., 4.],
        [4., 5., 6., 7.]])

y = x[...,1:]
Out[41]: 
tensor([[2., 3., 4.],
        [5., 6., 7.]])

z = y.ne(0)
Out[43]: 
tensor([[True, True, True],
        [True, True, True]])

k = z.view(-1)
Out[46]: tensor([True, True, True, True, True, True])

x 是一个tensor, x.view(-1, y.size(-1)); y.size(-1) 最后一维的size.

k = torch.Tensor([[[1,2,3],[4,5,6]],
                  [[7,8,8],[8,9,10]],
                  [[11,12,13],[14,15,16]]])
k.shape
Out[57]: torch.Size([3, 2, 3])
k
Out[58]: 
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.]],
        [[ 7.,  8.,  8.],
         [ 8.,  9., 10.]],
        [[11., 12., 13.],
         [14., 15., 16.]]])

z = k.view(-1,3)  # 合并前面两维,最后一维不变
Out[60]: 
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  8.],
        [ 8.,  9., 10.],
        [11., 12., 13.],
        [14., 15., 16.]])

z.shape
Out[61]: torch.Size([6, 3])

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值