基于pytorch的cyclegan,实现风格转换

目录

1、概述

 2、网络结构

2.1、生成器部分的代码:

2.2、判别器部分的代码

3、代码(持续更新)


论文名称:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

原文地址:https://arxiv.org/abs/1703.10593

1、概述

        cyclegan用于实现两种不同风格的图像之间的相互转换,其特点是不需要训练数据为一一对应的成对存在,而是只需要准备两个领域的数据集即可,比如说普通马的图片和斑马的图片。经过训练可以实现如下图所示的风格转换:

 2、网络结构

其训练时的网络结构如下图所示:

         可以看出,cyclegan网络中一共包含了两个生成器和两个判别器,分别用于两个领域的数据的生成和判断。

2.1、生成器部分的代码:

import torch.nn as nn
from torchsummary import summary
from collections import OrderedDict


# 定义残差块
class Resnet_block(nn.Module):
    def __init__(self, in_channels):
        super(Resnet_block, self).__init__()
        block = []
        for i in range(2):
            block += [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_channels, in_channels, 3, 1, 0),
                      nn.InstanceNorm2d(in_channels),
                      nn.ReLU(True) if i > 0 else nn.Identity()]
        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = x + self.block(x)
        return out


class Cycle_Gan_G(nn.Module):
    def __init__(self):
        super(Cycle_Gan_G, self).__init__()
        net_dic = OrderedDict()
        # 三层卷积层
        net_dic.update({'first layer': nn.Sequential(
            nn.ReflectionPad2d(3),  # [3,256,256]  ->  [3,262,262]
            nn.Conv2d(3, 64, 7, 1),  # [3,262,262]  ->[64,256,256]
            nn.InstanceNorm2d(64),
            nn.ReLU(True)
        )})
        net_dic.update({'second_conv': nn.Sequential(
            nn.Conv2d(64, 128, 3, 2, 1),  # [128,128,128]
            nn.InstanceNorm2d(128),
            nn.ReLU(True)
        )})
        net_dic.update({'three_conv': nn.Sequential(
            nn.Conv2d(128, 256, 3, 2, 1),  # [256,64,64]
            nn.InstanceNorm2d(256),
            nn.ReLU(True)
        )})

        # 9层 resnet block
        for i in range(6):
            net_dic.update({'Resnet_block{}'.format(i + 1): Resnet_block(256)})

        # up_sample
        net_dic.update({'up_sample1': nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),  # [128,128,128]
            nn.ReLU(True)
        )})
        net_dic.update({'up_sample2': nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),  # [64,256,256]
            nn.ReLU(True)
        )})

        net_dic.update({'last_layer': nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, 7, 1),
            nn.Tanh()
        )})

        self.net_G = nn.Sequential(net_dic)
        self.init_weight()

    def init_weight(self):
        for w in self.modules():
            if isinstance(w, nn.Conv2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_out')
                if w.bias is not None:
                    nn.init.zeros_(w.bias)
            elif isinstance(w, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_in')
            elif isinstance(w, nn.BatchNorm2d):
                nn.init.ones_(w.weight)
                nn.init.zeros_(w.bias)

    def forward(self, x):
        out = self.net_G(x)
        return out


if __name__ == '__main__':
    G = Cycle_Gan_G().to('cuda')
    summary(G, (3, 256, 256))

其中:残差块的结构如下:

 整体网络结构如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ReflectionPad2d-1          [-1, 3, 262, 262]               0
            Conv2d-2         [-1, 64, 256, 256]           9,472
    InstanceNorm2d-3         [-1, 64, 256, 256]               0
              ReLU-4         [-1, 64, 256, 256]               0
            Conv2d-5        [-1, 128, 128, 128]          73,856
    InstanceNorm2d-6        [-1, 128, 128, 128]               0
              ReLU-7        [-1, 128, 128, 128]               0
            Conv2d-8          [-1, 256, 64, 64]         295,168
    InstanceNorm2d-9          [-1, 256, 64, 64]               0
             ReLU-10          [-1, 256, 64, 64]               0
  ReflectionPad2d-11          [-1, 256, 66, 66]               0
           Conv2d-12          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-13          [-1, 256, 64, 64]               0
         Identity-14          [-1, 256, 64, 64]               0
  ReflectionPad2d-15          [-1, 256, 66, 66]               0
           Conv2d-16          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-17          [-1, 256, 64, 64]               0
             ReLU-18          [-1, 256, 64, 64]               0
     Resnet_block-19          [-1, 256, 64, 64]               0
  ReflectionPad2d-20          [-1, 256, 66, 66]               0
           Conv2d-21          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-22          [-1, 256, 64, 64]               0
         Identity-23          [-1, 256, 64, 64]               0
  ReflectionPad2d-24          [-1, 256, 66, 66]               0
           Conv2d-25          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-26          [-1, 256, 64, 64]               0
             ReLU-27          [-1, 256, 64, 64]               0
     Resnet_block-28          [-1, 256, 64, 64]               0
  ReflectionPad2d-29          [-1, 256, 66, 66]               0
           Conv2d-30          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-31          [-1, 256, 64, 64]               0
         Identity-32          [-1, 256, 64, 64]               0
  ReflectionPad2d-33          [-1, 256, 66, 66]               0
           Conv2d-34          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-35          [-1, 256, 64, 64]               0
             ReLU-36          [-1, 256, 64, 64]               0
     Resnet_block-37          [-1, 256, 64, 64]               0
  ReflectionPad2d-38          [-1, 256, 66, 66]               0
           Conv2d-39          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-40          [-1, 256, 64, 64]               0
         Identity-41          [-1, 256, 64, 64]               0
  ReflectionPad2d-42          [-1, 256, 66, 66]               0
           Conv2d-43          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-44          [-1, 256, 64, 64]               0
             ReLU-45          [-1, 256, 64, 64]               0
     Resnet_block-46          [-1, 256, 64, 64]               0
  ReflectionPad2d-47          [-1, 256, 66, 66]               0
           Conv2d-48          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-49          [-1, 256, 64, 64]               0
         Identity-50          [-1, 256, 64, 64]               0
  ReflectionPad2d-51          [-1, 256, 66, 66]               0
           Conv2d-52          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-53          [-1, 256, 64, 64]               0
             ReLU-54          [-1, 256, 64, 64]               0
     Resnet_block-55          [-1, 256, 64, 64]               0
  ReflectionPad2d-56          [-1, 256, 66, 66]               0
           Conv2d-57          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-58          [-1, 256, 64, 64]               0
         Identity-59          [-1, 256, 64, 64]               0
  ReflectionPad2d-60          [-1, 256, 66, 66]               0
           Conv2d-61          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-62          [-1, 256, 64, 64]               0
             ReLU-63          [-1, 256, 64, 64]               0
     Resnet_block-64          [-1, 256, 64, 64]               0
  ConvTranspose2d-65        [-1, 128, 128, 128]         295,040
   InstanceNorm2d-66        [-1, 128, 128, 128]               0
             ReLU-67        [-1, 128, 128, 128]               0
  ConvTranspose2d-68         [-1, 64, 256, 256]          73,792
   InstanceNorm2d-69         [-1, 64, 256, 256]               0
             ReLU-70         [-1, 64, 256, 256]               0
  ReflectionPad2d-71         [-1, 64, 262, 262]               0
           Conv2d-72          [-1, 3, 256, 256]           9,411
             Tanh-73          [-1, 3, 256, 256]               0
================================================================
Total params: 7,837,699
Trainable params: 7,837,699
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 788.18
Params size (MB): 29.90
Estimated Total Size (MB): 818.83
----------------------------------------------------------------

2.2、判别器部分的代码

class Cycle_Gan_D(nn.Module):
    def __init__(self):
        super(Cycle_Gan_D, self).__init__()

        # 定义基本的卷积\bn\relu
        def base_Conv_bn_lkrl(in_channels, out_channels, stride):
            if in_channels == 3:
                bn = nn.Identity
            else:
                bn = nn.InstanceNorm2d
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, 1),
                bn(out_channels),
                nn.LeakyReLU(0.2, True)
            )

        D_dic = OrderedDict()
        in_channels = 3
        out_channels = 64
        for i in range(4):
            if i < 3:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
            else:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
            in_channels = out_channels
            out_channels *= 2
        D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)})  # [batch,1,30,30]
        self.D_model = nn.Sequential(D_dic)

    def forward(self, x):
        return self.D_model(x)



