活到老学到老之index操作

快速想一想,你能想到torch有哪些常见的index操作??

1. gather

1
2
3
4
5
>>> a = torch.tensor([[1, 2, 3],
[4, 5, 6]])
>>> a.gather(dim=1, index=torch.tensor([[0,1], [1,2]]))
tensor([[1, 2],
[5, 6]])

2. index_select

1
2
3
4
5
6
>>> a
tensor([[1, 2, 3],
[4, 5, 6]])
>>> a.index_select(dim=1, index=torch.tensor([1,2]))
tensor([[2, 3],
[5, 6]])

3. 骚气的来了哦

根据上面例子可以看到,a为矩阵,选择a中的index,但是下面介绍一个map操作.

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> index
tensor([[1, 2, 3],
[4, 5, 6]])

>>> a = torch.tensor([11, 22, 33, 44, 55, 66, 77])
>>> a
tensor([11, 22, 33, 44, 55, 66, 77])
>>> index
tensor([[1, 2, 3],
[4, 5, 6]])
>>> a[index]
tensor([[22, 33, 44],
[55, 66, 77]])

这种操作有一个真实场景,比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 1. 这是两个特征
>>> words = ['我', '爱', '中', '国']
>>> pos = ['n', 'v', 'n', 'n']

# 2. 假设words变成了一个4 * 4的临接矩阵,用于表示每个token和其他token的一个关联重要程度

>>> words_attn = torch.rand(4,4)

>>> words_attn
tensor([[0.6279, 0.6234, 0.9831, 0.5267],
[0.2265, 0.8453, 0.5740, 0.4772],
[0.7759, 0.6952, 0.1758, 0.3800],
[0.9998, 0.3138, 0.5078, 0.5565]])


>>> scores, indices = words_attn.topk(k=2, dim=1)

>>> indices
tensor([[2, 0],
[1, 2],
[0, 1],
[0, 3]])

# 3. 假设pos转为了
>>> pos_tensor = torch.tensor([111, 222, 333, 444])

# 4. map操作
>>> pos_tensor[indices]
tensor([[333, 111],
[222, 333],
[111, 222],
[111, 444]])

# 5. 随后就可以接一个embedding搞事情了
pos_embedding(pos_tensor[indices])

# 6. 总结,这个示例的优点可以看出是快速计算,取topK然后再结合其他的特征进行操作。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值