lastDay
GAN;DCGAN
GAN相关学习连接:
GAN的变种 https://github.com/hindupuravinash/the-gan-zoo (github,stars 7000)
“sakura小樱”博客 https://blog.csdn.net/Sakura55/article/details/81512600
知乎上 https://blog.csdn.net/Sakura55/article/details/81512600
1.GAN----2014
RNN,用判别模型来做生成模型
GAN,
https://www.boyuai.com/elites/course/cZu18YmweLv10OeV/jupyter/2dQ7K5yOLT4Dea6TrN6Fb
1.1 Generator
Our generator network will be the simplest network possible - a single layer linear model. This is since we will be driving that linear network with a Gaussian data generator. Hence, it literally only needs to learn the parameters to fake things perfectly.
class net_G(nn.Module):
def __init__(self):
super(net_G,self).__init__()
self.model=nn.Sequential(
nn.Linear(2,2),
)
self._initialize_weights()
def forward(self,x):
x=self.model(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Linear):
m.weight.data.normal_(0,0.02)
m.bias.data.zero_()
1.2 Discriminator
For the discriminator we will be a bit more discriminating: we will use an MLP with 3 layers to make things a bit more interesting.
class net_D(nn.Module):
def __init__(self):
super(net_D,self).__init__()
self.model=nn.Sequential(
nn.Linear(2,5),
nn.Tanh(),
nn.Linear(5,3),
nn.Tanh(),
nn.Linear(3,1),
nn.Sigmoid()
)
self._initialize_weights()
def forward(self,x):
x=self.model(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Linear):
m.weight.data.normal_(0,0.02)
m.bias.data.zero_()
1.3 Training
1.3.1 定义函数更新判别器
真实数据,label=1
生成数据,label=0
# Saved in the d2l package for later use
def update_D(X,Z,net_D,net_G,loss,trainer_D):
batch_size=X.shape[0]
Tensor=torch.FloatTensor
ones=Variable(Tensor(np.ones(batch_size))).view(batch_size,1)
zeros = Variable(Tensor(np.zeros(batch_size))).view(batch_size,1)
real_Y=net_D(X.float())
fake_X=net_G(Z)
fake_Y=net_D(fake_X)
loss_D=(loss(real_Y,ones)+loss(fake_Y,zeros))/2
loss_D.backward()
trainer_D.step()
return float(loss_D.sum())