pytorch基于GAN生成对抗网络的数据集扩充


前言

GAN对抗生成网络可以在数据集量少不足的情况下,根据这部分少量的数据集的特征来生成更多的新的数据集达到数据集扩充的目的,这篇文章前面部分先做个大概介绍后面有实例,都比较简单好理解,不想看理论的小伙伴可以直接跳到代码。

另外说一下,这篇文章是更新的第二个版本,去年写的是GAN的基础版本,介于很多朋友在问DCGAN和想要生成彩色图的缘故,于是做了一些改进,将GAN加入CNN变成了DCGAN之后,使得网络对于图像的特征提取能力更强参数也更少,效果有了非常好的改善,并且支持生成彩色图像(只需要修改代码前几行中的参数)

一、GAN基本原理

1.GAN结构图

在这里插入图片描述
GAN由两个模型构成, 判别模型和生成模型, 判别模型可用于训练, 也可用于测试, 但生成模型只能用于测试。生成模型捕捉真实样本的分布, 并根据分布生成新的fake样本;判别器是判别输入是真实样本还是fake样本的二分类器。模型G和D通过不断的对抗训练,使D正确判别训练样本来源,同时使G生成的fake样本与真实样本更相像。

2.GAN目标函数

在这里插入图片描述
GAN是生成网络和判别网络的博弈问题,判别网络D希望真实样本x的概率值越大越好,同时希望判断fake样本G(z)为真实样本的概率值越小越好,而生成网络G希望fake样本G(z)与x越相似越好,让判别网络判断其为真实样本的概率D(G(z))越高越好。

二、实例(完整代码:https://github.com/Programmerfei/Pytorch-Gan-based-dataset-expansion.git)

1.项目流程图

(这个流程图是用原始train训练的模型一和扩充的fake加上train训练的模型二准确率的对比流程图,如果只是想通过GAN生成数据就只参考这个流程图的左半部分)
在这里插入图片描述

流程图说明:1.将原始数据划分为train和val。 2.把train的图片送入GAN网络训练得到GAN的生成模型和判别模型,同时将train的图片送入CNN网络中训练得到第一个识别模型。 3.随机生成一些噪声点输入到步骤2中训练的生成模型中,得到若干输出的fake图片 4.将步骤3得到的fake图片和train的图片组合得到一个在原始数据集上加入了fake样本进行扩充后的新训练集 5.将新的训练集送入与步骤1相同的CNN网络中训练得到第二个识别模型 6.将val的图片送入步骤2和步骤5得到的两个识别模型中,对比预测准确率得到实验结论:用GAN生成的fake样本加入到识别模型的训练当中可以有效提高模型的泛化能力从而提高识别准确率。

2.项目代码

(本文只写训练GAN和用生成网络做数据扩充的代码,也就是流程图的左边部分)
注意:运行代码之前先将代码和数据按照目录结构放好,避免找不到库或数据

2.1解析mnist二进制文件保存为图片

(可以用其它数据集,训练什么类型就可以生成什么类型)
解析代码:

import os
import struct
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm

MNIST_data_dir = 'data/MNIST/raw/'  # MNIST数据文件路径,这里存放的是二进制文件
train_val_data_dir = 'data/MNIST/'  # train和val的数据保存路径,train是6w张数据,val是1w张数据
Number_of_requirements = 500  # 每个数字取多少张数据作为训练数据及测试数据,解析到足量则提前结束


def read_idx(filename):
    """
    二进制文件解析函数
    filename:二进制文件路径
    """
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)


