pytorch复现经典生成对抗式的超分辨率网络

论文原文:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

论文的中文翻译:翻译:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

网络结构如下图所示:

上面和下面分别是生成网络和判别网络:

在这里插入图片描述

废话不多说,直接看代码。比较不喜欢一堆废话的博客,懂得都懂,最有用的就是代码!

代码的实现参考pytorch torchvision中的网络实现优点:模块化、简洁易读、而且容易修改网络宽度和深度(方便调整网络架构做对比实验,消融实验)。

实现生成网络:

# -*- coding: utf-8 -*-
# @Use     :
# @Time    : 2022/8/17 21:32
# @FileName: models.py
# @Software: PyCharm
# @inference:https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

import torch
from torch import nn
import torchvision
from torch import Tensor


class GeneratorBasicBlock(nn.Module):
    """
    生成器重复的部分
    """

    def __init__(self, channel, kernel_size) -> None:
        super(GeneratorBasicBlock, self).__init__()

        self.channel = channel
        self.conv1 = nn.Conv2d(in_channels=channel, out_channels=channel,
                               kernel_size=(kernel_size, kernel_size),
                               stride=(1, 1), padding=(1, 1))
        self.bn1 = nn.BatchNorm2d(num_features=channel)
        self.p_relu1 = nn.PReLU()
        self.conv2 = nn.Conv2d(in_channels=channel, out_channels=channel,
                               kernel_size=(kernel_size, kernel_size),
                               stride=(1, 1), padding=(1, 1))
        self.bn2 = nn.BatchNorm2d(num_features=channel)

    def forward(self, x: Tensor) -> Tensor:
        """
        前向推断
        :param x:
        :return:
        """
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.p_relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        return out


class PixelShufflerBlock(nn.Module):
    """
    生成器最后的pixelshuffler
    """

    def __init__(self, in_channel, out_channel) -> None:
        super(PixelShufflerBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.pixels_shuffle = nn.PixelShuffle(upscale_factor=2)
        self.prelu = nn.PReLU()

    def forward(self, x: Tensor) -> Tensor:
        """
        前向
        """
        out = self.conv1(x)
        out = self.pixels_shuffle(out)
        out = self.prelu(out)
        return out


class Generator(nn.Module):
    """
    生成器
    """

    def __init__(self, config) -> None:
        # Generator parameters
        super(Generator, self).__init__()
        large_kernel_size = config.G.large_kernel_size  # = 9
        small_kernel_size = config.G.small_kernel_size  # = 3
        n_channels = config.G.n_channels  # = 64
        n_blocks = config.G.n_blocks  # = 16
        base_block_type = config.G.base_block_type  # 'depthwise_conv_residual'  # 'conv_residual' or 'depthwise_conv_residual'

        # base block
        if base_block_type == 'depthwise_conv_residual':
            self.repeat_block = GeneratorDepthwiseBlock
        if base_block_type == 'conv_residual':
            self.repeat_block = GeneratorBasicBlock

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=n_channels,
                               kernel_size=(large_kernel_size, large_kernel_size),
                               stride=(1, 1), padding=(4, 4))
        self.prelu1 = nn.PReLU()
        self.B_residul_block = self._make_layer(self.repeat_block, n_channels,
                                                n_blocks, small_kernel_size)
        self.conv2 = nn.Conv2d(in_channels=n_channels, out_channels=n_channels,
                               kernel_size=(small_kernel_size, small_kernel_size),
                               stride=(1, 1), padding=(1, 1))
        self.bn1 = nn.BatchNorm2d(n_channels)
        self.pixel_shuffle_block1 = PixelShufflerBlock(n_channels, 4 * n_channels)
        self.pixel_shuffle_block2 = PixelShufflerBlock(n_channels, 4 * n_channels)
        self.conv3 = nn.Conv2d(in_channels=n_channels, out_channels=3,
                               kernel_size=(large_kernel_size, large_kernel_size),
                               stride=(1, 1), padding=(4, 4))

    def _make_layer(self, base_block, n_channels, n_block, kernel_size) -> nn.Sequential:
        """
        构建重复的B个基本块
        :param base_block: 基本块
        :param n_channels: 块里面的通道数
        :param n_block: 块数
        :return:
        """
        layers = []
        self.base_block = base_block
        for _ in range(n_block):
            layers.append(self.base_block(n_channels, kernel_size))
        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        """
        前向的实现
        """
        out = self.conv1(x)
        out = self.prelu1(out)
        identity = out
        out = self.B_residul_block(out)
        out = self.conv2(out)
        out = self.bn1(out)
        out += identity
        out = self.pixel_shuffle_block1(out)
        out = self.pixel_shuffle_block2(out)
        out = self.conv3(out)

        return out

    def forward(self, x: Tensor) -> Tensor:
        """
        前向
        """
        return self._forward_impl(x)



