代码基于 dennybritz/cnn-text-classification-tf 及 clayandgithub/zh_cnn_text_classify
参考文章 了解用于NLP的卷积神经网络(译) 及 TensorFlow实现CNN用于文本分类(译)
本文完整代码 - Widiot/cnn-zh-text-classification
1. 项目结构
以下是完整的目录结构示例,包括运行之后形成的目录和文件:
cnn-zh-text-classification/
data/
maildata/
cleaned_ham_5000.utf8
cleaned_spam_5000.utf8
ham_5000.utf8
spam_5000.utf8
runs/
1517572900/
checkpoints/
...
summaries/
...
prediction.csv
vocab
.gitignore
README.md
data_helpers.py
eval.py
text_cnn.py
train.py
各个目录及文件的作用如下:
- data 目录用于存放数据
- maildata 目录用于存放邮件文件,目前有四个文件,ham_5000.utf8 及 spam_5000.utf8 分别为正常邮件和垃圾邮件,带 cleaned 前缀的文件为清洗后的数据
- runs 目录用于存放每次运行产生的数据,以时间戳为目录名
- 1517572900 目录用于存放每次运行产生的检查点、日志摘要、词汇文件及评估产生的结果
- data_helpers.py 用于处理数据
- eval.py 用于评估模型
- text_cnn.py 是 CNN 模型类
- train.py 用于训练模型
2. 数据
2.1 数据格式
以分类正常邮件和垃圾邮件为例,如下是邮件数据的例子:
# 正常邮件
他们自己也是刚到北京不久 跟在北京读书然后留在这里工作的还不一样 难免会觉得还有好多东西没有安顿下来 然后来了之后还要带着四处旅游甚么什么的 却是花费很大 你要不带着出去玩,还真不行 这次我小表弟来北京玩,花了好多钱 就因为本来预定的几个地方因为某种原因没去 舅妈似乎就很不开心 结果就是钱全白花了 人家也是牢骚一肚子 所以是自己找出来的困难 退一万步说 婆婆来几个月
发文时难免欠点理智 我不怎么灌水,没想到上了十大了,拍的还挺欢,呵呵 写这个贴子,是由于自己太郁闷了,其时,我最主要的目的,是觉得,水木上肯定有一些嫁农村GG但现在很幸福的JJMM.我目前遇到的问题,我的确不知道怎么解决,所以发上来,问一下成功解决这类问题的建议.因为没有相同的经历和体会,是不会理解的,我在我身边就找不到可行的建议. 结果,无心得罪了不少人.呵呵,可能我想了太多关于城乡差别的问题,意识的比较深刻,所以不经意写了出来.
所以那些贵族1就要找一些特定的东西来章显自己的与众不同 这个东西一定是穷人买不起的,所以好多奢侈品也就营运诞生了 想想也是,他们要表也没有啊, 我要是香paris hilton那么有钱,就每天一个牌子的表,一个牌子的时装,一个牌子的汽车,哈哈,。。。要得就是这个派 俺连表都不用, 带手上都累赘, 上课又不能开手机, 所以俺只好经常退一下ppt去看右下脚的时间. 其实 贵族又不用赶时间, 要知道精确时间做啥? 表走的
# 垃圾邮件
中信(国际)电子科技有限公司推出新产品: 升职步步高、做生意发大财、连找情人都用的上,详情进入 网 址: http://www.usa5588.com/ccc 电话:020-33770208 服务热线:013650852999
以下不能正确显示请点此 IFRAME: http://www.ewzw.com/bbs/viewthread.php?tid=3809&fpage=1
尊敬的公司您好!打扰之处请见谅! 我深圳公司愿在互惠互利、诚信为本代开3厘---2点国税、地税等发票。增值税和海关缴款书就以2点---7点来代开。手机:13510631209 联系人:邝先生 邮箱:ao998@163.com 祥细资料合作告知,希望合作。谢谢!!
每个句子单独一行,正常邮件和垃圾邮件的数据分别存放在两个文件中。
2.2 数据处理
数据处理 data_helpers.py 的代码如下,与所参考的代码不同的是:
- load_data_and_labels():将函数的参数修改为以逗号分隔的数据文件的路径字符串,比如
'./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8'
,这样可以读取多个类别的数据文件以实现多分类问题 - read_and_clean_zh_file():将函数的 output_cleaned_file 修改为 boolean 类型,控制是否保存清洗后的数据,并在函数中判断,如果已经存在清洗后的数据文件则直接加载,否则进行清洗并选择保存
其他函数与所参考的代码相比变动不大:
import numpy as np
import re
import os
def load_data_and_labels(data_files):
"""
1. 加载所有数据和标签
2. 可以进行多分类,每个类别的数据单独放在一个文件中
2. 保存处理后的数据
"""
data_files = data_files.split(',')
num_data_file = len(data_files)
assert num_data_file > 1
x_text = []
y = []
for i, data_file in enumerate(data_files):
# 将数据放在一起
data = read_and_clean_zh_file(data_file, True)
x_text += data
# 形成数据对应的标签
label = [0] * num_data_file
label[i] = 1
labels = [label for _ in data]
y += labels
return [x_text, np.array(y)]
def read_and_clean_zh_file(input_file, output_cleaned_file=False):
"""
1. 读取中文文件并清洗句子
2. 可以将清洗后的结果保存到文件
3. 如果已经存在经过清洗的数据文件则直接加载
"""
data_file_path, file_name = os.path.split(input_file)
output_file = os.path.join(data_file_path, 'cleaned_' + file_name)
if os.path.exists(output_file):
lines = list(open(output_file, 'r').readlines())
lines = [line.strip() for line in lines]
else:
lines = list(open(input_file, 'r').readlines())
lines = [clean_str(seperate_line(line)) for line in lines]
if output_cleaned_file:
with open(output_file, 'w') as f:
for line in lines:
f.write(line + '\n')
return lines
def clean_str(string):
"""
1. 将除汉字外的字符转为一个空格
2. 将连续的多个空格转为一个空格
3. 除去句子前后的空格字符
"""
string = re.sub(r'[^\u4e00-\u9fff]', ' ', string)
string = re.sub(r'\s{2,}', ' ', string)
return string.strip()
def seperate_line(line):
"""
将句子中的每个字用空格分隔开
"""
return ''.join([word + ' ' for word in line])
def batch_iter(data, batch_size, num_epochs, shuffle=True):
'''
生成一个batch迭代器
'''
data = np.array(data)
data_size = len(data)
num_batches_per_epoch = int((data_size - 1) / batch_size) + 1
for epoch in range(num_epochs):
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
else:
shuffled_data = data
for batch_num in range(num_batches_per_epoch):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_idx:end_idx]
if __name__ == '__main__':
data_files = './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8'
x_text, y = load_data_and_labels(data_files)
print(x_text)
2.3 清洗标准
将原始数据进行清洗,仅保留汉字,并把每个汉字用一个空格分隔开,各个类别清洗后的数据分别存放在 cleaned 前缀的文件中,清洗后的数据格式如下:
本 公 司 有 部 分 普 通 发 票 商 品 销 售 发 票 增 值 税 发 票 及 海 关 代 征 增 值 税 专 用 缴 款 书 及 其 它 服 务 行 业 发 票 公 路 内 河 运 输 发 票 可 以 以 低 税 率 为 贵 公 司 代 开 本 公 司 具 有 内 外 贸 生 意 实 力 保 证 我 司 开 具 的 票 据 的 真 实 性 希 望 可 以 合 作 共 同 发 展 敬 侯 您 的 来 电 洽 谈 咨 询 联 系 人 李 先 生 联 系 电 话 如 有 打 扰 望 谅 解 祝 商 琪
3. 模型
CNN 模型类 text_cnn.py 的代码如下,修改的地方如下:
- 将 concat 和 reshape 的操作结点放在 concat 命名空间下,这样在 TensorBoard 中的节点图更加清晰合理
- 将计算损失值的操作修改为通过 collection 进行,并只计算 W 的 L2 损失值,删去了计算 b 的 L2 损失值的代码
import tensorflow as tf
import numpy as np
class TextCNN(object):
"""
字符级CNN文本分类
词嵌入层->卷积层->池化层->softmax层
"""
def __init__(self,
sequence_length,
num_classes,
vocab_size,
embedding_size,
filter_sizes,
num_filters,
l2_reg_lambda=0.0):
# 输入,输出,dropout的占位符
self.input_x = tf.placeholder(
tf.int32, [None, sequence_length], name='input_x')
self.input_y = tf.placeholder(
tf.float32, [None, num_classes],