def save_img(data, labels, t_v):
    """
    图片保存函数
    data: 二进制文件解析出来的图片数据
    labels: 标签
    t_v: train或val
    """
    count_dict = {}
    for i in tqdm(range(len(data)), desc=t_v):
        label = labels[i]
        folder = os.path.join(t_v, str(label))
        if not os.path.exists(folder):
            os.makedirs(folder)
        if sum(count_dict.values()) == 10*Number_of_requirements:  # 如果每个数字都达到需求个数,则结束
            break
        # 如果这个数字的个数达到要求则跳过这个数字的保存
        if str(label) in count_dict and count_dict[str(label)] == Number_of_requirements:
            continue
        # if os.path.exists(os.path.join(folder, f'image_{i}.png')):   #如果图片存在先删除之前保存的,再重新保存新的图片(防止之前保存的有问题)
        #     os.remove(os.path.join(folder, f'image_{i}.png'))
        cv2.imwrite(os.path.join(folder, f'image_{i}.jpg'), data[i])
        # 保存一次图片,这个数字的计数+1,如果字典中没有,即为该数字的第一张图,赋值为1
        count_dict[str(label)] = count_dict[str(label)] + \
            1 if str(label) in count_dict else 1
    print('数量已达要求,停止解析:\n', count_dict)


if __name__ == '__main__':
    for data_path, label_path, t_v in zip(['train-images-idx3-ubyte', 't10k-images-idx3-ubyte'],
                                          ['train-labels-idx1-ubyte',
                                              't10k-labels-idx1-ubyte'],
                                          ['train', 'val']):
        data = read_idx(os.path.join(MNIST_data_dir, data_path))  # 解析图片文件
        labels = read_idx(os.path.join(
            MNIST_data_dir, label_path))  # 解析label文件
        save_img(data, labels, os.path.join(train_val_data_dir, t_v))  # 保存图片

2.2训练GAN生成网络和判别网络

注意:修改图片路径和模型保存路径,导入库文件是否存在

# coding=utf-8
# -*- coding=utf-8 -*-
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from utils.data_reader import LoadData  # 数据读取
from utils.network_structure import generator, discriminator  # 网路结构
from utils.functional_functions import init_weights  # 参数初始化

# 数据目录
route = 'data\MNIST'  # 数据目录
result_save_path = 'model/GAN_model'  # 模型和训练过程中生成的fake样本的保存目录
drop_last = False  # 不够一个批次的数据是否舍弃掉,数据量多可以选择True
if not os.path.exists(result_save_path):
    os.mkdir(result_save_path)  # 如果没有保存路径的目录文件夹则进行创建

# 训练相关的参数
lr_d = 0.002  # 判别器学习率
lr_g = 0.002  # 生成器学习率
batch_size = 100  # 一个批次的大小
num_epoch = 300  # 训练迭代次数
output_loss_Interval_ratio = 10  # 间隔多少个epoch打印一次损失
save_model_Interval_ratio = 100  # 间隔多少个epoch保存一次训练过程中的fake图片

# 网络结构相关的参数
g_d_nc = 1  # d的输入通道和g的输出通道,RGB为3,GRAY为1
g_input = 100  # g的输入噪声点个数

# 定义loss的度量方式
criterion = nn.BCELoss()  # 单目标二分类交叉熵函数
# 实例化生成器和判别器
d = discriminator(number_of_channels=g_d_nc).cuda()
g = generator(noise_number=g_input,
              number_of_channels=g_d_nc).cuda()  # 模型迁移至GPU
# 定义 优化函数 学习率
d_optimizer = torch.optim.Adam(
    d.parameters(), lr=lr_d, betas=(0.5, 0.999))  # Adam优化器
g_optimizer = torch.optim.Adam(g.parameters(), lr=lr_g, betas=(0.5, 0.999))

# 调试代码,用于验证输入图像大小和g网络结构的适配性
# # 下面注释的这几行代码用于调试g网络层输入输出大小用
# z = torch.randn(batch_size,g_input,1,1).cuda()  # 随机生成一些噪声
# for i in g.gen:
#     print(i(z).shape)
#     z=i(z)

# # 下面注释的这几行代码用于调试d网络层输入输出大小用
# z = torch.randn(batch_size,g_input,1,1).cuda()  # 随机生成一些噪声
# fake_img=g(z)
# for i in d.dis:
#     print(i(fake_img).shape)
#     fake_img=i(fake_img)

