基于tensorflow2.0 +CNN的垃圾邮件文本分类

本文介绍了如何使用Tensorflow2.0进行垃圾邮件文本分类,重点是将Tensorflow1.x代码更新为2.0版本,包括修改神经网络结构、添加激活函数、调整损失函数和精度计算,并分享了训练过程及结果展示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

此文章是修改了 https://www.jianshu.com/p/f3ca7c75401b的代码。原代码是基于Tensorflow1.x 的。训练用的数据也可以从原链接中找到:

数据集下载链接: https://pan.baidu.com/s/10QtokJ8_tkK6I3GifalxWg 提取码: uytb

这次修改最主要的问题有三个:

  1. 文字处理要使用Embedding作为输入层
  2. keras中的神经网络激励函数默认是None,需要添加。我就是忘了加上激励函数,结果网络不收敛。
  3. 最后一层的输出函数由tf.nn.softmax换成tf.nn.selu

主要的修改如下:

  1. 增加了loss 函数
    loss函数定义如下:
    def loss_function(self,labels, pred_proba):
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=pred_proba)
        return tf.reduce_mean(loss)

2 . 增加精度函数
精度函数定义如下:

    def compute_accuracy(self,labels, pred_proba):
        predictions = tf.argmax(pred_proba, axis=1)
        return accuracy_score(labels, predictions)
  1. 修改了网络结构
    去掉了Dropout,保留也可以。但是对于网络收敛以及减少训练时间作用并不大。
    减小了CNN的卷积核数目,全连接网络的节点也减少了。

  2. 修改了每次训练的样本数据
    原代码每次的样本数据集是32个,不利于训练。修改为128个后,效果比较好。
    全部代码如下:



from sklearn.model_selection import train_test_split
import pickle
from collections import Counter
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers,optimizers

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, accuracy_score
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
import warnings
warnings.filterwarnings('ignore')
import time
import matplotlib.pyplot as plt

class TextConfig():

    vocab_size = 5000
    seq_length = 600
    embedding_dim = 64  # 词向量维度
    num_filters = 16  # 卷积核数目
    kernel_size = 5  # 卷积核尺
    hidden_dim = 32  # 全连接层神经元
    dropout_keep_prob = 0.5  # dropout保留比例
    learning_rate = 1e-3  # 学习率
    batch_size = 128   # 每批训练大小
    num_iteration = 5000 #迭代次数
    print_per_batch = num_iteration / 100 #打印间隔
    
class TextClassification():
    def config(self):
        textConfig = TextConfig()
        self.vocab_size = textConfig.vocab_size
        self.seq_length = textConfig.seq_length
        self.embedding_dim = textConfig.embedding_dim
        self.num_filters = textConfig.num_filters
        self.kernel_size = textConfig.kernel_size
        self.hidden_dim = textConfig.hidden_dim
        self.dropout_keep_prob = textConfig.dropout_keep_prob
        self.learning_rate = textConfig.learning_rate
        self.batch_size = textConfig.batch_size
        self.print_per_batch = textConfig.print_per_batch
        self.num_iteration = textConfig.num_iteration
    
    def __init__(self, *args):
        self.config()
        if len(args) == 2:
            content_list = args[0]
            label_list = args[1]
            train_X, test_X, train_y, test_y = train_test_split(content_list, label_list)
            self.train_content_list = train_X
            self.train_label_list = train_y
            self.test_content_list = test_X
            self.test_label_list = test_y
            self.content_list = self.train_content_list + self.test_content_list
        elif len(args) == 4:
            self.train_content_list = args[0]
            self.train_label_list = args[1]
            self.test_content_list = args[2]
            self.test_label_list = args[3]
            self.content_list = self.train_content_list + self.test_content_list
        else:
            print('false to init TextClassification object')
        self.autoGetNumClasses()
    
    def autoGetNumClasses(self):
        label_list = self.train_label_list + self.test_label_list
        self.num_classes = np.unique(label_list).shape[0]
    
    def getVocabularyList(self, content_list, vocabulary_size):
        allContent_str = ''.join(content_list)
        counter = Counter(allContent_str)
        vocabulary_list = [k[0] for k in counter.most_common(vocabulary_size)]
        return ['PAD'] + vocabulary_list

    def prepareData(self)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值