【NLP】TensorFlow实现CNN用于中文文本分类

代码基于 dennybritz/cnn-text-classification-tfclayandgithub/zh_cnn_text_classify
参考文章 了解用于NLP的卷积神经网络(译)TensorFlow实现CNN用于文本分类(译)
本文完整代码 - Widiot/cnn-zh-text-classification

1. 项目结构

以下是完整的目录结构示例,包括运行之后形成的目录和文件:

cnn-zh-text-classification/
	data/
		maildata/
			cleaned_ham_5000.utf8
			cleaned_spam_5000.utf8
			ham_5000.utf8
			spam_5000.utf8
	runs/
		1517572900/
			checkpoints/
				...
			summaries/
				...
			prediction.csv
			vocab
	.gitignore
	README.md
	data_helpers.py
	eval.py
	text_cnn.py
	train.py

各个目录及文件的作用如下:

  • data 目录用于存放数据
  • maildata 目录用于存放邮件文件,目前有四个文件,ham_5000.utf8 及 spam_5000.utf8 分别为正常邮件和垃圾邮件,带 cleaned 前缀的文件为清洗后的数据
  • runs 目录用于存放每次运行产生的数据,以时间戳为目录名
  • 1517572900 目录用于存放每次运行产生的检查点、日志摘要、词汇文件及评估产生的结果
  • data_helpers.py 用于处理数据
  • eval.py 用于评估模型
  • text_cnn.py 是 CNN 模型类
  • train.py 用于训练模型

2. 数据

2.1 数据格式

以分类正常邮件和垃圾邮件为例,如下是邮件数据的例子:

# 正常邮件
他们自己也是刚到北京不久 跟在北京读书然后留在这里工作的还不一样 难免会觉得还有好多东西没有安顿下来 然后来了之后还要带着四处旅游甚么什么的 却是花费很大 你要不带着出去玩,还真不行 这次我小表弟来北京玩,花了好多钱 就因为本来预定的几个地方因为某种原因没去 舅妈似乎就很不开心 结果就是钱全白花了 人家也是牢骚一肚子 所以是自己找出来的困难 退一万步说 婆婆来几个月
发文时难免欠点理智 我不怎么灌水,没想到上了十大了,拍的还挺欢,呵呵 写这个贴子,是由于自己太郁闷了,其时,我最主要的目的,是觉得,水木上肯定有一些嫁农村GG但现在很幸福的JJMM.我目前遇到的问题,我的确不知道怎么解决,所以发上来,问一下成功解决这类问题的建议.因为没有相同的经历和体会,是不会理解的,我在我身边就找不到可行的建议. 结果,无心得罪了不少人.呵呵,可能我想了太多关于城乡差别的问题,意识的比较深刻,所以不经意写了出来.
所以那些贵族1就要找一些特定的东西来章显自己的与众不同 这个东西一定是穷人买不起的,所以好多奢侈品也就营运诞生了 想想也是,他们要表也没有啊, 我要是香paris hilton那么有钱,就每天一个牌子的表,一个牌子的时装,一个牌子的汽车,哈哈,。。。要得就是这个派 俺连表都不用, 带手上都累赘, 上课又不能开手机, 所以俺只好经常退一下ppt去看右下脚的时间. 其实 贵族又不用赶时间, 要知道精确时间做啥? 表走的

# 垃圾邮件
中信(国际)电子科技有限公司推出新产品: 升职步步高、做生意发大财、连找情人都用的上,详情进入 网  址:  http://www.usa5588.com/ccc 电话:020-33770208   服务热线:013650852999
以下不能正确显示请点此 IFRAME: http://www.ewzw.com/bbs/viewthread.php?tid=3809&fpage=1
尊敬的公司您好!打扰之处请见谅! 我深圳公司愿在互惠互利、诚信为本代开3厘---2点国税、地税等发票。增值税和海关缴款书就以2点---7点来代开。手机:13510631209       联系人:邝先生  邮箱:ao998@163.com     祥细资料合作告知,希望合作。谢谢!!

每个句子单独一行,正常邮件和垃圾邮件的数据分别存放在两个文件中。

2.2 数据处理

数据处理 data_helpers.py 的代码如下,与所参考的代码不同的是:

  • load_data_and_labels():将函数的参数修改为以逗号分隔的数据文件的路径字符串,比如 './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8',这样可以读取多个类别的数据文件以实现多分类问题
  • read_and_clean_zh_file():将函数的 output_cleaned_file 修改为 boolean 类型,控制是否保存清洗后的数据,并在函数中判断,如果已经存在清洗后的数据文件则直接加载,否则进行清洗并选择保存

其他函数与所参考的代码相比变动不大:

import numpy as np
import re
import os


def load_data_and_labels(data_files):
    """
    1. 加载所有数据和标签
    2. 可以进行多分类,每个类别的数据单独放在一个文件中
    2. 保存处理后的数据
    """
    data_files = data_files.split(',')
    num_data_file = len(data_files)
    assert num_data_file > 1
    x_text = []
    y = []
    for i, data_file in enumerate(data_files):
        # 将数据放在一起
        data = read_and_clean_zh_file(data_file, True)
        x_text += data
        # 形成数据对应的标签
        label = [0] * num_data_file
        label[i] = 1
        labels = [label for _ in data]
        y += labels
    return [x_text, np.array(y)]


