CycleGAN模型——pytorch实现

 论文传送门:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks​​​​​​​

# 输入图像shape默认为(3,256,256)
class Discriminator(nn.Module):  # 定义判别器
    def __init__(self):  # 初始化方法
        super(Discriminator, self).__init__()  # 继承初始化方法
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)  # conv
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)  # conv
        self.isn2 = nn.InstanceNorm2d(128)  # instancenorm,实例标准化,在图像风格转化任务中,生成图像依赖于某个图像的实例,所以batchnorm并不适用于风格转化任务
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)  # conv
        self.isn3 = nn.InstanceNorm2d(256)  # in
        self.conv4 = nn.Conv2d(256, 512, 4, 2, 1)  # conv
        self.isn4 = nn.InstanceNorm2d(512)  # in
        self.conv5 = nn.Conv2d(512, 1, 3, 1, 1)  # conv

        self.leakyrelu = nn.LeakyReLU(0.2)  # leakyrelu
        self.sigmoid = nn.Sigmoid()  # sigmoid

    def forward(self, x):  # 前传函数
        x = self.conv1(x)  # conv,(n,3,256,256)-->(n,64,128,128)
        x = self.leakyrelu(x)  # leakyrelu
        x = self.conv2(x)  # conv,(n,64,128,128)-->(n,128,64,64)
        x = self.isn2(x)  # in
        x = self.leakyrelu(x)  # leakyrelu
        x = self.conv3(x)  # conv,(n,128,64,64)-->(n,256,32,32)
        x = self.isn3(x)  # in
        x = self.leakyrelu(x)  # leakyrelu
        x = self.conv4(x)  # conv,(n,256,32,32)-->(n,512,16,16)
        x = self.isn4(x)  # in
        x = self.leakyrelu(x)  # leakyrelu
        x = self.conv5(x)  # conv,(n,512,16,16)-->(n,1,16,16)
        x = self.sigmoid(x)  # sigmoid,输入映射至(0,1)

        return x  # 返回图像真假的得分(16,16),相当于对16x16个区域的真假进行评分,而非对整体图片的真假进行评分


class IdentityBlock(nn.Module):  # 定义残差块
    def __init__(self):  # 初始化方法
        super(IdentityBlock, self).__init__()  # 继承初始化方法
        self.conv1 = nn.Conv2d(256, 256, 3, 1, 1)  # conv
        self.isn1 = nn.InstanceNorm2d(256)  # in
        self.conv2 = nn.Conv2d(256, 256, 3, 1, 1)  # conv
        self.isn2 = nn.InstanceNorm2d(256)  # in
        self.relu = nn.ReLU()  # relu

    def forward(self, x):  # 前传函数
        y = self.conv1(x)  # conv,(n,256,64,64)-->(n,256,64,64)
        y = self.isn1(y)  # in
        y = self.relu(y)  # relu
        y = self.conv2(y)  # conv,(n,256,64,64)-->(n,256,64,64)
        y = self.isn2(y)  # in
        y += x  # F(x) + x,(n,256,64,64)+(n,256,64,64)-->(n,256,64,64)
        y = self.relu(y)  # relu
        return y


class Generator(nn.Module):  # 定义生成器
    def __init__(self):  # 初始化方法
        super(Generator, self).__init__()  # 继承初始化方法
        self.conv1 = nn.Conv2d(3, 64, 7, 1, 3)  # conv
        self.isn1 = nn.InstanceNorm2d(64)  # in
        self.conv2 = nn.Conv2d(64, 128, 3, 2, 1)  # conv
        self.isn2 = nn.InstanceNorm2d(128)  # in
        self.conv3 = nn.Conv2d(128, 256, 3, 2, 1)  # conv
        self.isn3 = nn.InstanceNorm2d(256)  # in
        self.relu = nn.ReLU()  # relu
        self.layers = []  # 用于存放残差块结构
        for i in range(9):  # 共9个残差块
            self.layers.append(IdentityBlock())  # 向layers中添加残差块结构
        self.resnet = nn.Sequential(*self.layers)  # 将layers列表转化为模型结构序列
        self.ups = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # upsample,上采样
        self.conv4 = nn.Conv2d(256, 128, 3, 1, 1)  # conv
        self.isn4 = nn.InstanceNorm2d(128)  # in
        self.conv5 = nn.Conv2d(128, 64, 3, 1, 1)  # conv
        self.isn5 = nn.InstanceNorm2d(64)  # in
        self.conv6 = nn.Conv2d(64, 3, 7, 1, 3)  # conv
        self.tanh = nn.Tanh()  # tanh

    def forward(self, x):  # 前传函数
        x = self.conv1(x)  # conv,(n,3,256,256)-->(n,64,256,256)
        x = self.isn1(x)  # in
        x = self.relu(x)  # relu
        x = self.conv2(x)  # conv,(n,64,256,256)-->(n,128,128,128)
        x = self.isn2(x)  # in
        x = self.relu(x)  # relu
        x = self.conv3(x)  # conv,(n,128,128,128)-->(n,256,64,64)
        x = self.isn3(x)  # in
        x = self.relu(x)  # relu
        x = self.resnet(x)  # 9次残差结构计算,(n,256,64,64)-->(n,256,64,64)
        x = self.ups(x)  # upsample,(n,256,64,64)-->(n,256,128,128)
        x = self.conv4(x)  # conv,(n,256,128,128)-->(n,128,128,128)
        x = self.isn4(x)  # in
        x = self.relu(x)  # relu
        x = self.ups(x)  # upsample,(n,128,128,128)-->(n,128,256,256)
        x = self.conv5(x)  # conv,(n,128,256,256)-->(n,64,256,256)
        x = self.isn5(x)  # in
        x = self.relu(x)  # relu
        x = self.conv6(x)  # conv,(n,64,256,256)-->(n,3,256,256)
        x = self.tanh(x)  # tanh,输出映射至(-1,1)

        return x  # 返回风格迁移后的图像

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CV_Peach

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值