此文章是修改了 https://www.jianshu.com/p/f3ca7c75401b的代码。原代码是基于Tensorflow1.x 的。训练用的数据也可以从原链接中找到:
数据集下载链接: https://pan.baidu.com/s/10QtokJ8_tkK6I3GifalxWg 提取码: uytb
这次修改最主要的问题有三个:
- 文字处理要使用Embedding作为输入层
- keras中的神经网络激励函数默认是None,需要添加。我就是忘了加上激励函数,结果网络不收敛。
- 最后一层的输出函数由tf.nn.softmax换成tf.nn.selu
主要的修改如下:
- 增加了loss 函数
loss函数定义如下:
def loss_function(self,labels, pred_proba):
loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=pred_proba)
return tf.reduce_mean(loss)
2 . 增加精度函数
精度函数定义如下:
def compute_accuracy(self,labels, pred_proba):
predictions = tf.argmax(pred_proba, axis=1)
return accuracy_score(labels, predictions)
-
修改了网络结构
去掉了Dropout,保留也可以。但是对于网络收敛以及减少训练时间作用并不大。
减小了CNN的卷积核数目,全连接网络的节点也减少了。 -
修改了每次训练的样本数据
原代码每次的样本数据集是32个,不利于训练。修改为128个后,效果比较好。
全部代码如下:
from sklearn.model_selection import train_test_split
import pickle
from collections import Counter
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers,optimizers
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, accuracy_score
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
import warnings
warnings.filterwarnings('ignore')
import time
import matplotlib.pyplot as plt
class TextConfig():
vocab_size = 5000
seq_length = 600
embedding_dim = 64 # 词向量维度
num_filters = 16 # 卷积核数目
kernel_size = 5 # 卷积核尺
hidden_dim = 32 # 全连接层神经元
dropout_keep_prob = 0.5 # dropout保留比例
learning_rate = 1e-3 # 学习率
batch_size = 128 # 每批训练大小
num_iteration = 5000 #迭代次数
print_per_batch = num_iteration / 100 #打印间隔
class TextClassification():
def config(self):
textConfig = TextConfig()
self.vocab_size = textConfig.vocab_size
self.seq_length = textConfig.seq_length
self.embedding_dim = textConfig.embedding_dim
self.num_filters = textConfig.num_filters
self.kernel_size = textConfig.kernel_size
self.hidden_dim = textConfig.hidden_dim
self.dropout_keep_prob = textConfig.dropout_keep_prob
self.learning_rate = textConfig.learning_rate
self.batch_size = textConfig.batch_size
self.print_per_batch = textConfig.print_per_batch
self.num_iteration = textConfig.num_iteration
def __init__(self, *args):
self.config()
if len(args) == 2:
content_list = args[0]
label_list = args[1]
train_X, test_X, train_y, test_y = train_test_split(content_list, label_list)
self.train_content_list = train_X
self.train_label_list = train_y
self.test_content_list = test_X
self.test_label_list = test_y
self.content_list = self.train_content_list + self.test_content_list
elif len(args) == 4:
self.train_content_list = args[0]
self.train_label_list = args[1]
self.test_content_list = args[2]
self.test_label_list = args[3]
self.content_list = self.train_content_list + self.test_content_list
else:
print('false to init TextClassification object')
self.autoGetNumClasses()
def autoGetNumClasses(self):
label_list = self.train_label_list + self.test_label_list
self.num_classes = np.unique(label_list).shape[0]
def getVocabularyList(self, content_list, vocabulary_size):
allContent_str = ''.join(content_list)
counter = Counter(allContent_str)
vocabulary_list = [k[0] for k in counter.most_common(vocabulary_size)]
return ['PAD'] + vocabulary_list
def prepareData(self)