GAN网络做异常检测

"""
@Time    : 2021/7/6 15:20
-------------------------------------------------
@Author  : sailorlee(lizeyi)
@email   : chrislistudy@163.com
-------------------------------------------------
@FileName: train_gan.py
@Software: PyCharm
"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from sklearn.metrics import classification_report, roc_curve, auc
from torch.utils.data import DataLoader, TensorDataset
from torchsummary import summary
import torch.nn as nn
from load_data.preprocessing import main_process, get_typeofmalware
from models.GAN import NetD,NetG
from utils.plot_culve import plot_ROC, plot_loss_new
from torch.autograd import Variable

num_epochs = 40

def calculate_losses(x, preds):
    losses = np.zeros(len(x))
    for i in range(len(x)):
        losses[i] = ((preds[i] - x[i]) ** 2).mean(axis=None)

    return losses

if __name__ == '__main__':


    #
    # 数据集的预处理 #
    #

    data, malware_flow, x_test, y_test = main_process()

    feature = data.shape[1]
    # data = (data.astype(np.float32) - 127.5) / 127.5
    # X_normal = data.values.reshape(data.shape[0], 44)  # 变成矩阵格式
    # print(X_normal.shape)
    X_normal = data.values
    # 切分数据集
    x_test = x_test.values
    x_test_tensor = torch.FloatTensor(x_test)

    X_normal = torch.FloatTensor(X_normal)
    X_normal_data = TensorDataset(X_normal)  # 对tensor进行打包
    train_loader = DataLoader(dataset=X_normal_data, batch_size=64,
                              shuffle=True)  # 数据集放入Data.DataLoader中,可以生成一个迭代器,从而我们可以方便的进行批处理

    #
    # 构建模型
    #
    D = NetD(feature)
    G = NetG(feature,64)

    # if torch.cuda.is_available():
    #     D = D.cuda()
    #     G = G.cuda()

    loss_fn = nn.BCELoss()  # binary cross entropy
    D_optim = torch.optim.Adam(D.parameters(),lr = 0.00008)
    G_optim = torch.optim.Adam(G.parameters(),lr= 0.00008)

    d_loss_list = []
    g_loss_list = []

    for epoch in range(num_epochs):
        total_loss = 0.
        for step, (x,) in enumerate(train_loader):

            batch_num = x.size(0)  # 一次送进去多少个样本
            real_label = Variable(torch.ones(batch_num,1)) # 定义真实的图片label为1
            fake_label = Variable(torch.zeros(batch_num,1))  # 定义假的图片的label为0

            # ########判别器训练train#####################
            # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
            # 计算真实图片的损失

            real_out = D(x)  # 将真实图片放入判别器中
            # print('real_out:',real_out.shape)
            # print('real_label:',real_label.shape)
            # 得到真实图片的loss
            d_loss_real = loss_fn(real_out, real_label)
            real_scores = real_out  # 得到真实图片的判别值,输出的值越接近1越好
            # 计算假的图片的损失
            z = Variable(torch.randn(batch_num,feature)) # 随机生成一些噪声
            # print('z',z.shape)
            fake_img = G(z).detach()  # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
            # print('fake_img:',fake_img.shape)
            fake_out = D(fake_img)  # 判别器判断假的图片,
            d_loss_fake = loss_fn(fake_out, fake_label)  # 得到假的图片的loss
            fake_scores = fake_out  # 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
            # 损失函数和优化
            d_loss = d_loss_real + d_loss_fake  # 损失包括判真损失和判假损失
            d_loss_list.append(d_loss.data.item())
            D_optim.zero_grad()  # 在反向传播之前,先将梯度归0
            d_loss.backward()  # 将误差反向传播
            D_optim.step()  # 更新参数

            # ==================训练生成器============================
            # ###############################生成网络的训练###############################
            # 原理:目的是希望生成的假的图片被判别器判断为真的图片,
            # 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
            # 反向传播更新的参数是生成网络里面的参数,
            # 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的
            # 这样就达到了对抗的目的
            # 计算假的图片的损失
            z = Variable(torch.randn(batch_num,feature)) # 得到随机噪声
            fake_img = G(z)  # 随机噪声输入到生成器中,得到一副假的图片
            output = D(fake_img)  # 经过判别器得到的结果

            g_loss = loss_fn(output, real_label)  # 得到的假的图片与真实的图片的label的loss
            g_loss_list.append(g_loss.data.item())
            # bp and optimize
            G_optim.zero_grad()  # 梯度归0
            g_loss.backward()  # 进行反向传播
            G_optim.step()  # .step()一般用在反向传播后面,用于更新生成网络的参数

            print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                  'D real: {:.6f},D fake: {:.6f}'.format(
                epoch, num_epochs, d_loss.data.item(), g_loss.data.item(),
                real_scores.data.mean(), fake_scores.data.mean()  # 打印的是真实图片的损失均值
            ))

    test_prediction = G(x_test_tensor)
    test_prediction = test_prediction.detach().numpy()
    losses = calculate_losses(test_prediction, x_test)
    losses_normal  = losses[0:37000]
    losses_malware = losses[37000:]
    sns.distplot(losses, kde=True)
    plt.show()
    sns.distplot(losses_normal, kde=True)
    plt.show()
    sns.distplot(losses_malware, kde=True)
    plt.show()
    print(test_prediction.shape)

    threhold = 2.16
    testing_set_predictions = np.zeros(len(losses))
    testing_set_predictions[np.where(losses > threhold)] = 1
    accuracy = accuracy_score(y_test, testing_set_predictions)
    recall = recall_score(y_test, testing_set_predictions)
    precision = precision_score(y_test, testing_set_predictions)
    f1 = f1_score(y_test, testing_set_predictions)
    print("accuracy:", accuracy)
    print("recall:", recall)
    print("precision:", precision)
    print("f1:", f1)

    # plot_loss_new(8,d_loss_list)
    # plot_loss_new(8,g_loss_list)
    # optimizer = torch.optim.Adam(model.parameters(), 0.00001)
    # loss_func = torch.nn.MSELoss(reduction='mean')
    # summary(model, ((data.shape[1], data.shape[1])))

首先对于总测试集流量的概率分布图:
在这里插入图片描述
对于正常流量的概率分布图:
在这里插入图片描述
对于恶意流量的概率分布图:
在这里插入图片描述

  • 0
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值