1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| >>> words = ['我', '爱', '中', '国'] >>> pos = ['n', 'v', 'n', 'n']
>>> words_attn = torch.rand(4,4)
>>> words_attn tensor([[0.6279, 0.6234, 0.9831, 0.5267], [0.2265, 0.8453, 0.5740, 0.4772], [0.7759, 0.6952, 0.1758, 0.3800], [0.9998, 0.3138, 0.5078, 0.5565]])
>>> scores, indices = words_attn.topk(k=2, dim=1)
>>> indices tensor([[2, 0], [1, 2], [0, 1], [0, 3]])
>>> pos_tensor = torch.tensor([111, 222, 333, 444])
>>> pos_tensor[indices] tensor([[333, 111], [222, 333], [111, 222], [111, 444]])
pos_embedding(pos_tensor[indices])
|