对抗生成网络代码Generative Adversarial Networks (GANs),Vanilla GAN,Deeply Convolutional GANs

这篇博客介绍了对抗生成网络(GANs)的基本概念,包括Vanilla GAN的Discriminator和Generator,以及GAN Loss的计算。作者详细阐述了如何实现和优化这两个网络,并探讨了Least Squares GAN作为替代损失函数的优势。此外,还讨论了Deeply Convolutional GANs的架构及其在训练过程中的应用。文章中还提到了一些关键函数,如sampler、np.prod和clamp的使用。
摘要由CSDN通过智能技术生成

理论部分: CS231n 2022PPT笔记- 生成模型Generative Modeling_iwill323的博客-CSDN博客

目录

导包

加载数据

Vanilla GAN

Discriminator

Generator

GAN Loss

bce loss

Optimizing

主函数

运行GAN

Least Squares GAN

训练

Deeply Convolutional GANs

Discriminator

Generator

训练

需要注意的函数

sampler

np.prod

clamp


We can think of the generator (𝐺) trying to fool the discriminator (𝐷) and the discriminator trying to correctly classify real vs. fake as a minimax game:

 where 𝑧∼𝑝(𝑧)are the random noise samples, 𝐺(𝑧) are the generated images using the neural network generator 𝐺, and 𝐷 is the output of the discriminator, specifying the probability of an input being real.

In this assignment, we will alternate the following updates:

  1. Update the generator (𝐺) to maximize the probability of the discriminator making the incorrect choice on generated data:

    maximize 𝔼𝑧∼𝑝(𝑧)[log𝐷(𝐺(𝑧))]

  2. Update the discriminator (𝐷), to maximize the probability of the discriminator making the correct choice on real and generated data:

    maximize 𝔼𝑥∼𝑝data[log𝐷(𝑥)]+𝔼𝑧∼𝑝(𝑧)[log(1−𝐷(𝐺(𝑧)))]

导包

# Setup cell.
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # Set default size of plots.
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

%load_ext autoreload
%autoreload 2

def show_images(images):
    # images: (N, C, H, W)
    images = np.reshape(images, [images.shape[0], -1]) # Images reshape to (batch_size, D).
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return

dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
NOISE_DIM = 96

加载数据

NUM_TRAIN = 50000  # 总的训练数据其实是60,000个
NUM_VAL = 5000

NOISE_DIM = 96
batch_size = 128

mnist_train = datasets.MNIST(
    './cs231n/datasets/MNIST_data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
loader_train = DataLoader(
    mnist_train,
    batch_size=batch_size,
    sampler=ChunkSampler(NUM_TRAIN, 0)  
)

mnist_val = datasets.MNIST(
    './cs231n/datasets/MNIST_data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
loader_val = DataLoader(
    mnist_val,
    batch_size=batch_size,
    sampler=ChunkSampler(NUM_VAL, NUM_TRAIN)
)
imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
print(imgs.shape) # (128, 784)
show_images(imgs)  # 查看其中一个batch的图片


class ChunkSampler(sampler.Sampler):
    """Samples elements sequentially from some offset.
    Arguments:
        num_samples: # of desired datapoints
        start: offset where we should start selecting from
    """
    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples

Vanilla GAN

Discriminator

The output of the discriminator should have shape [batch_size, 1], and

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值