昇思25天学习打卡营第25天 | Pix2Pix实现图像转换

27 篇文章 0 订阅
25 篇文章 0 订阅

Pix2Pix实现图像转换

在这里插入图片描述

Pix2Pix概述

Pix2Pix是一种基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks)的图像转换模型,由Phillip Isola等人在2017年提出。它能够将语义/标签图像转换为真实图片、灰度图转换为彩色图、航空图转换为地图、白天图转换为夜晚图、线稿图转换为实物图等。Pix2Pix的创新之处在于使用相同的架构和目标函数,通过不同的数据集训练实现多种图像转换任务。

基础原理

cGAN与传统GAN的区别在于,cGAN的生成器以输入图片为指导信息生成“假”图像,而GAN的生成器则以随机噪声为输入。Pix2Pix的生成器使用U-Net结构,通过编码和解码输入图像生成输出图像;判别器使用PatchGAN结构,通过判断图像的局部区域(Patch)来区分真实图像和生成图像。cGAN的目标是通过生成器和判别器的博弈,使生成器生成的图像越来越接近真实图像。

准备工作

配置环境文件

本案例支持在GPU、CPU和Ascend平台的动静态模式下运行。

准备数据

使用指定的数据集,已处理好的外墙(facades)数据集,可直接使用MindSpore读取。

数据展示

from mindspore import dataset as ds
import matplotlib.pyplot as plt

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

创建网络

生成器G结构

生成器使用U-Net结构,通过编码和解码输入图像生成输出图像。

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False, submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        # 定义下采样和上采样的卷积层、激活函数和归一化层
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4, stride=2, padding=1, has_bias=False, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        # 定义下采样和上采样的卷积层、激活函数和归一化层
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, pad_mode='pad')
            model = [down_conv] + [submodule] + [up_relu, up_conv, nn.Tanh()]
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=False, pad_mode='pad')
            model = [down_relu, down_conv] + [up_relu, up_conv, nn.BatchNorm2d(outer_nc)]
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=False, pad_mode='pad')
            model = [down_relu, down_conv, nn.BatchNorm2d(inner_nc)] + [submodule] + [up_relu, up_conv, nn.BatchNorm2d(outer_nc)]
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out
基于UNet的生成器
class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        # 定义UNet生成器的各个层次
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None, norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block, norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block, norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block, norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block, norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block, outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)
判别器D结构

判别器使用PatchGAN结构,通过判断图像的局部区域(Patch)来区分真实图像和生成图像。

import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self, in_planes, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='batch', pad_mode='CONSTANT', use_relu=True, padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=6, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

Pix2Pix的生成器和判别器初始化

import mindspore.nn as nn
from mindspore.common import initializer as init

继续训练过程

随着训练的进行,我们可以在每个周期结束时保存生成器的检查点,并可视化一些中间结果。我们也可以监控损失的变化来确定模型的训练情况。

训练完成

一旦训练完成,我们可以使用生成器来生成新的图像。通过提供输入图像,生成器可以生成相应的输出图像。

验证生成效果

在训练完成后,我们使用部分数据进行验证,看看生成器的效果。

# 验证生成器效果
from mindspore import load_checkpoint

# 加载生成器模型
load_checkpoint("results/ckpt/Generator.ckpt", net_generator)

# 可视化函数
def visualize_result(input_image, generated_image, target_image, epoch, idx):
    plt.figure(figsize=(15, 5), dpi=140)
    plt.subplot(1, 3, 1)
    plt.axis("off")
    plt.title("Input Image")
    plt.imshow((input_image.transpose(1, 2, 0) + 1) / 2)

    plt.subplot(1, 3, 2)
    plt.axis("off")
    plt.title("Generated Image")
    plt.imshow((generated_image.transpose(1, 2, 0) + 1) / 2)

    plt.subplot(1, 3, 3)
    plt.axis("off")
    plt.title("Target Image")
    plt.imshow((target_image.transpose(1, 2, 0) + 1) / 2)
    
    plt.suptitle(f"Epoch {epoch}, Step {idx}")
    plt.show()

# 验证集
val_dataset = ds.MindDataset("./dataset/dataset_pix2pix/val.mindrecord", columns_list=["input_images", "target_images"], shuffle=False)

# 验证生成器
val_data_iter = val_dataset.create_dict_iterator(output_numpy=True)
for idx, data in enumerate(val_data_iter):
    input_image = Tensor(data["input_images"])
    target_image = Tensor(data["target_images"])
    generated_image = net_generator(input_image)
    visualize_result(input_image.asnumpy(), generated_image.asnumpy(), target_image.asnumpy(), epoch_num, idx)
    if idx >= 5:  # 仅展示部分验证结果
        break

在这里插入图片描述

通过本文教程,我们使用MindSpore实现了Pix2Pix模型,包括数据准备、生成器和判别器的搭建、训练和验证。在训练过程中,我们使用了cGAN和L1损失的组合来优化生成器。最终,我们展示了模型的训练效果和生成结果。

在实际应用中,Pix2Pix模型可以用于各种图像到图像的转换任务,比如从卫星图像生成地图、将灰度图像转换为彩色图像等。希望本文的教程对您理解Pix2Pix模型有所帮助,并能在实际项目中应用此模型。

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值