if __name__ == '__main__':
    D = Cycle_Gan_D().to('cuda')
    summary(D, (3, 256, 256))

网络整体架构如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 128, 128]           3,136
          Identity-2         [-1, 64, 128, 128]               0
         LeakyReLU-3         [-1, 64, 128, 128]               0
            Conv2d-4          [-1, 128, 64, 64]         131,200
    InstanceNorm2d-5          [-1, 128, 64, 64]               0
         LeakyReLU-6          [-1, 128, 64, 64]               0
            Conv2d-7          [-1, 256, 32, 32]         524,544
    InstanceNorm2d-8          [-1, 256, 32, 32]               0
         LeakyReLU-9          [-1, 256, 32, 32]               0
           Conv2d-10          [-1, 512, 31, 31]       2,097,664
   InstanceNorm2d-11          [-1, 512, 31, 31]               0
        LeakyReLU-12          [-1, 512, 31, 31]               0
           Conv2d-13            [-1, 1, 30, 30]           8,193
================================================================
Total params: 2,764,737
Trainable params: 2,764,737
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 53.27
Params size (MB): 10.55
Estimated Total Size (MB): 64.57
----------------------------------------------------------------

3、代码(持续更新)

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一种深度学习框架,可以用于基于生成对抗网络(GAN)的模型训练和实现GAN是一种由生成判别组成的网络结构,旨在生成逼真的数据样本。 在PyTorch中,你可以使用其强大的张量操作和自动求导功能来构建和训练GAN模型。下面是一个简单的示例,展示了如何在PyTorch实现一个基本的GAN: ```python import torch import torch.nn as nn import torch.optim as optim # 定义生成模型 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # 定义生成网络结构 def forward(self, x): # 前向传播逻辑 # 定义判别模型 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() # 定义判别网络结构 def forward(self, x): # 前向传播逻辑 # 初始化生成判别 generator = Generator() discriminator = Discriminator() # 定义损失函数和优化 criterion = nn.BCELoss() # 二分类交叉熵损失函数 optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 训练GAN模型 for epoch in range(num_epochs): for i, real_data in enumerate(data_loader): # 更新判别 optimizer_D.zero_grad() # 计算真实数据的判别损失 # 生成一批噪声数据并通过生成生成假数据 # 计算假数据的判别损失 # 计算总判别损失 # 反向传播和优化判别参数 # 更新生成 optimizer_G.zero_grad() # 生成一批噪声数据并通过生成生成假数据 # 计算假数据的判别损失 # 计算生成损失 # 反向传播和优化生成参数 # 打印训练信息 # 使用生成生成样本 noise = torch.randn(batch_size, latent_size, device=device) fake_samples = generator(noise) ``` 这只是一个基本的示例,你可以根据你的需求和数据集进行相应的调整。希望对你有所帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值