pytorch按照索引取batch中的数

比如说bert的输出表征是基于子词的,想要用于基于词的任务,需要将词对应的最后一个子词的表征取出来,代码如下:

import torch
import torch.nn as nn

torch.manual_seed(1)

mix = torch.randn([2, 5, 3])
range_vector = torch.tensor([[0], [1]])
offsets2d = torch.tensor([[1, 3, 0], [1, 2, 4]])
print(mix)
selected_embeddings = mix[range_vector, offsets2d]
print(selected_embeddings.size())
print(selected_embeddings)

结果:


tensor([[[-1.5256, -0.7502, -0.6540],
         [-1.6095, -0.1002, -0.6092],
         [-0.9798, -1.6091, -0.7121],
         [ 0.3037, -0.7773, -0.2515],
         [-0.2223,  1.6871, -0.3206]],

        [[-0.2993,  1.8793, -0.0721],
         [ 0.1578, -0.7735,  0.1991],
         [ 0.0457, -1.3924,  2.6891],
         [-0.1110,  0.2927, -0.1578],
         [-0.0288,  2.3571, -1.0373]]])
torch.Size([2, 3, 3])
tensor([[[-1.6095, -0.1002, -0.6092],
         [ 0.3037, -0.7773, -0.2515],
         [-1.5256, -0.7502, -0.6540]],

        [[ 0.1578, -0.7735,  0.1991],
         [ 0.0457, -1.3924,  2.6891],
         [-0.0288,  2.3571, -1.0373]]])

Process finished with exit code 0

此段代码,在模型转成onnx时候会报错,可改成

mix = torch.randn([2, 5, 3]).cuda()
offsets = torch.tensor([[1, 3, 0], [1, 2, 4]]).cuda()

# 按索引取数
B, S, D = mix.size()
new_mix = mix.view(-1, D)
_, W = offsets.size()
right_add = torch.arange(0, B).unsqueeze(-1).cuda()
right_add = right_add * S
right_add.expand([B, W])

new_offsets = right_add + offsets

new_offsets = new_offsets.view(-1)
print(new_offsets)
out1 = new_mix.index_select(0, new_offsets)
# index_select 必须是一维向量
# torch.gather输出维度和输入的维度必须相同
print(out1.view(B, W, -1))

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

旺旺棒棒冰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值