def read_and_clean_zh_file(input_file, output_cleaned_file=False):
    """
    1. 读取中文文件并清洗句子
    2. 可以将清洗后的结果保存到文件
    3. 如果已经存在经过清洗的数据文件则直接加载
    """
    data_file_path, file_name = os.path.split(input_file)
    output_file = os.path.join(data_file_path, 'cleaned_' + file_name)
    if os.path.exists(output_file):
        lines = list(open(output_file, 'r').readlines())
        lines = [line.strip() for line in lines]
    else:
        lines = list(open(input_file, 'r').readlines())
        lines = [clean_str(seperate_line(line)) for line in lines]
        if output_cleaned_file:
            with open(output_file, 'w') as f:
                for line in lines:
                    f.write(line + '\n')
    return lines


def clean_str(string):
    """
    1. 将除汉字外的字符转为一个空格
    2. 将连续的多个空格转为一个空格
    3. 除去句子前后的空格字符
    """
    string = re.sub(r'[^\u4e00-\u9fff]', ' ', string)
    string = re.sub(r'\s{2,}', ' ', string)
    return string.strip()


def seperate_line(line):
    """
    将句子中的每个字用空格分隔开
    """
    return ''.join([word + ' ' for word in line])


def batch_iter(data, batch_size, num_epochs, shuffle=True):
    '''
    生成一个batch迭代器
    '''
    data = np.array(data)
    data_size = len(data)
    num_batches_per_epoch = int((data_size - 1) / batch_size) + 1
    for epoch in range(num_epochs):
        if shuffle:
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data
        for batch_num in range(num_batches_per_epoch):
            start_idx = batch_num * batch_size
            end_idx = min((batch_num + 1) * batch_size, data_size)
            yield shuffled_data[start_idx:end_idx]


if __name__ == '__main__':
    data_files = './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8'
    x_text, y = load_data_and_labels(data_files)
    print(x_text)

2.3 清洗标准

将原始数据进行清洗,仅保留汉字,并把每个汉字用一个空格分隔开,各个类别清洗后的数据分别存放在 cleaned 前缀的文件中,清洗后的数据格式如下:

本 公 司 有 部 分 普 通 发 票 商 品 销 售 发 票 增 值 税 发 票 及 海 关 代 征 增 值 税 专 用 缴 款 书 及 其 它 服 务 行 业 发 票 公 路 内 河 运 输 发 票 可 以 以 低 税 率 为 贵 公 司 代 开 本 公 司 具 有 内 外 贸 生 意 实 力 保 证 我 司 开 具 的 票 据 的 真 实 性 希 望 可 以 合 作 共 同 发 展 敬 侯 您 的 来 电 洽 谈 咨 询 联 系 人 李 先 生 联 系 电 话 如 有 打 扰 望 谅 解 祝 商 琪

3. 模型

CNN 模型类 text_cnn.py 的代码如下,修改的地方如下:

  • 将 concat 和 reshape 的操作结点放在 concat 命名空间下,这样在 TensorBoard 中的节点图更加清晰合理
  • 将计算损失值的操作修改为通过 collection 进行,并只计算 W 的 L2 损失值,删去了计算 b 的 L2 损失值的代码
import tensorflow as tf
import numpy as np


class TextCNN(object):
    """
    字符级CNN文本分类
    词嵌入层->卷积层->池化层->softmax层
    """

    def __init__(self,
                 sequence_length,
                 num_classes,
                 vocab_size,
                 embedding_size,
                 filter_sizes,
                 num_filters,
                 l2_reg_lambda=0.0):

        # 输入,输出,dropout的占位符
        self.input_x = tf.placeholder(
            tf.int32, [None, sequence_length], name='input_x')
        self.input_y = tf.placeholder(
            tf.float32, [None, num_classes],
  • 5
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 19
    评论
### 语言模型 #### 数据预处理 中文语言模型基本都是基于字的模型,因此不需要做太多的操作 #### 文件结构介绍 * config文件:配置各种模型的配置参数 * data:存放训练集和测试集 * data_helpers:提供数据处理的方法 * ckpt_model:存放checkpoint模型文件 * pb_model:存放pb模型文件 * outputs:存放vocab,word_to_index, label_to_index, 处理后的数据 * models:存放模型代码 * trainers:存放训练代码 * predictors:存放预测代码 #### 训练模型 * python train.py --config_path="config.json" #### 预测模型 * 预测代码都在predict.py中,初始化Predictor对象,调用predict方法即可。 * 执行python test.py文件可以生成诗词 #### 模型的配置参数详述 #### char rn:字符级的rnn,基于字符的语言模型 * model_name:模型名称 * epochs:全样本迭代次数 * checkpoint_every:迭代多少步保存一次模型文件 * eval_every:迭代多少步验证一次模型 * learning_rate:学习速率 * optimization:优化算法 * embedding_size:embedding层大小 * hidden_sizes:rnn隐层大小 * batch_size:批样本大小 * sequence_length:序列长度 * vocab_size:词汇表大小 * keep_prob:保留神经元的比例 * max_grad_norm:梯度阶段临界值 * train_data:训练数据的存储路径 * eval_data:验证数据的存储路径 * output_path:输出路径,用来存储vocab,处理后的训练数据,验证数据 * word_vectors_path:词向量的路径 * ckpt_model_path:checkpoint 模型的存储路径 * pb_model_path:pb 模型的存储路径

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值