环境: python 3.7 + pytorch 1.0.1
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
def init_weight(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0., 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1., 0.02)
m.bias.data.fill_(0.)
class DCGenerator(nn.Module):
def __init__(self, convs):
super(DCGenerator, self).__init__()
self.convs = nn.ModuleList()
in_channels = 1
for i, (out_channels, kernel_size, stride, padding) in enumerate(convs):
self.convs.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False))
if i < len(convs)-1:
# we use BN and RELU for each layer except the output
self.convs.append(nn.BatchNorm2d(out_channels))
self.convs.append(nn.ReLU())
else:
# in output, we use Tanh to generate data in [-1, 1]
self.convs.append(nn.Tanh())
in_channels = out_channels
self.apply(init_weight)
def forward(self, input):
out = input
for module in self.convs:
out = module(out)
return out
class Discriminator(nn.Module):
def __init__(self, convs):
super(Discriminator, self).__init__()
self.convs = nn.ModuleList()
in_channels = 1
for i, (out_channels, kernel_size, stride, padding) in enumerate(convs):
self.convs.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False))
if i != 0 and i != len(convs)-1:
# we donot use BN in the input layer of D
self.convs.append(nn.BatchNorm2d(out_channels))
if i != len(convs)-1:
self.convs.append(nn.LeakyReLU(0.2))
in_channels = out_channels
#self.cls = nn.Linear(out_channels*in_width*in_height, nout)
self.apply(init_weight)
def forward(self, input):
out = input
for layer in self.convs:
out = layer(out)
out = out.view(out.size(0), -1)
out = F.sigmoid(out)
return out
train.py
from __future__ import print_function
import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
def sample_noise(batch_size, channels):
return torch.randn(batch_size, channels, 1, 1).float()
max_iter = 25
download = True
trans = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5,], [0.5,])])
mnist = datasets.MNIST('./', train=True, transform=trans, download=download)
batch_size = 64
use_cuda = True
if __name__ == '__main__':
d_convs = [(32, 4, 2, 1), (64, 4, 2, 1), (1, 7, 1, 0)]
discriminator = Discriminator(d_convs)
g_convs = [(64, 7, 1, 0), (32, 4, 2, 1), (1, 4, 2, 1)]
generator = DCGenerator(g_convs)
print(discriminator)
print(generator)
if use_cuda:
discriminator, generator = discriminator.cuda(), generator.cuda()
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
real_label, fake_label = 1, 0
criterion = nn.BCELoss()
if use_cuda:
criterion = criterion.cuda()
fixed_noise = sample_noise(batch_size, 1)
if use_cuda:
fixed_noise = fixed_noise.cuda()
fixed_noise = Variable(fixed_noise, volatile=True)
for epoch in range(1, max_iter+1):
for i, (x, _) in enumerate(dataloader):
batch_size = x.size(0)
# training D on real data
optimizer_d.zero_grad()
x = Variable(x)
if use_cuda:
x = x.cuda()
output = discriminator(x)
real_v = Variable(torch.Tensor(batch_size).fill_(real_label).float())
if use_cuda:
real_v = real_v.cuda()
loss_d = criterion(output, real_v)
loss_d.backward()
Dx = output.data.mean(dim=0)[0]
# training D on fake data
z = sample_noise(batch_size, 1)
z = Variable(z)
if use_cuda:
z = z.cuda()
fake = generator(z)
output = discriminator(fake.detach())
fake_v = Variable(torch.Tensor(batch_size).fill_(fake_label).float())
if use_cuda:
fake_v = fake_v.cuda()
loss_g = criterion(output, fake_v)
loss_g.backward()
optimizer_d.step()
err_D = loss_d.data + loss_g.data
# training G
optimizer_g.zero_grad()
output = discriminator(fake)
real_v = Variable(torch.Tensor(batch_size).fill_(real_label).float())
if use_cuda:
real_v = real_v.cuda()
loss = criterion(output, real_v)
loss.backward()
optimizer_g.step()
err_G = loss.data
DGz = output.data.mean(dim=0)[0]
print('[{:02d}/{:02d}],[{:03d}/{:03d}], errD: {:.4f}, D(x): {:.4f}, errG: {:.4f}, D(G(z)): {:.4f}'.format(
epoch, max_iter, i, len(dataloader), err_D, Dx, err_G, DGz))
fake = generator(fixed_noise)
save_image(fake.data, './mnist-fake-{:02d}.png'.format(epoch),
normalize=True)