import numpy as np
def get_intent_acc(preds, labels):
acc = (preds == labels).mean()
return {
"intent_acc": acc
}
if __name__ == '__main__':
preds = np.array([0.1, 0.5, 0.1, 0.1, 0.1,0.1])
labels = np.array([0.1, 0.6, 0.1, 0.1, 0.05, 0.05])
a = get_intent_acc(preds, labels)
print(a)
结果:3/6=0.5
D:\ProgramData\Anaconda3\envs\jointBert\python.exe E:/github/2020.03/JointBERT-master/test.py
{'intent_acc': 0.5}
Process finished with exit code 0
注意:mean()函数在numpy中使用