文章目录
import torch
import torchvision
import numpy
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.utils import save_image
latent_dim = 100
batch_size = 128
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
data = datasets.MNIST(root='./data/mnist', train=True, transform=transform, download=True)
data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=4)
class Genetator(nn.Module):
def __init__(self):
super(Genetator, self).__init__()
self.models = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 256),
nn.BatchNorm1d(256, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512,0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 784),
nn.Tanh(),
)
def forward(self, x):
batch_size = x.shape[0]
x = self.models(x)
x = x.view(batch_size, 1, 28, 28)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
batch_size = x.shape[0]
x = x.view(batch_size, -1)
validity = self.model(x)
return validity
generator = Genetator()
generator.to(device)
discrimator = Discriminator()
discrimator.to(device)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discrimator.parameters(), lr=0.0002)
criterion = nn.BCELoss()
def train(epoch):
for i, (imgs, _) in enumerate(data_loader):
imgs = imgs.to(device)
valid = Variable(torch.Tensor(imgs.size(0), 1).fill_(1.0)).to(device)
fake = Variable(torch.Tensor(imgs.size(0), 1).fill_(0.0)).to(device)
'''
这是第一种训练方法,先训练generator
'''
z = Variable(torch.Tensor(numpy.random.normal(0, 1, (imgs.size(0), latent_dim)))).to(device)
gen_imgs = generator(z)
optimizer_G.zero_grad()
g_loss = criterion(discrimator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
optimizer_D.zero_grad()
real_loss = criterion(discrimator(imgs), valid)
fake_loss = criterion(discrimator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
if i % 50 == 0:
print(
"[Epoch %d/100] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, i, len(data_loader), d_loss.item(), g_loss.item())
)
save_image(gen_imgs.data[:25], 'images{}.png'.format(epoch), nrow=5, normalize=True)
if __name__ == '__main__':
epoches = 100
for epoch in range(epoches):
train(epoch)