引用库
# from __future__ import print_function
import json
import six
# import paddlehub as hub
import paddlehub as hub
参数配置
test_text = ["打击色情犯罪,是每一个人的责任", '引导未成年人远离黄赌毒'] #文本内容
use_gpu=True #是否调用GPU
batch_size=2 #批处理大小
模型配置
通过lstm实现
def porn_detection_lstm(test_text, use_gpu ,batch_size):
# Load porn_detection_lstm module
porn_detection_lstm = hub.Module(name="porn_detection_lstm")
input_dict = {"text": test_text}
results = porn_detection_lstm.detection(data=input_dict, use_gpu=use_gpu, batch_size=batch_size)
return results
通过gru实现
def porn_detection_gru(test_text, use_gpu, batch_size):
# Load porn_detection_gru module
porn_detection_gru = hub.Module(name="porn_detection_gru")
input_dict = {"text": test_text}
results = porn_detection_gru.detection(data=input_dict, use_gpu=use_gpu, batch_size=batch_size)
return results
通过cnn实现
def porn_detection_cnn(test_text, use_gpu, batch_size):
# Load porn_detection_cnn module
porn_detection_cnn = hub.Module(name="porn_detection_cnn")
results = porn_detection_cnn.detection(texts=test_text, use_gpu=use_gpu, batch_size=batch_size)
return results
调用三种不同网络结构
lstm = porn_detection_lstm(test_text, use_gpu, batch_size) #调用lstm
gru = porn_detection_gru(test_text, use_gpu, batch_size) #调用gru
cnn = porn_detection_cnn(test_text, use_gpu, batch_size) #调用cnn
输出结果
Tips:label的值越高则涉及的可能性越高
# print(lstm)
# print(gru)
# print(cnn)
def output_dict(dic,pre='LSTM'):
print('------\nPorn detection with {}'.format(pre))
for line in dic:
for k,v in line.items():
print('{:20s}: {}'.format(k,v))
output_dict(lstm)
output_dict(gru,'GRU')
output_dict(cnn,'CNN')
for index, text in enumerate(test_text):
lstm[index]["text"] = text
print("文本内容:",text)
label = (lstm[index]["porn_detection_label"] + gru[index]["porn_detection_label"] + cnn[index]["porn_detection_label"])
porn_probs = (lstm[index]["porn_probs"] + gru[index]["porn_probs"] + cnn[index]["porn_probs"])/3
not_porn_probs = (lstm[index]["not_porn_probs"] + gru[index]["not_porn_probs"] + cnn[index]["not_porn_probs"])/3
print('label:%0.0f, porn_probs:%0.5f, not_porn_probs:%0.5f' % (label, porn_probs, not_porn_probs))