np.where()用于三目运算:
如果A%2==0成立,则执行A+1,否则执行A-1
a=
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> np.where(a < 5, a, 10*a)
array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90]) 小于5的,保持不变,不满足小于5这个条件的,*10
np.where(condition, x, y)
满足条件(condition),输出x,不满足输出y。
只有条件 (condition),没有x和y,则输出满足条件 (即非0) 元素的坐标 (等价于numpy.nonzero)。这里的坐标以tuple的形式给出,通常原数组有多少维,输出的tuple中就包含几个数组,分别对应符合条件元素的各维坐标。
v1 = embedding[word2index[instance[1]]] - embedding[word2index[instance[0]]] + embedding[word2index[instance[2]]] #1-0+2=3 按照类比词对计算出来的预测的3的embedding target = word2index[instance[3]] #id为3的词实际对应的embedding对应的idx distance2 = np.linalg.norm(embedding - v1, axis=1) #将embedding矩阵和计算出的3的embedding求差,和实际的3的embedding越接近的那一行,distance2越小 top = np.where(distance2 < distance2[target])[0].tolist() #满足条件distance2<distance[target]的非0元素的下标,因为返回的是tuple,将其转为list