DCGAN
DCGAN structure
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Generator(nn.Module):
"""
Input shape: (N, in_dim)
Output shape: (N, 3, 64, 64)
"""
def __init__(self, in_dim, dim=64):
super(Generator, self).__init__()
def dconv_bn_relu(in_dim, out_dim):
return nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
padding=2, output_padding=1, bias=False),
nn.BatchNorm2d(out_dim),
nn.ReLU()
)
self.l1 = nn.Sequential(
nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
nn.BatchNorm1d(dim * 8 * 4 * 4),
nn.ReLU()
)
self.l2_5 = nn.Sequential(
dconv_bn_relu(dim * 8, dim * 4),
dconv_bn_relu(dim * 4, dim * 2),
dconv_bn_relu(dim * 2, dim),
nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
nn.Tanh()
)
self.apply(weights_init)
def forward(self, x):
y = self.l1(x)
y = y.view(y.size(0), -1, 4, 4)
y = self.l2_5(y)
return y
class Discriminator(nn.Module):
"""
Input shape: (N, 3, 64, 64)
Output shape: (N, )
"""
def __init__(self, in_dim, dim=64):
super(Discriminator, self).__init__()
def conv_bn_lrelu(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, 5, 2, 2),
nn.BatchNorm2d(out_dim),
nn.LeakyReLU(0.2),
)
""" Medium: Remove the last sigmoid layer for WGAN. """
self.ls = nn.Sequential(
nn.Conv2d(in_dim, dim, 5, 2, 2),
nn.LeakyReLU(0.2),
conv_bn_lrelu(dim, dim * 2),
conv_bn_lrelu(dim * 2, dim * 4),
conv_bn_lrelu(dim * 4, dim * 8),
nn.Conv2d(dim * 8, 1, 4),
nn.Sigmoid(),
)
self.apply(weights_init)
def forward(self, x):
y = self.ls(x)
y = y.view(-1)
return y
-
nn.ConvTranspose2d( ) 转置卷积
即反卷积,用来扩大图像尺寸
假设L层 feature map 宽W 高H 通道数C
输入的卷积核: kernel size, Stride,padding
对该feature map插值: H’ = H + (s - 1) ∗(H - 1) ; W’ = W + (s - 1) ∗(W - 1)
新的卷积核变为:kernel size = 1, Stride =Stride ,padding = kernel size-padding-1
分别记为k,s,p
则L+1层feature map:H’ = (H + 2p - k)/s +1 ; W’ = (W + 2p - k)/s +1 -
WGAN
loss and Optimizer (WGAN)
- BCE loss
- RMSProp
Train Loop
- 固定genetator参数,训练Discriminator
- 更新step次discriminator后,固定其参数,更新generator参数。
steps = 0
for e, epoch in enumerate(range(n_epoch)):
progress_bar = qqdm(dataloader)
for i, data in enumerate(progress_bar):
imgs = data
imgs = imgs.cuda()
bs = imgs.size(0)
# ============================================
# Train D
# ============================================
z = Variable(torch.randn(bs, z_dim)).cuda()
r_imgs = Variable(imgs).cuda()
f_imgs = G(z)
# WGAN Loss
# loss_D = -torch.mean(D(r_imgs)) + torch.mean(D(f_imgs))
# Model backwarding
D.zero_grad()
loss_D.backward()
# Update the discriminator.
opt_D.step()
""" Medium: Clip weights of discriminator. """
# for p in D.parameters():
# p.data.clamp_(-clip_value, clip_value)
# ============================================
# Train G
# ============================================
if steps % n_critic == 0:
# Generate some fake images.
z = Variable(torch.randn(bs, z_dim)).cuda()
f_imgs = G(z)
# Model forwarding
f_logit = D(f_imgs)
WGAN Loss
loss_G = -torch.mean(D(f_imgs))
# Model backwarding
G.zero_grad()
loss_G.backward()
# Update the generator.
opt_G.step()
steps += 1
# Set the info of the progress bar
# Note that the value of the GAN loss is not directly related to
# the quality of the generated images.
progress_bar.set_infos({
'Loss_D': round(loss_D.item(), 4),
'Loss_G': round(loss_G.item(), 4),
'Epoch': e+1,
'Step': steps,
})
G.eval()
f_imgs_sample = (G(z_sample).data + 1) / 2.0
filename = os.path.join(log_dir, f'Epoch_{epoch+1:03d}.jpg')
torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
print(f' | Save some samples to {filename}.')
# Show generated images in the jupyter notebook.
grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
G.train()
if (e+1) % 5 == 0 or e == 0:
# Save the checkpoints.
torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G.pth'))
torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D.pth'))
注:
生成dataloader时,number_workers=0