一种评估聊天机器人输出置信度的方法
为什么需要评估聊天机器人输出置信度
使用LSTM构建聊天机器人,无论输入是什么,就算输入完全不符合语法,模型都会给出一个输出,显然,这个输出不是我们想要的,如何识别模型输出是不是我们想要的?因此我们需要一种评估指标,评估模型输出的置信度。评估LSTM模型的置信度本质上是判断输入与模型输出是否属于训练语料集之内,因为LSTM模型是在语料集的输入与标签之间建立了映射关系,对于训练语料集之外的输入,LSTM模型输出是随机的。因此,可以通过判断输入与LSTM模型的输出是否属于训练语料集之内来评估LSTM模型的置信度。
算法
把训练语料集的提问分词:key1,key2,…keyn,做为字典的关键字,对应的应答列表为字典的值:[answer1,answer2,… answerk],相同的关键字加入同一应答列表,如下表:
调用LSTM模型做预测时,把输入做分词,得到:key1,key2,…keyn,分别查找关系字典,得到对应的应答列表,对比应答列表的answer与LSTM模型输出,如果命中则计数器count加1。如果应答列表为空则为不命中。置信度confidence用下式计算:
Confidence = count / key size
计算LSTM模型置信度python代码
import jieba
import pickle
class RelevanceChat():
def __init__(self,topk=5):
self.topk = topk
self.fited = False
def fit(self,x_data,y_data,ws_decode):
self.dict = {}
high_fw_max = int(len(x_data) * 0.6)
for ask, answer in zip(x_data, y_data):
ask_str = ''.join(ask)
if len(ask_str) == 0:
continue
top_key = jieba.lcut(ask_str)
#print("top key:", top_key)
y_code = ws_decode.transform(answer)[0]
key_set = set(top_key)
for key in key_set:
rel_list = []
if key in self.dict:
rel_list = self.dict[key]
if rel_list[0]==0:
continue
elif len(rel_list)>=high_fw_max:
print("key list over:", key,"ask:",
ask_str)
self.dict[key] = [0]
continue
rel_list.append(y_code)
self.dict[key] = rel_list
dict_items = self.dict.items()
print("size:",len(self.dict))
#print("dict:", dict_items)
self.fited = True
def relevance(self,ask,answer):
assert self.fited, "RelevanceChat 尚未进行 fit 操作"
top_key = jieba.lcut(''.join(ask))
#print("top key:", top_key)
key_set = set(top_key)
key_size = len(key_set)
if key_size == 0:
return 0.0
rel_num = 0
high_fw = 0
for key in key_set:
rel_list = self.dict.get(key)
if rel_list is not None:
if rel_list[0] == 0:
high_fw += 1
elif answer in rel_list:
rel_num += 1
if rel_num == 0:
relv_val = float(high_fw)/key_size
else:
relv_val = float(rel_num)/(key_size - high_fw)
return relv_val
def test():
x_data, y_data = pickle.load(open('pkl/chatbot.pkl', 'rb'))
ws_decode = pickle.load(open('pkl/ws_decode.pkl', 'rb'))
relv = RelevanceChat(5)
relv.fit(x_data,y_data,ws_decode)
count = 0
for ask,answer in zip(x_data,y_data):
decode = ws_decode.transform(answer)[0]
relv_val = relv.relevance(ask,decode)
if relv_val<0.7:
print("rel:", relv_val)
print("ask:",''.join(ask))
print("answer:", ''.join(answer),end='\n\n')
count += 1
print("same dialogue Confidence<0.7 count:", count)
count = 0
for i,answer in enumerate(y_data):
decode = ws_decode.transform(answer)[0]
for j,ask in enumerate(x_data):
if i==j:
continue
relv_val = relv.relevance(ask,decode)
if relv_val>0.7:
#print("rel:", relv_val)
#print("ask:",''.join(ask))
#print("answer:", ''.join(answer),end='\n\n')
count += 1
print("different dialogue Confidence<0.7 count:",count)
if __name__ == '__main__':
test()
测试结果
使用129条对话的语料集进行测试,结果见下图。语料集内对话置信度小于0.7的条数为0,不同对话间置信度大于0.7的有61,误报率:61/(129*128)=0.37%
本作品采用知识共享署名 4.0 国际许可协议进行许可。