参考资料:https://morvanzhou.github.io/tutorials/
详细说明已在程序注释给出。
"""
View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/
My Youtube Channel: https://www.youtube.com/user/MorvanZhou
Dependencies:
torch: 1.1.0
numpy
matplotlib
修改作者:kindy
修改日期:2019-05-28
假设想要模仿一位大师的绘画作品(real),模仿的作品(Generator生成的fake样本)需要不断和大师的作品进行对比(Discriminator进行比较),当两者差距最小(Discriminator的目标函数criter function)时,说明模仿的作品足够以假乱真,这就是相互博弈的一个过程。
相关函数:
np.vstack(tup):垂直方向拼接数据(first axis比较好理解),tup是一个tuple,例如a大小为(2,3),
b大小为(4,3),那么通过:
>>> c = np.vstack((a,b))
拼接之后的c大小为(6,3),效果是数据a下面堆叠b(垂直方向)
np.random.uniform(a,b,size):均匀分布U(a,b)
np.newaxis:增加维度,例如:
>>> a=rand(10)
>>> a.shape
>>> (10,)
>>> a[:,newaxis].shape
>>> (10, 1)
>>> a[newaxis,:].shape
>>> (1, 10)
"""
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import *
# torch.manual_seed(1) # reproducible
# np.random.seed(1)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001 # learning rate for generator
LR_D = 0.0001 # learning rate for discriminator
N_IDEAS = 5 # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15 # it could be total point G can draw in the canvas,每个作品的维度
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)]) # (64.15)
def artist_works(): # painting from the famous artist (real target)
a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis] # (64,1)
paintings = a * np.power(PAINT_POINTS, 2) + (a-1)
paintings = torch.from_numpy(paintings).float()
return paintings
G = nn.Sequential( # Generator
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
)
D = nn.Sequential( # Discriminator
nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # tell the probability that the art work is made by artist
)
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
plt.ion() # something about continuous plotting
for step in range(10000):
artist_paintings = artist_works() # real painting from artist
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
G_paintings = G(G_ideas) # fake painting from G (random ideas)
# D判断真画为真实画作的概率, try to increase this prob
prob_artist0 = D(artist_paintings)
# D判断赝品为真实画作的概率, try to reduce this prob
prob_artist1 = D(G_paintings)
# 下面的G_loss和D_loss是自定义的目标函数(损失函数),并非取自torch已有的损失函数
# 对判别器而言,目标是能正确区分真画和赝品
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
# 对生成器来说,肯定是模仿画判断为真画的概率越高越好,概率越高说明赝品越接近真品
#所以我们的目标是增大 prob_artist1D(G_paintings) 的值,在torch里面优化目标是最小化,所以写成下面形式,可以单独绘制函数 f(x)=log(1-x)感受一下函数的形式
# G_loss = torch.mean(torch.log(1. - prob_artist1))
# 将 G_loss换成以下目标函数试试:
G_loss = - torch.mean(torch.log(prob_artist1))
opt_D.zero_grad()
D_loss.backward(retain_graph=True) # reusing computational graph
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
print('step:{}'.format(step))
if step % 50 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.01)
plt.ioff()
plt.show()
关键点:
(1)D,G的网络构建;
(2)D,G两个网络的目标函数选取问题,好多衍生的GAN都是在目标函数选取基础是实现的