GAN(生成对抗网络)

该文介绍了GAN(生成对抗网络)的基本原理,包括生成器和判别器的作用,以及它们之间的对抗学习过程。文中提供了一个使用PyTorch实现的GAN模型,用于在MNIST数据集上生成手写数字图像,展示了训练过程中的损失变化,并提供了代码示例。
摘要由CSDN通过智能技术生成

算法流程

G:G是一个生成器,随机噪声输入,图片输出G(z),选择噪声输入的原因是引入这种随机性,带来生成的多样性。

D:D是一个判别器,判断图片是否真实输入为图片,生成二分类0-1(sigmoid激活输出)。

流程:G由设计噪声生成一张图片,判别器接受真实的图片和生成的图片,尽量将两者区分开,将正确辨别真实和生成图片与否作为判别器的损失,生成器的损失是将能否生成近似真实图片而且使得判别器将生成的图片判定为真。

对抗:个人理解是生成器和判别器的对抗,相互促进作用,标签还是存在的(真-假),判别器是根据标签来做为进化的方向,生成器把“欺骗”判别器作为进化的方向,进一步判别器继续根据标签进化,这就使得生成器的欺骗能力越来越强,判别器的判断能力越来越强,防“欺骗”能力越来越强。判别器的输出是一个概率值,可以通过交叉熵来计算

但是这种网络的损失最终会不会收敛是一个问题,不收敛就代表生成器和判别器的功效不确定,但好在Goodfellow给出了证明,证明用不到不附。

 

这里给出GAN的公式

左式是真实数据,右式是生成数据。

对于D而言,左式越大越好,右式越大越好

对于G而言,右式越小越好

应用

生成人脸

图像增强

风格转换

声音转换

实验

import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

#数据加载部分
transform = transforms. Compose([transforms.ToTensor (),
                                 transforms.Normalize(0.5, 0.5)
                                 ])
train_ds = torchvision.datasets.MNIST( "data",
                                        train=True,
                                        transform=transform,
                                        download=True)
dataloader = torch.utils. data.Dataloader(train_ds, batch_size=64, shuffle=True)

#模型部分
#生成器部分
class Generator(nn.Module) :
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn. Sequential(
            nn.Linear(100,256),
            nn. ReLU () ,
            nn.Linear(256,512),
            nn. ReLU () ,
            nn.Linear(512,28*28),
            nn. Tanh ()
        )
def forward(self,x):
        img = self.main(x)
        img = img.view(-1,28,28,1)
        return img
#判别器部分

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).init()
        self.main = nn. Sequential(
            nn.Linear(28*28,512),
            nn. LeakyReLU(),
            nn. Linear(512,256),
            nn. LeakyReLU(),
            nn. Linear(256,1),
            nn.Sigmoid()
        )

        def forward(self, x):
            x = x.view(-1,28 * 28)
            x = self.main(x)
            return x




device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)

d_optim = torch.optim.Adam(dis.parameters (),lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters (),lr=0.0001)

loss_fn = torch.nn.BCEWithLogitsLoss()

def gen_img_plot(model, epoch, test_input):
    prediction = np.squeeze (model(test_input). detach ().cpu ().numpy ())
    fig = plt.figure(figsize=(4,4))
    for i in range(16):
        plt.subplot(4,4, i+1)
        plt.imshow((prediction[i] + 1)/2)
        plt.axis(' off')
        plt.show()

D_loss = []
G_loss = []

for epoch in range(20) :
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    for step, (img,_) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device-device)
        d_optim.zero_grad()
        # real output对真实图片的预测
        real_output = dis(img)
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output),device = device)

        d_real_loss.backward()

        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())
        d_fake_loss = loss_fn(fake_output,
                                   torch.zeros_like((fake_output),)
                                   ,device = device)
        d_fake_loss.backward()
        d_loss = d_real_loss+d_fake_loss
        d_optim.step()

        g_optim.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output,
                            torch.ones_like(fake_output),  # 生成器的损失
                            device = device)
        g_loss.backward()
        g_optim.step()

        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch:',epoch)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值