SAGAN: Self-attention GAN

Self-attention

Motivation:

  • Since the convolution operator has a local receptive field, long range dependencies can only be processed after passing through several convolutional layers. This could prevent learning about long-term dependencies for a variety of reasons:
    • (i) a small model may not be able to represent them
    • (ii) optimization algorithms may have trouble discovering parameter values that carefully coordinate multiple layers to capture these dependencies
    • (iii) these parameterizations may be statistically brittle and prone to failure when applied to previously unseen inputs.
  • Increasing the size of the convolution kernels can increase the representational capacity of the network but doing so also loses the computational and statistical efficiency obtained by using local convolutional structure.

SAGAN

  • SAGAN allows attention-driven, long-range dependency modeling (卷积核易于捕捉局部信息,而 SAGAN 通过注意力机制引入广域依赖) for image generation tasks.
  • In the SAGAN, the proposed attention module has been applied to both the generator and the discriminator.
    • (1) Generator: Details can be generated using cues from all feature locations.
    • (2) Discriminator: the discriminator can check that highly detailed features in distant portions of the image are consistent with each other.
    • Visualization of the attention layers shows that the generator leverages neighborhoods that correspond to object shapes rather than local regions of fixed shape.
      在这里插入图片描述

Self-attention - 全局空间信息计算

在这里插入图片描述

  • x ∈ R C × N x\in\R^{C\times N} xRC×N: image features from the previous hidden layer. Here, C C C is the number of channels and N N N is the number of feature locations of features from the previous hidden layer.
  • f ( x ) = W f x , g ( x ) = W g x f(x) = W_f x, g(x) = W_gx f(x)=Wfx,g(x)=Wgx: transform x x x into two feature spaces f f f (key), g g g (query) to calculate the attention
    在这里插入图片描述
    • β j , i β_{j,i} βj,i indicates the extent to which the model attends to the i i ith location when synthesizing the j j jth region
    • W g ∈ R C ˉ × C , W f ∈ R C ˉ × C W_g\in\R^{\bar C\times C},W_f\in\R^{\bar C\times C} WgRCˉ×C,WfRCˉ×C; attention map: N × N N\times N N×N
  • The output of the attention layer is o = ( o 1 , o 2 , . . . , o j , . . . , o N ) ∈ R C × N o = (o_1, o_2, ..., o_j , ..., o_N) ∈ \R^{C×N} o=(o1,o2,...,oj,...,oN)RC×N, where,
    在这里插入图片描述
    • W h ∈ R C ˉ × C , W v ∈ R C × C ˉ W_h\in\R^{\bar C\times C},W_v\in\R^{C\times\bar C} WhRCˉ×C,WvRC×Cˉ

W g , W f , W h , W v W_g,W_f,W_h,W_v Wg,Wf,Wh,Wv are implemented as 1 × 1 1×1 1×1 convolutions. Since We did not notice any significant performance decrease when reducing the channel number of C ˉ \bar C Cˉ to be C / k C/k C/k, where k = 1 , 2 , 4 , 8 k = 1, 2, 4, 8 k=1,2,4,8 after few training epochs on ImageNet. For memory efficiency, we choose k = 8 k = 8 k=8 (i.e., C ˉ = C / 8 \bar C = C/8 Cˉ=C/8) in all our experiments.


Self-attention - 整合全局空间信息和局部信息

  • In addition, we further multiply the output of the attention layer by a scale parameter and add back the input feature map. Therefore, the final output is given by,
    在这里插入图片描述where γ γ γ is a learnable scalar and it is initialized as 0.
  • Introducing the learnable γ γ γ allows the network to first rely on the cues in the local neighborhood – since this is easier – and then gradually learn to assign more weight to the non-local evidence.
    • The intuition for why we do this is straightforward: we want to learn the easy task first and then progressively increase the complexity of the task.

Loss

Spectral normalization for both generator and discriminator

  • In SNGAN, SN is only applied to D D D. Here, SAGAN applys spectral normalization to both GAN generator and discriminator.
    • Spectral normalization in the generator can prevent the escalation of parameter magnitudes and avoid unusual gradients.
    • We find empirically that spectral normalization of both generator and discriminator makes it possible to use fewer discriminator updates per generator update, thus significantly reducing the computational cost of training. The approach also shows more stable training behavior.