for number in range(0, 10):  # 0-9每一个数字单独训练
    # 初始化网络每一层的参数
    d.apply(init_weights), g.apply(init_weights)

    # #恢复训练
    # g=torch.load(os.path.join(result_save_path,str(number),str(number)+'_g__last.pth'))
    # d=torch.load(os.path.join(result_save_path,str(number),str(number)+'_d__last.pth'))

    # 初始化训练数据读取器
    train_dataset = LoadData(os.path.join(route, 'train', str(
        number)), number_of_channels=g_d_nc)  # dataset
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                              shuffle=True, drop_last=drop_last)  # dataloader

    loss_list_g, loss_list_d = [], []  # 保存每一个epoch的损失值
    for epoch in tqdm(range(0, num_epoch+1), desc='epoch'):  # 迭代num_epoch个epoch
        batch_d_loss, batch_g_loss = 0, 0  # 累加每个epoch中全部batch的损失值,最后平均得到每个epoch的损失值
        for img, label in train_loader:  # 每个batch_size的图片
            img_number = len(img)  # 该批次有多少张图片
            real_img = img.cuda()  # 将tensor放入cuda中
            real_label = torch.ones(img_number).cuda()  # 定义真实的图片label为1
            fake_label = torch.zeros(img_number).cuda()  # 定义假的图片的label为0

            # ==================训练判别器==================
            # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
            # 计算真实图片的损失
            real_out = d(real_img)  # 将真实图片放入判别器中
            real_label = real_label.reshape([-1, 1])  # shape (n) -> (n,1)
            d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss
            real_scores = real_out  # 得到真实图片的判别值,输出的值越接近1越好
            # 计算假的图片的损失
            z = torch.randn(img_number, g_input, 1, 1).cuda()  # 随机生成一些噪声
            # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
            fake_img = g(z).detach()
            fake_out = d(fake_img)  # 判别器判断假的图片,
            fake_label = fake_label.reshape([-1, 1])  # shape (n) -> (n,1)
            d_loss_fake = criterion(fake_out, fake_label)  # 得到假的图片的loss
            fake_scores = fake_out  # 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
            # 合计判别器的总损失
            d_loss = d_loss_real + d_loss_fake  # 损失包括判真损失和判假损失
            # 反向传播,参数更新
            d_optimizer.zero_grad()  # 在反向传播之前,先将梯度归0
            d_loss.backward()  # 将误差反向传播
            d_optimizer.step()  # 更新参数

            # ==================训练生成器==================
            # 原理:目的是希望生成的假的图片被判别器判断为真的图片,
            # 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
            # 反向传播更新的参数是生成网络里面的参数,
            # 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的
            # 这样就达到了对抗的目的
            # 计算假的图片的损失
            z = torch.randn(img_number, g_input, 1, 1).cuda()  # 得到随机噪声
            fake_img = g(z)  # 随机噪声输入到生成器中,得到一副假的图片
            output = d(fake_img)  # 经过判别器得到的结果
            g_loss = criterion(output, real_label)  # 得到的假的图片与真实的图片的label的loss
            # 反向传播,参数更新
            g_optimizer.zero_grad()  # 梯度归0
            g_loss.backward()  # 进行反向传播
            g_optimizer.step()  # .step()一般用在反向传播后面,用于更新生成网络的参数

            # ==================累加总损失值,后面进行损失值可视化==================
            batch_d_loss += d_loss  # 累加每一个batch的损失值
            batch_g_loss += g_loss  # 累加每一个batch的损失值

        # # 调整学习率,当判别器损失足够小的时候,大幅度降低d的学习率,防止d过于完美,导致g无法训练(增加epoch次数可以开启)
        # if d_loss < 0.5:
        #     for i in d_optimizer.param_groups:
        #         i['lr']=lr_d/10

        # 将该轮的损失函数值保存到列表当中
        # 保存g损失值为列表,将所有batch累加的损失值除以batch数即该轮epoch的损失值
        loss_list_g.append(batch_g_loss.item()/len(train_loader))
        loss_list_d.append(batch_d_loss.item()/len(train_loader))  # 保存d损失值为列表

        # 打印中间的损失  #间隔output_loss_Interval_ratio个epoch打印一次损失
        if epoch % output_loss_Interval_ratio == 0:
            print('\nnumber:{} Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                  'D real: {:.6f},D fake: {:.6f}'.format(
                      number, epoch, num_epoch,
                      batch_d_loss.item()/len(train_loader),
                      batch_g_loss.item()/len(train_loader),
                      real_scores.data.mean(),
                      fake_scores.data.mean()
                  ))  # 打印每个epoch的d和g损失值(越小越好)和d的判别值(real越接近1越好,fake越接近0越好)

        # 创建保存模型和生成fake样本以及loss图的目录
        if not os.path.exists(os.path.join(result_save_path, str(number))):
            os.mkdir(os.path.join(result_save_path, str(number)))

        # 保存生成的fake图片,间隔save_model_Interval_ratio个epoch保存一次
        if epoch % save_model_Interval_ratio == 0:
            save_image(fake_img, os.path.join(result_save_path, str(number),
                                              str(number)+'_fake_epoch'+str(epoch)+'.jpg'))

        # 保存模型,for分别保存g和d,每个epoch都保存一次last.pth
        for g_or_d, g_d_name in zip([g, d], ['_g_', '_d_']):
            torch.save(g_or_d, os.path.join(result_save_path,
                       str(number), str(number)+g_d_name+'last.pth'))

        # 保存loss图像
        plt.plot(range(len(loss_list_g)), loss_list_g, label="g_loss")
        plt.plot(range(len(loss_list_d)), loss_list_d, label="d_loss")
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.legend()
        plt.savefig(os.path.join(result_save_path, str(number), 'loss.jpg'))
        plt.clf()

    print('\n')

