(参考 b站大神 日月光华 教程复现)
原理:
这里通过一个简单的手写字体生产网络了解GAN的基本原理,主要包含generator 和 discriminator两部分,其中generator 的输入 是正太分布噪声,输出是28x28的图像, discriminator 的输入是28x28的图像,分别是真是图像和generator生成的图像,输出是概率值。对抗的含义体现在优化目标上,generator 的目标是使输出的图像尽量被discriminator判别为真,而discriminator的目标是尽量将噪声生成的图像判别为假,真实图像判别为真。感兴趣的小伙伴可以在此网络上进行修改:
添加可变的学习率
添加卷积层
增加网络深度
话不多说,上代码,可运行:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
#draw , pred
def draw_genImg(model, input):
pred = np.squeeze(model(input).detach().cpu().numpy())
size = input.shape[0]
for i in range(size):
plt.subplot(4, int(size/4), i+1)
plt.imshow((pred[i]+1)/2) #[0,1]
plt.show()
#generator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100,256),
nn.LeakyReLU(),
nn.Linear(256,512),
nn.LeakyReLU(),
nn.Linear(512,28*28),
nn.Tanh()
)
def forward(self,x):
x= self.main(x)
img = x.view(-1,28,28,1)
return img
#discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28,512),
nn.LeakyReLU(),
nn.Linear(512,256),
nn.LeakyReLU(),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self,x):
x = x.view(-1,28*28)
conf = self.main(x)
return conf
if __name__=="__main__":
batch_size = 64
epoch_size = 200
pred_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
test_input = torch.randn(pred_size,100, device =device)
# data
transform = transforms.Compose([
transforms.ToTensor(), #0-1
transforms.Normalize(0.5,0.5), #(mean-var:0.5,0.5)->-1,1
])
train_ds = torchvision.datasets.MNIST('data', train = True, transform = transform, download=True) #data folder
dataloader = torch.utils.data.DataLoader(train_ds, batch_size= batch_size, shuffle = True)
gen = Generator().to(device)
dis = Discriminator().to(device)
g_optim = torch.optim.Adam(gen.parameters(), lr = 0.0001)
d_optim = torch.optim.Adam(dis.parameters(), lr = 0.0001)
loss_fn = nn.BCELoss()
D_loss = []
G_loss = []
for epoch in range(epoch_size):
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader)
for step,(img,_) in enumerate(dataloader):
img = img.to(device)
size= img.size(0)
random_noise = torch.randn(size, 100, device = device)
#optim the generator, gen->1
g_optim.zero_grad()
fake_out = dis(gen(random_noise))
g_loss = loss_fn(fake_out, torch.ones_like(fake_out))
g_loss.backward()
g_optim.step()
#optim the discriminator, img->1, gen->0
d_optim.zero_grad()
real_out = dis(img)
d_real_loss = loss_fn(real_out, torch.ones_like(real_out))
d_real_loss.backward()
fake_out = dis(gen(random_noise).detach())
d_fake_loss = loss_fn(fake_out, torch.zeros_like(fake_out))
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
d_optim.step() #over
#统计
with torch.no_grad():
g_epoch_loss += g_loss
d_epoch_loss += d_loss
#统计每次迭代后的loss 和生成结果
with torch.no_grad():
g_epoch_loss /= count
d_epoch_loss /= count
G_loss.append(g_epoch_loss)
D_loss.append(d_epoch_loss)
print('epoch: ', epoch, 'g_epoch_loss:', g_epoch_loss, 'd_epoch_loss:', d_epoch_loss)
if epoch > epoch_size - 5:
draw_genImg(gen, test_input)
上面分别为第1代、第30代和第200代的结果。