Imbalanced learning rate for generator and discriminator updates

  • In previous work, regularization of the discriminator (SNGAN; WGAN-GP) often slows down the GANs’ learning process.
    • In practice, methods using regularized discriminators typically require multiple (e.g., 5) discriminator update steps per generator update step during training.
  • Independently, Heusel et al. (Heusel et al., 2017) have advocated using separate learning rates (TTUR; Two-Timescale Update Rule) for the generator and the discriminator.
  • We propose using TTUR specifically to compensate for the problem of slow learning in a regularized discriminator, making it possible to use fewer discriminator steps per generator step. Using this approach, we are able to produce better results given the same wall-clock time.
    • lr for Discriminator: 0.0004
    • lr for Generator: 0.0001
      在这里插入图片描述
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用PyTorch实现Self-Attention GAN(SAGAN)并替换成WGAN-GP损失函数的代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, in_dim): super(SelfAttention, self).__init__() self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, x): batch_size, c, h, w = x.size() proj_query = self.query_conv(x).view(batch_size, -1, h*w).permute(0, 2, 1) proj_key = self.key_conv(x).view(batch_size, -1, h*w) energy = torch.bmm(proj_query, proj_key) attention = self.softmax(energy) proj_value = self.value_conv(x).view(batch_size, -1, h*w) out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(batch_size, c, h, w) out = self.gamma*out + x return out class Generator(nn.Module): def __init__(self, z_dim=100, image_size=64, conv_dim=64): super(Generator, self).__init__() self.image_size = image_size self.fc = nn.Linear(z_dim, conv_dim*8*(image_size//16)**2) self.conv1 = nn.Conv2d(conv_dim*8, conv_dim*4, 3, 1, 1) self.conv2 = nn.Conv2d(conv_dim*4, conv_dim*2, 3, 1, 1) self.self_attention = SelfAttention(conv_dim*2) self.conv3 = nn.Conv2d(conv_dim*2, conv_dim, 3, 1, 1) self.conv4 = nn.Conv2d(conv_dim, 3, 3, 1, 1) self.bn1 = nn.BatchNorm2d(conv_dim*8) self.bn2 = nn.BatchNorm2d(conv_dim*4) self.bn3 = nn.BatchNorm2d(conv_dim*2) self.bn4 = nn.BatchNorm2d(conv_dim) def forward(self, z): x = self.fc(z) x = x.view(-1, 512, self.image_size//16, self.image_size//16) x = F.relu(self.bn1(x)) x = F.interpolate(x, scale_factor=2) x = F.relu(self.bn2(self.conv1(x))) x = F.interpolate(x, scale_factor=2) x = F.relu(self.bn3(self.self_attention(self.conv2(x)))) x = F.interpolate(x, scale_factor=2) x = F.relu(self.bn4(self.conv3(x))) x = torch.tanh(self.conv4(x)) return x class Discriminator(nn.Module): def __init__(self, image_size=64, conv_dim=64): super(Discriminator, self).__init__() self.conv1 = nn.Conv2d(3, conv_dim, 4, 2, 1) self.conv2 = nn.Conv2d(conv_dim, conv_dim*2, 4, 2, 1) self.conv3 = nn.Conv2d(conv_dim*2, conv_dim*4, 4, 2, 1) self.self_attention = SelfAttention(conv_dim*4) self.conv4 = nn.Conv2d(conv_dim*4, conv_dim*8, 4, 2, 1) self.conv5 = nn.Conv2d(conv_dim*8, 1, 4, 1, 0) self.bn1 = nn.BatchNorm2d(conv_dim) self.bn2 = nn.BatchNorm2d(conv_dim*2) self.bn3 = nn.BatchNorm2d(conv_dim*4) self.bn4 = nn.BatchNorm2d(conv_dim*8) def forward(self, x): x = F.leaky_relu(self.conv1(x), 0.1) x = F.leaky_relu(self.bn2(self.conv2(x)), 0.1) x = F.leaky_relu(self.bn3(self.self_attention(self.conv3(x))), 0.1) x = F.leaky_relu(self.bn4(self.conv4(x)), 0.1) x = self.conv5(x) return x.view(-1, 1) # WGAN-GP loss def wgan_gp_loss(discriminator, real_images, fake_images, batch_size, device): # Calculate critic scores for real images real_scores = discriminator(real_images) # Sample random points in the latent space z = torch.randn(batch_size, 100, device=device) # Generate fake images fake_images = fake_images.detach() fake_images = Generator(z) # Calculate critic scores for fake images fake_scores = discriminator(fake_images) # Compute the gradient penalty alpha = torch.rand(batch_size, 1, 1, 1, device=device) interpolated_images = (alpha * real_images + (1 - alpha) * fake_images).requires_grad_(True) interpolated_scores = discriminator(interpolated_images) gradients = torch.autograd.grad(outputs=interpolated_scores, inputs=interpolated_images, grad_outputs=torch.ones_like(interpolated_scores), create_graph=True, retain_graph=True)[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() # Compute the Wasserstein distance wasserstein_distance = real_scores.mean() - fake_scores.mean() # Compute the loss for the discriminator d_loss = -wasserstein_distance + 10 * gradient_penalty # Compute the loss for the generator g_loss = -fake_scores.mean() return d_loss, g_loss ``` 在训练GAN的过程中,使用wgan_gp_loss函数来替代原来的GAN损失函数。例如: ```python # 初始化模型和优化器 generator = Generator().to(device) discriminator = Discriminator().to(device) g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 训练GAN for epoch in range(num_epochs): for i, (real_images, _) in enumerate(dataloader): real_images = real_images.to(device) batch_size = real_images.size(0) # 训练判别器 discriminator.zero_grad() d_loss, _ = wgan_gp_loss(discriminator, real_images, generator, batch_size, device) d_loss.backward() d_optimizer.step() # 训练生成器 generator.zero_grad() _, g_loss = wgan_gp_loss(discriminator, real_images, generator, batch_size, device) g_loss.backward() g_optimizer.step() if i % 100 == 0: print("Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}".format( epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item())) # 保存模型和图片 with torch.no_grad(): fake_images = generator(fixed_z) save_image(fake_images, "SAGAN_WGAN_GP_{}.png".format(epoch+1), nrow=8, normalize=True) torch.save(generator.state_dict(), "SAGAN_WGAN_GP_Generator_{}.ckpt".format(epoch+1)) torch.save(discriminator.state_dict(), "SAGAN_WGAN_GP_Discriminator_{}.ckpt".format(epoch+1)) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值