导入包,加载数据
from __future__ import division
from torchvision import models, transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Lambda(lambda x: x.repeat(3,1,1)),
transforms.Normalize(mean=(0.5),
std=(0.5))
])
mnist_data = torchvision.datasets.MNIST('./mnist_data',train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset=mnist_data,batch_size=batch_size,shuffle=True
)
定义模型
image_size = 784
hidden_size = 256
D = nn.Sequential(
nn.Linear(image_size,hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size,hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size,1),
nn.Sigmoid()
)
latent_size = 64
G = nn.Sequential(
nn.Linear(latent_size,hidden_size),
nn.ReLU(),
nn.Linear(hidden_size,hidden_size),
nn.ReLU(),
nn.Linear(hidden_size,image_size),
nn.Tanh()
)
D = D.to(device)
G = G.to(device)
训练模型
loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(),lr=0.0002)
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
total_step=len(dataloader)
num_epoch = 2
for epoch in range(num_epoch):
for i, (images,_) in enumerate(dataloader):
# print(images.size)
batch_size=images.size(0)
images = images.reshape(batch_size, image_size).to(device)
real_labels = torch.ones(batch_size,1).to(device)
fake_labels = torch.zeros(batch_size,1).to(device)
output = D(images)
d_loss_real = loss_fn(output,real_labels)
real_score = output # 对于D,越大越好
#开始生成fake image
z = torch.randn(batch_size,latent_size).to(device)
fake_images = G(z)
output = D(fake_images)
d_loss_fake = loss_fn(output,fake_labels)
fake_score = output #对于D,越小越好,说明判别很准
#开始优化D
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()
#开始优化G
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
output = D(fake_images)
g_loss = loss_fn(output,real_labels) #希望它造假造得越真越好,接近真
reset_grad()
g_loss.backward()
g_optimizer.step()
if i % 1000 == 0:
print('epoch [{}/{}], step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(
epoch,num_epoch,i,total_step,d_loss.item(),g_loss.item(),real_score.mean().item(),fake_score.mean().item()
))
测试效果
# 打印一张生成得图片
z =torch.randn(1,latent_size).to(device)
fake_images = G(z).view(28,28).data.cpu().numpy()
plt.imshow(fake_images)
# 打印一张真实得图片
plt.imshow(images[0].view(28,28).data.cpu().numpy())