一种评估LSTM模型置信度方法

为什么需要评估聊天机器人输出置信度

使用LSTM构建聊天机器人,无论输入是什么,就算输入完全不符合语法,模型都会给出一个输出,显然,这个输出不是我们想要的,如何识别模型输出是不是我们想要的?因此我们需要一种评估指标,评估模型输出的置信度。评估LSTM模型的置信度本质上是判断输入与模型输出是否属于训练语料集之内,因为LSTM模型是在语料集的输入与标签之间建立了映射关系,对于训练语料集之外的输入,LSTM模型输出是随机的。因此,可以通过判断输入与LSTM模型的输出是否属于训练语料集之内来评估LSTM模型的置信度。

算法

把训练语料集的提问分词:key1,key2,…keyn,做为字典的关键字,对应的应答列表为字典的值:[answer1,answer2,… answerk],相同的关键字加入同一应答列表,如下表:

在这里插入图片描述
调用LSTM模型做预测时,把输入做分词,得到:key1,key2,…keyn,分别查找关系字典,得到对应的应答列表,对比应答列表的answer与LSTM模型输出,如果命中则计数器count加1。如果应答列表为空则为不命中。置信度confidence用下式计算:
Confidence = count / key size

计算LSTM模型置信度python代码

import jieba
import pickle
class RelevanceChat():
    def __init__(self,topk=5):
            self.topk = topk
            self.fited = False
    def fit(self,x_data,y_data,ws_decode):
            self.dict = {}
            high_fw_max = int(len(x_data) * 0.6)
            for ask, answer in zip(x_data, y_data):
                 ask_str = ''.join(ask)
                 if len(ask_str) == 0:
                      continue
                 top_key = jieba.lcut(ask_str)
            	   #print("top key:", top_key)
            	   y_code = ws_decode.transform(answer)[0]
            	   key_set = set(top_key)
            	   for key in key_set:
            	       rel_list = []        
            	            if key in self.dict:                 
            	               rel_list = self.dict[key]                    
            	               if rel_list[0]==0:                        
                                   continue                    
            	               elif len(rel_list)>=high_fw_max:                  
            	                   print("key list over:", key,"ask:", 
            	                   		ask_str)                        
            	                   	self.dict[key] = [0]                        
            	                   	continue
                		 rel_list.append(y_code)                
                		 self.dict[key] = rel_list
        	dict_items = self.dict.items()        
        	print("size:",len(self.dict))        
        	#print("dict:", dict_items)
        	self.fited = True
        	
    def relevance(self,ask,answer):
        assert self.fited, "RelevanceChat 尚未进行 fit 操作"
        top_key = jieba.lcut(''.join(ask))
        #print("top key:", top_key)                
        key_set = set(top_key)
        key_size = len(key_set)        
        if key_size == 0:            
        	return 0.0
        rel_num = 0        
        high_fw = 0        
        for key in key_set:            
        	rel_list = self.dict.get(key)
         	if rel_list is not None:                
         		if rel_list[0] == 0:                    
         			high_fw += 1                
         		elif answer in rel_list:                    
         			rel_num += 1
           
       if rel_num == 0:            
       	  relv_val = float(high_fw)/key_size        
       else:            
       	  relv_val = float(rel_num)/(key_size - high_fw)
       return relv_val
                    
def test():    
	x_data, y_data = pickle.load(open('pkl/chatbot.pkl', 'rb'))    
	ws_decode = pickle.load(open('pkl/ws_decode.pkl', 'rb'))
	relv = RelevanceChat(5)    
	relv.fit(x_data,y_data,ws_decode)
	count = 0    
	for ask,answer in zip(x_data,y_data):        
	decode = ws_decode.transform(answer)[0]        
	relv_val = relv.relevance(ask,decode)
	if relv_val<0.7:            
		print("rel:", relv_val)            
		print("ask:",''.join(ask))            
		print("answer:", ''.join(answer),end='\n\n')            
		count += 1
		
	print("same dialogue Confidence<0.7 count:", count)
	
	count = 0    
	for i,answer in enumerate(y_data):        
		decode = ws_decode.transform(answer)[0]
		for j,ask in enumerate(x_data):            
			if i==j:                
				continue
		relv_val = relv.relevance(ask,decode)            
		if relv_val>0.7:                
			#print("rel:", relv_val)                
			#print("ask:",''.join(ask))                
			#print("answer:", ''.join(answer),end='\n\n')                
			count += 1                
	print("different dialogue Confidence<0.7 count:",count)
	
if __name__ == '__main__':    
	test()

测试结果

使用129条对话的语料集进行测试,结果见下图。语料集内对话置信度小于0.7的条数为0,不同对话间置信度大于0.7的有61,误报率:61/(129*128)=0.37%

在这里插入图片描述
知识共享许可协议
本作品采用知识共享署名 4.0 国际许可协议进行许可。

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值