判别网络实现:


class DiscriminatorBaseblock(nn.Module):
    """
    判别器的基本块
    """

    def __init__(self, inchannel, out_chanel, kernel_size, stride):
        super(DiscriminatorBaseblock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=out_chanel,
                               kernel_size=kernel_size, stride=stride, padding=(1, 1))
        self.bn1 = nn.BatchNorm2d(out_chanel)
        self.act1 = nn.LeakyReLU(0.2)

    def forward(self, x: Tensor) -> Tensor:
        """
        前向
        """
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act1(out)
        return out
        
class Discriminator(nn.Module):
    """
    判别器
    """
def __init__(self, config):
    super(Discriminator, self).__init__()
    # Discriminator parameters
    kernel_size = config.D.kernel_size = 3
    n_channels = config.D.n_channels = 64
    n_blocks = config.D.n_blocks = 8
    fc_size = config.D.fc_size = 1024

    self.conv1 = nn.Conv2d(in_channels=3, out_channels=n_channels,
                           kernel_size=(kernel_size, kernel_size), stride=(1, 1), padding=(1, 1))
    self.leaky_relu1 = nn.LeakyReLU(0.2)
    conv_configs = [[3, 64, 2],  # 配置二维数组
                    [3, 128, 1],
                    [3, 128, 2],
                    [3, 256, 1],
                    [3, 256, 2],
                    [3, 512, 1],
                    [3, 512, 2]]
    self.base_blocks = self._make_layer(n_channels, DiscriminatorBaseblock, conv_configs)
    self.dense1 = nn.Linear(512 * 6 * 6, 1024)
    self.leaky_relu2 = nn.LeakyReLU(0.2)
    self.dense2 = nn.Linear(1024, 1)
    self.sigmod1 = nn.Sigmoid()

def _make_layer(self, in_channel, base_block, conv_configs: list) -> nn.Sequential:
    """

    :param base_block: DiscriminatorBaseblock
    :param conv_configs: (kernel, channel, stride)
    :return:
    """
    layers = []
    self.base_block = base_block
    self.in_channel = in_channel
    for i in range(len(conv_configs)):
        # in_channel, out_chanel, kernel_size, stride
        kernel_size = (conv_configs[i][0], conv_configs[i][0])
        stride = (conv_configs[i][2], conv_configs[i][2])
        out_channel = conv_configs[i][1]
        layers.append(self.base_block(self.in_channel, out_channel, kernel_size, stride))
        self.in_channel = out_channel
    return nn.Sequential(*layers)

def _forward_impl(self, x: Tensor) -> Tensor:
    """
    前向
    """
    out = self.conv1(x)
    out = self.leaky_relu1(out)
    out = self.base_blocks(out)
    out = torch.flatten(out, 1)
    out = self.dense1(out)
    out = self.leaky_relu2(out)
    out = self.dense2(out)
    out = self.sigmod1(out)
    return out

def forward(self, x: Tensor) -> Tensor:
    """
    前向
    """
    return self._forward_impl(x)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KPer_Yang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值