论文《Self-Attention Generative Adversarial Networks》
地址:https://arxiv.org/abs/1805.08318
代码地址:https://github.com/heykeetae/Self-Attention-GAN
按照代码流程进行记录
默认参数设置
adv_loss = 'hinge'
attn_path = './attn'
batch_size = 64
beta1 = 0.0
beta2 = 0.9
d_conv_dim = 64
d_iters = 5
d_lr = 0.0004
dataset = 'celeb'
g_conv_dim = 64
g_lr = 0.0001
g_num = 5
image_path = './data'
imsize = 64
lambda_gp = 10
log_path = './logs'
log_step = 10
lr_decay = 0.95
model = 'sagan'
model_save_path = './models'
model_save_step = 1.0
num_workers = 2
parallel = False
pretrained_model = None
sample_path = './samples'
sample_step = 100
total_step = 1000000
train = True
use_tensorboard = False
version = 'sagan_celeb'
z_dim = 128
Discriminator网络结构
判别器网络设定参数为batch size=64, image_size=64, conv_dim=64
假定输入数据为 torch.Size([64, 3, 64, 64])
# layer1
Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
LeakyReLU(negative_slope=0.1)
此时变为 torch.Size([64, 64, 32, 32])
# layer2
Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
LeakyReLU(negative_slope=0.1)
此时变为 torch.Size([64, 128, 16, 16])
# layer3
Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm(