昇思25天学习打卡营第23天|基于MindSpore的Pix2Pix实现图像转换

Pix2Pix实现图像转换

Pix2Pix概述

Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器判别器

传统上,尽管此类任务的目标都是相同的从像素预测像素,但每项都是用单独的专用机器来处理的。而Pix2Pix使用的网络作为一个通用框架,使用相同的架构和目标,只在不同的数据上进行训练,即可得到令人满意的结果,鉴于此许多人已经使用此网络发布了他们自己的艺术作品。

基础原理

cGAN的生成器与传统GAN的生成器在原理上有一些区别,cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成,这是cGAN和GAN的在图像翻译任务中的差异。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。

在教程开始前,首先定义一些在整个过程中需要用到的符号:

  • 𝑥:代表观测图像的数据。
  • 𝑧:代表随机噪声的数据。
  • 𝑦=𝐺(𝑥,𝑧):生成器网络,给出由观测图像𝑥与随机噪声𝑧生成的“假”图片,其中𝑥来自于训练数据而非生成器。
  • 𝐷(𝑥,𝐺(𝑥,𝑧)):判别器网络,给出图像判定为真实图像的概率,其中𝑥来自于训练数据,𝐺(𝑥,𝑧)来自于生成器。

cGAN的目标可以表示为:

该公式是cGAN的损失函数,D想要尽最大努力去正确分类真实图像与“假”图像,也就是使参数𝑙𝑜𝑔𝐷(𝑥,𝑦)最大化;而G则尽最大努力用生成的“假”图像𝑦欺骗D,避免被识破,也就是使参数𝑙𝑜𝑔(1−𝐷(𝐺(𝑥,𝑧)))最小化。cGAN的目标可简化为:

为了对比cGAN和GAN的不同,我们将GAN的目标也进行了说明:

从公式可以看出,GAN直接由随机噪声𝑧�生成“假”图像,不借助观测图像𝑥�的任何信息。过去的经验告诉我们,GAN与传统损失混合使用是有好处的,判别器的任务不变,依旧是区分真实图像与“假”图像,但是生成器的任务不仅要欺骗判别器,还要在传统损失的基础上接近训练数据。假设cGAN与L1正则化混合使用,那么有:

进而得到最终目标:

图像转换问题本质上其实就是像素到像素的映射问题,Pix2Pix使用完全一样的网络结构和目标函数,仅更换不同的训练数据集就能分别实现以上的任务。本任务将借助MindSpore框架来实现Pix2Pix的应用。

准备环节

配置环境文件

本案例在GPU,CPU和Ascend平台的动静态模式都支持。

python版本:Python 3.9.19

依赖环境安装

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

完整的环境

