11.15-16
终于找出困扰我一天的bug了,哭哭哦
数组与集合数组!
运行代码时发现:
preds = np.argmax(probs, axis=1)
probs好像是一个二维数组
竟然返回的是列最大值的索引!
我非常惊讶,再三确认了argmax的函数。
没错啊 argmax:返回数组最大值的索引,如果数组是二维数组的话:axis=0返回的是每列最大的索引,axis=1返回的是每行最大的索引。
错误的原因在于probs不是数组!是一个集合,里面有一个数组!
数组与集合数组的区别在于:
1.数组(np)可以调用shape函数,而集合数组不能调用
2.使用print方法时,返回的东西是不同的。如图:
可以看到,使用print函数时,如果是集合数组,他会比数组多[array(…)]这种东西!
使用argmax时自然也有不同,如果对数组使用argmax,结果是一个躺平的数组。如果时对集合的数组使用argmax,结果是一个二维数组!
发现问题后(将集合数组命名为probs),我第一反应是在probs套一层np,将probs变为三维数组,再用三维数组套上一个axis=2来寻找最大值下标,但是问题来了,argmax(三维数组)的结果是两维数组。
和上图对比,结果是我们想要的结果,但是多了一层[ ]。对于三维数组的argmax,贴一个还不错的解释:最后我把问题分享给了室友,室友一语中的:在probs后面加一个[0]就好了,意思就是取数组集合中的第一个元素,该元素类型为数组。淦!
这样总算大功告成了。
我也想了一下,这个probs函数是调用nlp模型的predict函数得到的结果。我调用的是我们大佬的函数,为什么他的lstm的predict就能运行得好好的,我的bert就返回的是一个list呢?可能是因为bert的参数更复杂吧。
查阅了资料,这个是tf.karas.Model库(也是大佬代码的函数原型)对predict返回值的定义:
可能是因为bert的predict用的是transfrom包的。
天哪,水了这么久,代码还没跑完,先写这么多吧。