训练过程中生成网络的效果变化:
在这里插入图片描述

2.3使用生成网络制造fake样本,扩充数据集

注意:修改生成图片保存路径和模型存放路径
扩充代码:

import os
from tqdm import tqdm

import torch
from torchvision.utils import save_image 

img_number=500  #每一个数字生成多少张fake图片

result_save_path='model/GAN_model'  #训练好的生成网络模型的目录
fakedata_save_path='data/MNIST_fake/train/'  #生成的fake图片保存目录

if not os.path.exists(fakedata_save_path):
    os.makedirs(fakedata_save_path)

for number in range(0,10):
    g=torch.load(os.path.join(result_save_path,str(number),str(number)+'_g_last.pth'))  #加载模型
    fake_save_dir=os.path.join(fakedata_save_path,str(number))    #保存图片的目录路径
    if not os.path.exists(fake_save_dir):  #如果没有这个路径则创建
        os.mkdir(fake_save_dir)
    
    g.eval()#进入验证模式,不用计算梯度和参数更新
    g_input=next(g.children())[0].in_channels  #获取模型的输入通道数

    for i in tqdm(range(img_number),desc=f'number{number}'):
        z = torch.randn(1,g_input,1,1).cuda()  # 随机生成一些噪声
        fake_img = g(z).detach()  # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
        save_image(fake_img,os.path.join(fake_save_dir,
                        str(number)+'_fake_'+str(i)+'.jpg'))  #保存fake样本
        
    

列举部分真实样本和fake样本
真实样本:在这里插入图片描述
fake样本:
在这里插入图片描述

三、目录结构展示

1、目录结构图

在这里插入图片描述
说明:MNIST数据集直接下载官方文件,大小52.4M,其它没有后缀的就是文件夹,有后缀的就是对应类型的文件。
运行代码之前先按照这个目录结构创建目录和存放数据集。训练好的模型和生成的图片最终也会存放到model和data对应的目录下。

2、utils中的代码

data_reader.py

