- 🍨 本文为🔗365天深度学习训练营中的学习记录博客
- 🍖 原作者:K同学啊
GAN(生成对抗网络)简介:
生成对抗网络(GAN) 是一种深度学习模型,由两部分组成
生成器(Generator):生成伪造数据
判别器(Discriminator):判断数据是否真实。
训练过程:
生成器产生假的数据(如图像),然后传给判别器
判别器将数据分为“真实”或“假的”
生成器的目标是让判别器认为它生成的假数据是真实的,而判别器的目标是正确判断数据的真假
核心思想:
生成器和判别器互相对抗,通过不断的训练,最终生成器能够生成足够“真实”的数据,骗过判别器
1.定义超参数
定义模型的超参数:
n_epochs
:训练轮数
batch_size
:每批次训练数据的数量
lr
:学习率
b1
、b2
:Adam优化器的β值(控制动量)
latent_dim
:生成器输入的噪声维度
img_size
、channels
:图像大小和通道数
sample_interval
:保存生成图片的间隔
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import variable
import torch.nn as nn
import torch#创建1文件夹
os.makedirs('./images/',exist_ok=True)
os.makedirs('./save/',exist_ok=