DCGAN应用的简单示例(基于MNIST手写数据集的训练)
1. 说明
已经是第三篇了,前面说了那么多,这一篇来实战,采用的是TENSORFLOW。使用DCGAN对MNIST数据集进行训练的简单示例。注意,篇幅有限,代码中省略了一些具体的实现细节,与原理一一对应。
2. 代码
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
# 定义生成器
class Generator(nn.Module):
# 省略具体实现
# 定义判别器
class Discriminator(nn.Module):
# 省略具体实现
# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# 初始化模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 训练模型
for epoch in range(num_epochs):
for i, data in enumerate(trainloader, 0):
real, _ = data
real = real.to(device)
# 训练判别器
optimizerD.zero_grad()
fake = netG(torch.randn(batch_size, nz, 1, 1, device=device))
real_label = torch.full((batch_size,), 1, device=device)
fake_label = torch.full((batch_size,), 0, device=device)
real_loss = criterion(netD(real), real_label)
fake_loss = criterion(netD(fake.detach()), fake_label)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizerD.step()
# 训练生成器
optimizerG.zero_grad()
fake_label.fill_(1)
g_loss = criterion(netD(fake), fake_label)
g_loss.backward()
optimizerG.step()
# 打印训练信息
if i % 100 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
% (epoch, num_epochs, i, len(trainloader),
d_loss.item(), g_loss.item()))
# 展示生成的图像
noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake = netG(noise)
imshow(torchvision.utils.make_grid(fake.cpu().detach()))