pip list
Package                        Version
------------------------------ --------------
absl-py                        2.1.0
aiofiles                       22.1.0
aiosqlite                      0.20.0
altair                         5.3.0
annotated-types                0.7.0
anyio                          4.4.0
argon2-cffi                    23.1.0
argon2-cffi-bindings           21.2.0
arrow                          1.3.0
astroid                        3.2.2
asttokens                      2.0.5
astunparse                     1.6.3
attrs                          23.2.0
auto-tune                      0.1.0
autopep8                       1.5.5
Babel                          2.15.0
backcall                       0.2.0
beautifulsoup4                 4.12.3
black                          24.4.2
bleach                         6.1.0
certifi                        2024.6.2
cffi                           1.16.0
charset-normalizer             3.3.2
click                          8.1.7
cloudpickle                    3.0.0
colorama                       0.4.6
comm                           0.2.1
contextlib2                    21.6.0
contourpy                      1.2.1
cycler                         0.12.1
dataflow                       0.0.1
debugpy                        1.6.7
decorator                      5.1.1
defusedxml                     0.7.1
dill                           0.3.8
dnspython                      2.6.1
download                       0.3.5
easydict                       1.13
email_validator                2.2.0
entrypoints                    0.4
exceptiongroup                 1.2.0
executing                      0.8.3
fastapi                        0.111.0
fastapi-cli                    0.0.4
fastjsonschema                 2.20.0
ffmpy                          0.3.2
filelock                       3.15.3
flake8                         3.8.4
fonttools                      4.53.0
fqdn                           1.5.1
fsspec                         2024.6.0
gitdb                          4.0.11
GitPython                      3.1.43
gradio                         4.26.0
gradio_client                  0.15.1
h11                            0.14.0
hccl                           0.1.0
hccl-parser                    0.1
httpcore                       1.0.5
httptools                      0.6.1
httpx                          0.27.0
huggingface-hub                0.23.4
idna                           3.7
importlib-metadata             7.0.1
importlib_resources            6.4.0
iniconfig                      2.0.0
ipykernel                      6.28.0
ipympl                         0.9.4
ipython                        8.15.0
ipython-genutils               0.2.0
ipywidgets                     8.1.3
isoduration                    20.11.0
isort                          5.13.2
jedi                           0.17.2
Jinja2                         3.1.4
joblib                         1.4.2
json5                          0.9.25
jsonpointer                    3.0.0
jsonschema                     4.22.0
jsonschema-specifications      2023.12.1
jupyter_client                 7.4.9
jupyter_core                   5.7.2
jupyter-events                 0.10.0
jupyter-lsp                    2.2.5
jupyter-resource-usage         0.7.2
jupyter_server                 2.14.1
jupyter_server_fileid          0.9.2
jupyter-server-mathjax         0.2.6
jupyter_server_terminals       0.5.3
jupyter_server_ydoc            0.8.0
jupyter-ydoc                   0.2.5
jupyterlab                     3.6.7
jupyterlab_code_formatter      2.2.1
jupyterlab_git                 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp                 4.3.0
jupyterlab_pygments            0.3.0
jupyterlab_server              2.27.2
jupyterlab-system-monitor      0.8.0
jupyterlab-topbar              0.6.1
jupyterlab_widgets             3.0.11
kiwisolver                     1.4.5
markdown-it-py                 3.0.0
MarkupSafe                     2.1.5
matplotlib                     3.9.0
matplotlib-inline              0.1.6
mccabe                         0.6.1
mdurl                          0.1.2
mindspore                      2.2.14
mindvision                     0.1.0
mistune                        3.0.2
ml_collections                 0.1.1
mpmath                         1.3.0
msadvisor                      1.0.0
mypy-extensions                1.0.0
nbclassic                      1.1.0
nbclient                       0.10.0
nbconvert                      7.16.4
nbdime                         4.0.1
nbformat                       5.10.4
nest-asyncio                   1.6.0
notebook                       6.5.7
notebook_shim                  0.2.4
numpy                          1.26.4
op-compile-tool                0.1.0
op-gen                         0.1
op-test-frame                  0.1
opc-tool                       0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python                  4.10.0.84
opencv-python-headless         4.10.0.84
orjson                         3.10.5
overrides                      7.7.0
packaging                      23.2
pandas                         2.2.2
pandocfilters                  1.5.1
parso                          0.7.1
pathlib2                       2.3.7.post1
pathspec                       0.12.1
pexpect                        4.8.0
pickleshare                    0.7.5
pillow                         10.3.0
pip                            24.1
platformdirs                   4.2.2
pluggy                         1.5.0
prometheus_client              0.20.0
prompt-toolkit                 3.0.43
protobuf                       5.27.1
psutil                         5.9.0
ptyprocess                     0.7.0
pure-eval                      0.2.2
pycodestyle                    2.6.0
pycparser                      2.22
pydantic                       2.7.4
pydantic_core                  2.18.4
pydocstyle                     6.3.0
pydub                          0.25.1
pyflakes                       2.2.0
Pygments                       2.15.1
pylint                         3.2.3
pyparsing                      3.1.2
pytest                         8.0.0
python-dateutil                2.9.0.post0
python-dotenv                  1.0.1
python-json-logger             2.0.7
python-jsonrpc-server          0.4.0
python-language-server         0.36.2
python-multipart               0.0.9
pytoolconfig                   1.3.1
pytz                           2024.1
PyYAML                         6.0.1
pyzmq                          25.1.2
referencing                    0.35.1
requests                       2.32.3
rfc3339-validator              0.1.4
rfc3986-validator              0.1.1
rich                           13.7.1
rope                           1.13.0
rpds-py                        0.18.1
ruff                           0.4.10
schedule-search                0.0.1
scikit-learn                   1.5.0
scipy                          1.13.1
semantic-version               2.10.0
Send2Trash                     1.8.3
setuptools                     69.5.1
shellingham                    1.5.4
six                            1.16.0
smmap                          5.0.1
sniffio                        1.3.1
snowballstemmer                2.2.0
soupsieve                      2.5
stack-data                     0.2.0
starlette                      0.37.2
sympy                          1.12.1
synr                           0.5.0
te                             0.4.0
terminado                      0.18.1
threadpoolctl                  3.5.0
tinycss2                       1.3.0
toml                           0.10.2
tomli                          2.0.1
tomlkit                        0.12.0
toolz                          0.12.1
tornado                        6.4.1
tqdm                           4.66.4
traitlets                      5.14.3
typer                          0.12.3
types-python-dateutil          2.9.0.20240316
typing_extensions              4.11.0
tzdata                         2024.1
ujson                          5.10.0
uri-template                   1.3.0
urllib3                        2.2.2
uvicorn                        0.30.1
uvloop                         0.19.0
watchfiles                     0.22.0
wcwidth                        0.2.5
webcolors                      24.6.0
webencodings                   0.5.1
websocket-client               1.8.0
websockets                     11.0.3
wheel                          0.43.0
widgetsnbextension             4.0.11
y-py                           0.6.2
yapf                           0.40.2
ypy-websocket                  0.8.4
zipp                           3.17.0

准备数据

在本教程中,我们将使用指定数据集,该数据集是已经经过处理的外墙(facades)数据,可以直接使用mindspore.dataset的方法读取。

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"

download(url, "./dataset", kind="tar", replace=True)

数据展示

调用Pix2PixDatasetcreate_train_dataset读取训练集,这里我们直接下载已经处理好的数据集。

