GAN原理 & 代码解读

模型架构

在这里插入图片描述
在这里插入图片描述

代码

数据准备

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch

# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([
    transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)
    transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',
                               train=True,
                               transform=transform,
                               download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

'''
    输入:正态分布随机数噪声(长度为100)
    输出:生成的图片,(1,28,28)
    中间过程:
        linear1: 100 -> 256
        linear2: 256 -> 512
        linear3: 512 -> 28*28
        reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数
        self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),
                                   nn.Linear(256,512),nn.ReLU(),
                                    # 最后一层用tanh激活,将数据压缩到-1到1
                                   nn.Linear(512,28*28),nn.Tanh())
    def forward(self,x):
        img = self.model(x)
        img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)
        return img

定义判别器

'''
    判别器
    输入:(1,28,28)的图片
    输出:二分类的概率值 用sigmoid压缩到0-1之间
    内容:
    判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28,512),nn.LeakyReLU(),
            nn.Linear(512,256),nn.LeakyReLU(),
            nn.Linear(256,1),nn.Sigmoid(),
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.model(x)
        return x

初始化模型,优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

画生成器生成的图的绘图函数

def gen_img_plot(model,epoch,test_input):
    prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpy
    prediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

显示图片的函数

def img_plot(img):
    img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

定义训练函数


def train(num_epoch,test_input):
    D_loss = []
    G_loss = []
    # 训练循环
    for epoch in range(num_epoch):
        d_epoch_loss = 0
        g_epoch_loss = 0
        count = len(dataloader) # 返回批次数
        for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)
            # print(f'step={step},img.shape={img.shape}')
            # img_plot(img)
            img = img.to(device)
            size = img.size(0) # 得到一个批次的图片
            random_noise = torch.randn(size,100,device=device) # 生成器的输入

            '''一. 训练判别器'''
            '''用真实图片训练判别器'''
            dis_optim.zero_grad()
            real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果
            # 判别器在真实图像上的损失
            d_real_loss = bce_loss(real_output,
                                   # torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。
                                   torch.ones_like(real_output))
            d_real_loss.backward()

            '''用生成的图片训练判别器'''
            gen_img = gen(random_noise)
            # 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensor
            fake_output = dis(gen_img.detach())
            d_fake_loss = bce_loss(fake_output,
                                   torch.zeros_like(fake_output))
            d_fake_loss.backward()
            d_loss = d_real_loss+d_fake_loss
            dis_optim.step() # 对参数进行优化

            '''二.训练生成器'''
            gen_optim.zero_grad()
            # 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别
            fake_output = dis(gen_img)
            # 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。
            # 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。
            g_loss = bce_loss(fake_output,
                              torch.ones_like(fake_output))
            g_loss.backward()
            gen_optim.step()

            # 计算一个epoch的损失
            with torch.no_grad(): #  禁止梯度计算和参数更新
                d_epoch_loss +=d_loss
                g_epoch_loss +=g_loss
        # 计算整体loss每个epoch的平均Loss
        with torch.no_grad(): #  禁止梯度计算和参数更新
            d_epoch_loss /= count
            g_epoch_loss /= count
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)
            print('Epoch:', epoch+1)
            print(f'd_epoch_loss={d_epoch_loss}')
            print(f'g_epoch_loss={g_epoch_loss}')
            # 将16个长度为100的噪音输入到生成器并画图
            gen_img_plot(gen,test_input)

开始训练

'''开始计时'''
start_time = time.time()

'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')

'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:
    print(f'{round(run_time,2)}s')
else:
    print(f'{round(run_time/60,2)}minutes')

结果可视化

在这里插入图片描述

加载训练好的参数

gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))

用训练好的生成器生成图片并画图

test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)

在这里插入图片描述
GAN的生成是随机的,不同的噪声,生成不同的数字

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
StarGAN v2是一种先进的图像生成模型,旨在将一组输入图像转换为多个可能的目标域图像。该模型具有许多有用的功能和创新。 首先,StarGAN v2建立在StarGAN的基础上,通过引入一个新的概念,即多个生成器和判别器,大大提高了模型的生成能力。每个生成器与一个特定目标域相关联,并且可以从输入图像生成与目标域相关的图像。多个判别器用于提供有关输入图像和生成图像之间的真实性的反馈,从而帮助生成更高质量的图像。 其次,StarGAN v2引入了一个新的概念称为样式代码。样式代码是一个向量,代表了输入图像和目标域之间的潜在特征。通过改变样式代码的值,可以在目标域中生成具有不同外观和特征的图像。这使得模型更加灵活和可控,用户可以根据需要对图像进行个性化的转换。 另外,StarGAN v2还引入了两个重要的改进,称为判别器样式适应和循环一致性损失。判别器样式适应用于提高判别器的性能,使其能够更好地区分生成图像和目标域中真实图像之间的区别。循环一致性损失则用于确保生成器能够在两个目标域之间进行无缝转换,而不会丢失细节或信息。 最后,StarGAN v2通过使用特征对齐损失进一步提高了生成图像的质量。特征对齐损失用于确保在生成图像和真实图像之间的特征分布保持一致,从而使得生成图像更加逼真和真实。 总之,StarGAN v2是一个令人印象深刻的图像生成模型,通过引入多个生成器和判别器、样式代码、判别器样式适应、循环一致性损失和特征对齐损失,实现了高质量和高度可控的图像转换。它在许多应用领域,如人脸生成和图像风格迁移中具有巨大的潜力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

computer_vision_chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值