from torch.utils.data import Dataset
from torchvision.transforms import transforms
from PIL import Image
import os

imgsz=28  #缩放图像的大小

#定义数据读取器
class LoadData(Dataset):
    def __init__(self, dir_path, number_of_channels):
        self.imgs_info = [(os.path.join(dir_path,img),dir_path[-1]) for img in os.listdir(dir_path)]

        self.tf = transforms.Compose([
            # 将图片尺寸resize到512*512
            transforms.Resize((imgsz,imgsz)),
            # 将图片转化为Tensor格式
            transforms.ToTensor(),
            #将图片通道数转化为模型输入通道数
            transforms.Grayscale(number_of_channels),
            # 标准化(当模型出现过拟合的情况时,用来降低模型的复杂度)
            transforms.Normalize([0.5]*number_of_channels, [0.5]*number_of_channels)  # 图像标准化
            ])

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img = Image.open(img_path)
        img = img.convert('RGB')
        img = self.tf(img)
        return img,float(label)

    def __len__(self):
        return len(self.imgs_info)

functional_functions.py

import torch.nn as nn

#初始化网络参数函数,用于下一个数字开始训练之前
def init_weights(m):
    if hasattr(m,'weight'):
        nn.init.uniform_(m.weight,-0.1,0.1)

network_structure.py

import torch.nn as nn


ndf=64 #判别网络卷积核个数的倍数
ngf=64 #生成网络卷积核个数的倍数


"""
关于转置卷积:
当padding=0时,卷积核刚好和输入边缘相交一个单位。因此pandding可以理解为卷积核向中心移动的步数。 
同时stride也不再是kernel移动的步数,变为输入单元彼此散开的步数,当stride等于1时,中间没有间隔。
"""

#生成器网络G
class generator(nn.Module):
    def __init__(self,noise_number,number_of_channels):
        """
        noise_number:输入噪声点个数
        number_of_channels:生成图像通道数
        """
        super(generator,self).__init__()
        self.gen = nn.Sequential(
            # 输入大小  batch x noise_number x 1 * 1
            nn.ConvTranspose2d(noise_number , ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 输入大小 batch x (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 输入大小 batch x (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 2, ngf , 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf ),
            nn.ReLU(True),
            # 输入大小 batch x (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf   , number_of_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # 输出大小 batch x (nc) x 64 x 64
       )

    def forward(self, x):
        out = self.gen(x)
        return out
    
#判别器网络D
class discriminator(nn.Module):
    def __init__(self,number_of_channels):
        """
        number_of_channels:输入图像通道数
        """
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            # 输入大小 batch x g_d_nc x 64*64
            nn.Conv2d(number_of_channels, ndf  , 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf ),
            nn.LeakyReLU(0.2, inplace=True),
            # 输入大小 batch x ndf x 32*32
            nn.Conv2d(ndf , ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输入大小 batch x (ndf*2) x 16*16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输入大小 batch x (ndf*8) x 4*4
            nn.Conv2d(ndf * 4 , 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # 输出大小 batch x 1 x 1*1
        )

    def forward(self, x):
        x=self.dis(x).view(x.shape[0],-1)
        return x
    
#分类网络CNN
class classification_model(nn.Module):
    def __init__(self,n_classes,number_of_channels):
        """
        n_classes:类别数
        """
        super(classification_model,self).__init__()
        self.structure=nn.Sequential(
            nn.Conv2d(number_of_channels, 6, kernel_size=5, stride=1, padding=2),  # (m,6,28,28)
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (m,6,14,14)

            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),  # (m,16,10,10)
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (m,16,5,5)

            nn.Conv2d(16, n_classes, kernel_size=5, stride=1, padding=0),  # (m,10,1,1)
            nn.Softmax(dim=1)
        )   
    
    def forward(self,x):
        out=self.structure(x)
        out=out.reshape(out.shape[0],-1)
        return out
评论 92
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

programmer.Mr.Fei,

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值