from mindspore import dataset as ds
import matplotlib.pyplot as plt

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

创建网络

当处理完数据后,就可以来进行网络的搭建了。网络搭建将逐一详细讨论生成器、判别器和损失函数。生成器G用到的是U-Net结构,输入的轮廓图𝑥编码再解码成真是图片,判别器D用到的是作者自己提出来的条件判别器PatchGAN,判别器D的作用是在轮廓图 𝑥的条件下,对于生成的图片𝐺(𝑥)判断为假,对于真实判断为真。

生成器G结构

U-Net是德国Freiburg大学模式识别和图像处理组提出的一种全卷积结构。它分为两个部分,其中左侧是由卷积和降采样操作组成的压缩路径,右侧是由卷积和上采样组成的扩张路径,扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成。网络模型整体是一个U形的结构,因此被叫做U-Net。和常见的先降采样到低维度,再升采样到原始分辨率的编解码结构的网络相比,U-Net的区别是加入skip-connection,对应的feature maps和decode之后的同样大小的feature maps按通道拼一起,用来保留不同分辨率下像素级的细节信息。

定义UNet Skip Connection Block
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.BatchNorm2d(inner_nc, affine=False)
            up_norm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
                              stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv]
            up = [up_relu, up_conv, up_norm]
            model = down + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            up = [up_relu, up_conv, up_norm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out
基于UNet的生成器
class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

原始cGAN的输入是条件x和噪声z两种信息,这里的生成器只使用了条件信息,因此不能生成多样性的结果。因此Pix2Pix在训练和测试时都使用了dropout,这样可以生成多样性的结果。

基于PatchGAN的判别器

判别器使用的PatchGAN结构,可看做卷积。生成的矩阵中的每个点代表原图的一小块区域(patch)。通过矩阵中的各个值来判断原图中对应每个Patch的真假。

import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

Pix2Pix的生成器和判别器初始化

实例化Pix2Pix生成器和判别器。

import mindspore.nn as nn
from mindspore.common import initializer as init

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'


net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
                              ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))


net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
                                  alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):
    """Pix2Pix模型网络"""
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.net_discriminator = discriminator
        self.net_generator = generator

    def construct(self, reala):
        fakeb = self.net_generator(reala)
        return fakeb

训练

训练分为两个主要部分:训练判别器和训练生成器。训练判别器的目的是最大程度地提高判别图像真伪的概率。训练生成器是希望能产生更好的虚假图像。在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计。

下面进行训练:

import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor

epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100

def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

def forword_dis(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis

def forword_gan(reala, realb):
    lambda_gan = 0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    loss_1 = loss_f(pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
    return loss_gan

d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)

grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d(reala, realb)
    loss_gan, g_grads = grad_g(reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan

if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)

for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1) == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

推理

获取上述训练过程完成后的ckpt文件,通过load_checkpoint和load_param_into_net将ckpt中的权重参数导入到模型中,获取数据进行推理并对推理的效果图进行演示(由于时间问题,训练过程只进行了3个epoch,可根据需求调整epoch)。

from mindspore import load_checkpoint, load_param_into_net

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 10, i + 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

引用

[1] Phillip Isola,Jun-Yan Zhu,Tinghui Zhou,Alexei A. Efros. Image-to-Image Translation with Conditional Adversarial Networks.[J]. CoRR,2016,abs/1611.07004.

  • 19
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
基于pix2pixHD的图像融合是指利用pix2pixHD模型将两个输入图像进行融合,生成一个新的合成图像。这种图像融合方法可以应用于多个领域,例如合成艺术、虚拟场景生成、风格迁移等。 下面是一种基于pix2pixHD的图像融合方法的简要流程: 1. 数据准备:收集成对的输入图像和对应的目标融合图像作为训练数据集。每个成对图像都包含两个输入图像和一个目标融合图像。 2. 生成器网络:利用pix2pixHD的生成器网络进行图像融合。输入图像被分别输入到生成器中,并生成两个中间合成图像。这些中间合成图像将被送入一个融合模块,最终生成合成图像。 3. 判别器网络:与传统的pix2pixHD不同,图像融合任务中的判别器网络需要判断生成的合成图像与目标融合图像的相似度。 4. 条件GAN训练:训练过程中,生成器的目标是生成逼真的合成图像,使其尽可能接近目标融合图像。判别器的目标是准确判断生成的合成图像的真实性。 5. 损失函数:与传统的pix2pixHD相似,图像融合任务中的损失函数也包括对抗损失、内容损失和边界损失等,用于指导生成器和判别器的训练。 6. 训练过程:通过迭代训练生成器和判别器,使得生成器能够逐渐生成更接近目标融合图像的合成结果,并使判别器能够准确判断生成图像的真实性。 7. 测试和推理:训练完成后,可以使用生成器将两个输入图像输入进行图像融合。生成器将生成一个合成图像,该图像是两个输入图像融合的结果。 需要注意的是,基于pix2pixHD的图像融合方法可以根据具体需求进行调整和改进。这只是一个一般的流程描述,具体实现可能会有所不同。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值