1 问题描述
- 任务是评估一个预测模型的准确性,该模型通过计算embedding的内积进行位置预测
-
具体来说,有以下数据:
-
张量
a
(大小为 40x10),包含 40 个位置的embedding(每个位置的embedding维度为10)。 -
张量
b
(大小为 4x10),包含一个batch中4个预测位置的embedding。 -
张量
c
(大小为 4x1),包含四个ground-truth位置的索引。
-
-
任务是对于
b
中的每个位置,计算它与a
中所有40个位置的embedding的内积,然后选取内积值最高的10个位置。 -
接下来,检查
c
中的ground-truth位置索引是否在这10个最高内积位置中
2 python 实现
2.1 获取a,b,c
import torch
a = torch.randn(40, 10) # 40个位置的embedding
b = torch.randn(4, 10) # 4个预测位置的embedding
b
'''
tensor([[-1.7334e-02, -4.0604e-01, 9.0610e-01, 1.7381e+00, -1.1258e+00,
-4.9452e-01, -2.1200e+00, -7.3516e-01, -1.6682e+00, 1.7613e-01],
[-2.0892e-01, 8.8255e-01, -4.0820e-02, -1.4790e+00, -1.5859e+00,
5.1649e-01, -3.2593e-01, 6.4271e-01, 8.9277e-01, -6.6575e-01],
[ 1.1684e-01, -7.3740e-01, 1.0661e+00, -1.0934e+00, -6.1928e-01,
1.2838e+00, 5.3154e-01, 2.1426e+00, 3.3756e-02, -7.1108e-01],
[-2.0004e+00, 1.6285e+00, 1.5834e-01, -7.2439e-01, -4.5901e-01,
9.4934e-01, -1.3431e+00, -2.1714e+00, 1.3512e-01, 5.9950e-04]])
'''
c = torch.randint(0, 40, (4, 1)) # ground-truth位置的索引
c
'''
tensor([[12],
[27],
[ 6],
[12]])
'''
2.2 计算b与a的所有位置的内积
similarity = torch.matmul(b, a.t())
similarity
2.3 对每一行(每个预测位置)取最大的10个值的索引
top10_indices = torch.topk(similarity, 10, dim=1).indices
top10_indices
'''
tensor([[25, 12, 1, 13, 11, 10, 19, 8, 37, 3],
[26, 35, 20, 5, 30, 10, 22, 16, 0, 38],
[27, 19, 22, 30, 23, 14, 35, 26, 2, 9],
[26, 38, 11, 10, 1, 32, 12, 5, 6, 21]])
'''
2.4 判断是否在top10中
is_in_top10 = torch.zeros(4, dtype=torch.bool)
for i in range(4):
is_in_top10[i] = c[i] in top10_indices[i]
is_in_top10
#tensor([ True, False, False, True])