华为开源自研AI框架昇思MindSpore应用案例:ICT实现图像修复

图像补全(又称图像修复)旨在用视觉上逼真和语义上适当的内容来填充图像的缺失部分,被广泛应用于图像处理领域。卷积神经网络由于其强大的纹理建模能力,在计算机视觉领域取得了巨大进展,然而卷积神经网络在理解全局结构上表现不佳,近几年transformer的发展证明了其在建模长期关系(全局结构)方面的能力,但transformer的计算复杂度阻碍了其在处理高分辨率图像中的应用。ICT将这两种方法的优点结合到图像补全任务当中,先使用transformer进行外观先验重建,恢复了多元相干结构和一些粗糙纹理,再用卷积神经网络进行纹理补充,增强了由高分辨率掩模图像引导的粗糙先验的局部纹理细节,其基本原理如下图所示。

 

一、环境准备
1.进入ModelArts官网
云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网

选择下方CodeLab立即体验

等待环境搭建完成

2.使用CodeLab体验Notebook实例
下载NoteBook样例代码,Mask-RCNN实现实例分割 ,.ipynb为样例代码

选择ModelArts Upload Files上传.ipynb文件

选择Kernel环境

切换至GPU环境,切换成第一个限时免费

进入昇思MindSpore官网,点击上方的安装

获取安装命令

回到Notebook中,在第一块代码前加入命令


conda update -n base -c defaults conda
1


安装MindSpore 2.0 GPU版本

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
1


安装mindvision

pip install mindvision
1


安装下载download

pip install download
1


二、环境准备
ImageNet数据集
在下面的教程中,我们将通过示例代码说明如何设置网络、优化器、以及计算损失函数。在本教程中,我们使用ImageNet数据集进行训练与推理,可自行前往ImageNet数据集官网进行下载,解压后使文件呈如下目录结构:

# ImageNet dataset
.imagenet/
├── train/  (1000 directories and 1281167 images)
   ├── n04347754/
   │   ├── 000001.jpg
   │   ├── 000002.jpg
   │   └── ....
   └── n04347756/
       ├── 000001.jpg
       ├── 000002.jpg
       └── ....
├── val/  (1000 directories and 50000 images)
   ├── n04347754/
   │   ├── 000001.jpg
   │   ├── 000002.jpg
   │   └── ....
   └── n04347756/
       ├── 000001.jpg
       ├── 000002.jpg
       └── ....
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Mask掩码数据集
在图像补全任务中,我们还需要掩码数据集来对图像进行mask处理,以得到部分像素损坏的图像,掩码数据集可自行前往下载,解压后使文件呈如下目录结构:

# Mask dataset
.mask/
├── testing_mask_dataset/
   ├── 000001.png
   ├── 000002.png
   ├── 000003.png
   └── ....
1
2
3
4
5
6
7
参数设置
首先设置训练与推理相关参数,并选择执行模式,本案例默认使用动态图运行模式进行推理与训练。
值得注意的是,在训练阶段依赖VGG19的模型,权重文件VGG19.ckpt需要提前进行下载,并在参数中保持地址一致。

import os
import argparse

import mindspore
import numpy as np

from mindspore import context


def parse_args():
    """Parse args."""
    parser = argparse.ArgumentParser()
    # Parameter of train
    parser.add_argument('--data_path', type=str, default='/data0/imagenet2012/train',
                        help='Indicate where is the training set')
    parser.add_argument('--mask_path', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
                        help='Where is the mask')
    parser.add_argument('--n_layer', type=int, default=35)
    parser.add_argument('--n_head', type=int, default=8)
    parser.add_argument('--n_embd', type=int, default=1024)
    parser.add_argument('--GELU_2', action='store_true', help='use the new activation function')
    parser.add_argument('--use_ImageFolder', action='store_true', help='Using the original folder for ImageNet dataset')
    parser.add_argument('--random_stroke', action='store_true', help='Use the generated mask')
    parser.add_argument('--train_epoch', type=int, default=5, help='How many epochs')
    parser.add_argument('--learning_rate', type=float, default=3e-4, help='Value of learning rate.')
    parser.add_argument('--input', type=str, default='/data0/imagenet2012/train',
                        help='path to the input images directory or an input image')
    parser.add_argument('--mask', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
                        help='path to the masks directory or a mask file')
    parser.add_argument('--prior', type=str, default='', help='path to the edges directory or an edge file')
    parser.add_argument('--kmeans', type=str, default='../kmeans_centers.npy', help='path to the kmeans')
    # 根据VGG19.ckpt的路径,要进行相应修改,VGG19权重文件可在推理部分给出的权重地址进行下载
    parser.add_argument('--vgg_path', type=str, default='../ckpts_ICT/VGG19.ckpt', help='path to the VGG')
    parser.add_argument('--image_size', type=int, default=256, help='the size of origin image')
    parser.add_argument('--prior_random_degree', type=int, default=1, help='during training, how far deviate from')
    parser.add_argument('--use_degradation_2', action='store_true', help='use the new degradation function')
    parser.add_argument('--mode', type=int, default=1, help='1:train, 2:test')
    parser.add_argument('--mask_type', type=int, default=2)
    parser.add_argument('--max_iteration', type=int, default=25000, help='How many run iteration')
    parser.add_argument('--lr', type=float, default=0.0001, help='Value of learning rate.')
    parser.add_argument('--D2G_lr', type=float, default=0.1,
                        help='Value of discriminator/generator learning rate ratio')
    parser.add_argument("--l1_loss_weight", type=float, default=1.0)
    parser.add_argument("--style_loss_weight", type=float, default=250.0)
    parser.add_argument("--content_loss_weight", type=float, default=0.1)
    parser.add_argument("--inpaint_adv_loss_weight", type=float, default=0.1)
    parser.add_argument('--ckpt_path', type=str, default='', help='model checkpoints path')
    parser.add_argument('--save_path', type=str, default='./checkpoint', help='save checkpoints path')
    parser.add_argument('--device_id', type=int, default='0')
    parser.add_argument('--device_target', type=str, default='GPU', help='GPU or Ascend')
    parser.add_argument('--prior_size', type=int, default=32, help='Input sequence length = prior_size*prior_size')
    parser.add_argument('--batch_size', type=int, default=2, help='The number of train batch size')
    parser.add_argument("--beta1", type=float, default=0.9, help="Value of beta1")
    parser.add_argument("--beta2", type=float, default=0.95, help="Value of beta2")

    # Parameter of infer
    parser.add_argument('--image_url', type=str, default='/data0/imagenet2012/val', help='the folder of image')
    parser.add_argument('--mask_url', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
                        help='the folder of mask')
    parser.add_argument('--top_k', type=int, default=40)
    parser.add_argument('--save_url', type=str, default='./sample', help='save the output results')
    parser.add_argument('--condition_num', type=int, default=1, help='Use how many BERT output')
    args = parser.parse_known_args()[0]
    return args


# 选择执行模式为动态图模式,执行硬件平台为GPU
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
opts = parse_args()


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
训练
Transformer训练
Transformer网络结构
由于Transformer计算量偏大,因此在图片输入网络前会先对其进行下采样,在ImageNet数据中,图片会被采样到32*32分辨率,以得到图像低分辨率的图像先验信息,Transformer的结构如下所示:


以下是Transformer的代码实现:

import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.ops.operations as P
from mindspore import Tensor
from mindspore.common.initializer import Normal


class CausalSelfAttention(nn.Cell):
    """
    The CausalSelfAttention part of transformer.

    Args:
        n_embd (int): The size of the vector space in which words are embedded.
        n_head (int): The number of multi-head.
        block_size (int): The context size(Input sequence length).
        resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
        attn_pdrop (float): The probability of attn_pdrop. Default: 0.1

    Returns:
        Tensor, output tensor.
    """

    def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
        super().__init__()
        # key, query, value projections for all heads
        self.key = nn.Dense(in_channels=n_embd, out_channels=n_embd,
                            weight_init=Normal(sigma=0.02, mean=0.0))
        self.query = nn.Dense(in_channels=n_embd, out_channels=n_embd,
                              weight_init=Normal(sigma=0.02, mean=0.0))
        self.value = nn.Dense(in_channels=n_embd, out_channels=n_embd,
                              weight_init=Normal(sigma=0.02, mean=0.0))
        # regularization
        self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_pdrop)
        self.resid_drop = nn.Dropout(keep_prob=1.0 - resid_pdrop)
        # output projection
        self.proj = nn.Dense(in_channels=n_embd, out_channels=n_embd,
                             weight_init=Normal(sigma=0.02, mean=0.0))

        tril = nn.Tril()
        self.mask = mindspore.Parameter(
            tril(P.Ones()((block_size, block_size), mindspore.float32)).view(1, 1, block_size, block_size),
            requires_grad=False)

        self.n_head = n_head

    def construct(self, x):
        B, T, C = P.Shape()(x)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head)
        k = k.transpose(0, 2, 1, 3)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head)
        q = q.transpose(0, 2, 1, 3)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head)
        v = v.transpose(0, 2, 1, 3)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        k_shape = k.shape[-1]
        sz = 1.0 / (Tensor(k_shape) ** Tensor(0.5))
        q = mindspore.ops.Cast()(q, mindspore.float16)
        k = mindspore.ops.Cast()(k, mindspore.float16)
        sz = mindspore.ops.Cast()(sz, mindspore.float16)
        att = (ops.matmul(q, k.transpose(0, 1, 3, 2)) * sz)
        att = mindspore.ops.Cast()(att, mindspore.float32)
        att = P.Softmax()(att)
        att = self.attn_drop(att)
        att = mindspore.ops.Cast()(att, mindspore.float16)
        v = mindspore.ops.Cast()(v, mindspore.float16)
        y = mindspore.ops.matmul(att, v)
        y = mindspore.ops.Cast()(y, mindspore.float32)
        y = y.transpose(0, 2, 1, 3).view(B, T, C)  # re-assemble all head outputs side by side
        # output projection
        y = self.resid_drop(self.proj(y))
        return y


class GELU2(nn.Cell):
    """
    The new gelu2 activation function.

    Returns:
        Tensor, output tensor.
    """

    def construct(self, x):
        return x * P.Sigmoid()(1.702 * x)


class Block_2(nn.Cell):
    """
    Transformer block with original GELU2.

    Args:
        n_embd (int): The size of the vector space in which words are embedded.
        n_head (int): The number of multi-head.
        block_size (int): The context size(Input sequence length).
        resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
        attn_pdrop (float): The probability of attn_pdrop. Default: 0.1

    Returns:
        Tensor, output tensor.
    """

    def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.ln2 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, resid_pdrop, attn_pdrop)
        self.mlp = nn.SequentialCell([
            nn.Dense(in_channels=n_embd, out_channels=4 * n_embd,
                     weight_init=Normal(sigma=0.02, mean=0.0)),
            GELU2(),
            nn.Dense(in_channels=4 * n_embd, out_channels=n_embd,
                     weight_init=Normal(sigma=0.02, mean=0.0)),
            nn.Dropout(keep_prob=1.0 - resid_pdrop),
        ])

    def construct(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class Block(nn.Cell):
    """
    Transformer block with original GELU.

    Args:
        n_embd (int): The size of the vector space in which words are embedded.
        n_head (int): The number of multi-head.
        block_size (int): The context size(Input sequence length).
        resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
        attn_pdrop (float): The probability of attn_pdrop. Default: 0.1

    Returns:
        Tensor, output tensor.
    """

    def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.ln2 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, resid_pdrop, attn_pdrop)
        self.mlp = nn.SequentialCell([
            nn.Dense(in_channels=n_embd, out_channels=4 * n_embd,
                     weight_init=Normal(sigma=0.02, mean=0.0)),
            nn.GELU(),
            nn.Dense(in_channels=4 * n_embd, out_channels=n_embd,
                     weight_init=Normal(sigma=0.02, mean=0.0)),
            nn.Dropout(keep_prob=1.0 - resid_pdrop),
        ])

    def construct(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPT(nn.Cell):
    """
    The full GPT language model, with a context size of block_size.

    Args:
        vocab_size (int): The size of the vocabulary in the embedded data.
        n_embd (int): The size of the vector space in which words are embedded.
        n_layer (int): The number of attention layer.
        n_head (int): The number of multi-head.
        block_size (int): The context size(Input sequence length).
        use_gelu2 (bool): Use the new gelu2 activation function.
        embd_pdrop (float): The probability of embd_pdrop. Default: 0.1
        resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
        attn_pdrop (float): The probability of attn_pdrop. Default: 0.1

    Returns:
        Tensor, output tensor.
    """

    def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, block_size: int, use_gelu2: bool,
                 embd_pdrop: float = 0.1, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, n_embd, embedding_table=Normal(sigma=0.02, mean=0.0))
        self.pos_emb = mindspore.Parameter(P.Zeros()((1, block_size, n_embd), mindspore.float32))
        self.drop = nn.Dropout(keep_prob=1.0 - embd_pdrop)
        # transformer
        if use_gelu2:
            self.blocks = nn.SequentialCell(
                [*[Block_2(n_embd, n_head, block_size, resid_pdrop, attn_pdrop) for _ in range(n_layer)]])
        else:
            self.blocks = nn.SequentialCell(
                [*[Block(n_embd, n_head, block_size, resid_pdrop, attn_pdrop) for _ in range(n_layer)]])
        # decoder head
        self.ln_f = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.head = nn.Dense(in_channels=n_embd, out_channels=vocab_size, has_bias=False,
                             weight_init=Normal(sigma=0.02, mean=0.0))

        self.block_size = block_size

    def get_block_size(self):
        return self.block_size

    def construct(self, idx, masks):
        _, t = idx.shape
        token_embeddings = self.tok_emb(idx)  # each index maps to a (learnable) vector
        masks = P.ExpandDims()(masks, 2)
        token_embeddings = token_embeddings * (1 - masks)
        position_embeddings = self.pos_emb[:, :t, :]  # each position maps to a (learnable) vector
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
定义损失函数
class TransformerWithLoss(nn.Cell):
    """
    Wrap the network with loss function to return Transformer with loss.

    Args:
        backbone (Cell): The target network to wrap.
    """

    def __init__(self, backbone):
        super(TransformerWithLoss, self).__init__(auto_prefix=False)
        self.backbone = backbone
        self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)

    def construct(self, x, targets, masks):
        logits = self.backbone(x, masks)
        loss = self.loss_fn(logits.view(-1, logits.shape[-1]), targets.view(-1))
        masks = P.ExpandDims()(masks, 2)
        masks = masks.view(-1)
        loss *= masks
        loss = P.ReduceMean()(loss)
        return loss


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
定义数据集以及模型
# 修改当前路径至ICT/Transformer
if os.getcwd().endswith('ICT') or os.getcwd().endswith('ict'):
    os.chdir("./Transformer")

# 启动ImageNet数据集选项
opts.use_ImageFolder = True

from datasets.dataset import load_dataset
from transformer_utils.util import AverageMeter

kmeans = np.load('../kmeans_centers.npy')
kmeans = np.rint(127.5 * (kmeans + 1.0))

# Define the dataset
train_dataset = load_dataset(opts.data_path, kmeans, mask_path=opts.mask_path, is_train=True,
                             use_imagefolder=opts.use_ImageFolder, prior_size=opts.prior_size,
                             random_stroke=opts.random_stroke)
train_dataset = train_dataset.batch(opts.batch_size)
step_size = train_dataset.get_dataset_size()

# Define the model
block_size = opts.prior_size * opts.prior_size
transformer = GPT(vocab_size=kmeans.shape[0], n_embd=opts.n_embd, n_layer=opts.n_layer, n_head=opts.n_head,
                  block_size=block_size, use_gelu2=opts.GELU_2, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0)
model = TransformerWithLoss(backbone=transformer)

# Define the optimizer
optimizer = nn.Adam(model.trainable_params(), learning_rate=opts.learning_rate, beta1=opts.beta1, beta2=opts.beta2)

train_net = nn.TrainOneStepCell(model, optimizer)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


train_loss = AverageMeter()
best_loss = 10000000000.
if not os.path.exists(opts.save_path):
    os.mkdir(opts.save_path)
opts.train_epoch = 1
for epoch in range(opts.train_epoch):
    train_loss.reset()
    for i, sample in enumerate(train_dataset.create_dict_iterator()):
        x = sample['data']
        y = sample['mask']
        y = mindspore.ops.Cast()(y, mindspore.float32)
        loss = train_net(x, x, y)
        train_loss.update(loss, 1)
        if i % 2 == 0:
            print(f"Epoch: [{epoch} / {opts.train_epoch}], "
                  f"step: [{i} / {step_size}], "
                  f"loss: {train_loss.avg}")
            if train_loss.avg < best_loss:
                best_loss = train_loss.avg
                mindspore.save_checkpoint(transformer, os.path.join(opts.save_path, 'ImageNet_best.ckpt'))
        if i != 0 and i % 100 == 0:
            break

mindspore.save_checkpoint(transformer, os.path.join(opts.save_path, 'ImageNet_latest.ckpt'))


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

class Generator(nn.Cell):
    def __init__(self, residual_blocks=8):
        super(Generator, self).__init__()

        self.encoder = nn.SequentialCell(
            nn.Pad(((0, 0), (0, 0), (3, 3), (3, 3)), mode='REFLECT'),
            nn.Conv2d(in_channels=6, out_channels=64, kernel_size=7, pad_mode='pad', padding=0, has_bias=True),
            nn.ReLU(),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, pad_mode='pad', padding=1,
                      has_bias=True),
            nn.ReLU(),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, pad_mode='pad', padding=1,
                      has_bias=True),
            nn.ReLU()
        )

        blocks = []
        for _ in range(residual_blocks):
            block = ResnetBlock_remove_IN(256, 2)
            blocks.append(block)
        self.middle = nn.SequentialCell(*blocks)
        self.decoder = nn.SequentialCell(
            nn.Conv2dTranspose(in_channels=256, out_channels=128, kernel_size=4, stride=2, pad_mode='pad', padding=1,
                               has_bias=True),
            nn.ReLU(),

            nn.Conv2dTranspose(in_channels=128, out_channels=64, kernel_size=4, stride=2, pad_mode='pad', padding=1,
                               has_bias=True),
            nn.ReLU(),

            nn.Pad(((0, 0), (0, 0), (3, 3), (3, 3)), mode='REFLECT'),
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, pad_mode='pad', padding=0, has_bias=True),
        )

    def construct(self, images, edges, masks):
        images_masked = (images * P.Cast()((1 - masks), mindspore.float32)) + masks
        x = P.Concat(axis=1)((images_masked, edges))
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = (P.Tanh()(x) + 1) / 2

        return x


class ResnetBlock_remove_IN(nn.Cell):
    def __init__(self, dim, dilation=1):
        super(ResnetBlock_remove_IN, self).__init__()
        self.conv_block = nn.SequentialCell([
            nn.Pad(((0, 0), (0, 0), (dilation, dilation), (dilation, dilation)), mode='REFLECT'),
            nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, pad_mode='pad', dilation=dilation,
                      has_bias=True),
            nn.ReLU(),

            nn.Pad(((0, 0), (0, 0), (1, 1), (1, 1)), mode='SYMMETRIC'),
            nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, pad_mode='pad', dilation=1, has_bias=True)
        ])

    def construct(self, x):
        out = x + self.conv_block(x)
        return out


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64


class Discriminator(nn.Cell):
    def __init__(self, in_channels, use_sigmoid=True):
        super(Discriminator, self).__init__()
        self.use_sigmoid = use_sigmoid

        self.conv1 = nn.SequentialCell(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, pad_mode='pad', padding=1),
            nn.LeakyReLU()
        )

        self.conv2 = nn.SequentialCell(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, pad_mode='pad', padding=1),
            nn.LeakyReLU()
        )

        self.conv3 = nn.SequentialCell(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, pad_mode='pad', padding=1),
            nn.LeakyReLU()
        )

        self.conv4 = nn.SequentialCell(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, pad_mode='pad', padding=1),
            nn.LeakyReLU()
        )

        self.conv5 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, pad_mode='pad', padding=1)

    def construct(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)
        outputs = conv5
        if self.use_sigmoid:
            outputs = P.Sigmoid()(conv5)

        return outputs, [conv1, conv2, conv3, conv4, conv5]


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
定义数据集与网络模型
# 修改当前路径至ICT/Guided_Upsample
if os.getcwd().endswith('Transformer'):
    os.chdir("../Guided_Upsample")
elif os.getcwd().endswith('ICT') or os.getcwd().endswith('ict'):
    os.chdir("./Guided_Upsample")

from Guided_Upsample.datasets.dataset import load_dataset
from Guided_Upsample.models.loss import GeneratorWithLoss, DiscriminatorWithLoss

train_dataset = load_dataset(image_flist=opts.input, edge_flist=opts.prior, mask_filst=opts.mask,
                             image_size=opts.image_size, prior_size=opts.prior_size, mask_type=opts.mask_type,
                             kmeans=opts.kmeans, use_degradation_2=opts.use_degradation_2,
                             prior_random_degree=opts.prior_random_degree,
                             augment=True, training=True)
train_dataset = train_dataset.batch(opts.batch_size)
step_size = train_dataset.get_dataset_size()

generator = Generator()
discriminator = Discriminator(in_channels=3)
model_G = GeneratorWithLoss(generator, discriminator, opts.vgg_path, opts.inpaint_adv_loss_weight, opts.l1_loss_weight,
                            opts.content_loss_weight, opts.style_loss_weight)
model_D = DiscriminatorWithLoss(generator, discriminator)


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
定义优化器与损失函数
class PSNR(nn.Cell):
    def __init__(self, max_val):
        super(PSNR, self).__init__()

        base10 = P.Log()(mindspore.Tensor(10.0, mindspore.float32))
        max_val = P.Cast()(mindspore.Tensor(max_val), mindspore.float32)

        self.base10 = mindspore.Parameter(base10, requires_grad=False)
        self.max_val = mindspore.Parameter(20 * P.Log()(max_val) / base10, requires_grad=False)

    def __call__(self, a, b):
        a = P.Cast()(a, mindspore.float32)
        b = P.Cast()(b, mindspore.float32)
        mse = P.ReduceMean()((a - b) ** 2)

        if mse == 0:
            return mindspore.Tensor(0)

        return self.max_val - 10 * P.Log()(mse) / self.base10


psnr_func = PSNR(255.0)

# Define the optimizer
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=opts.lr, beta1=opts.beta1, beta2=opts.beta2)
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=opts.lr * opts.D2G_lr, beta1=opts.beta1,
                      beta2=opts.beta2)


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28


from Guided_Upsample.models.loss import TrainOneStepCell
from Guided_Upsample.upsample_utils.util import postprocess

train_net = TrainOneStepCell(model_G, model_D, optimizer_G, optimizer_D)
epoch = 0
iteration = 0
keep_training = True
psnr = AverageMeter()
mae = AverageMeter()
psnr_best = 0.0
opts.max_iteration = 100
if not os.path.exists(opts.save_path):
    os.mkdir(opts.save_path)
while keep_training:
    epoch += 1
    for i, sample in enumerate(train_dataset.create_dict_iterator()):
        images = sample['images']
        edges = sample['edges']
        masks = sample['masks']
        loss_D, loss_G = train_net(images, edges, masks)
        iteration += 1
        if i % 2 == 0:
            outputs = generator(images, edges, masks)
            outputs_merged = (outputs * masks) + (images * (1 - masks))
            psnr.update(psnr_func(postprocess(images), postprocess(outputs_merged)), 1)
            mae.update((P.ReduceSum()(P.Abs()(images - outputs_merged)) / P.ReduceSum()(images)), 1)
            print(f"Epoch: [{epoch}], "
                  f"step: [{i} / {step_size}], "
                  f"psnr: {psnr.avg}, mae: {mae.avg}")
            if psnr.avg > psnr_best:
                psnr_best = psnr.avg
                mindspore.save_checkpoint(generator, os.path.join(opts.save_path, 'InpaintingModel_gen_best.ckpt'))
                mindspore.save_checkpoint(discriminator, os.path.join(opts.save_path, 'InpaintingModel_dis_best.ckpt'))

        if iteration >= opts.max_iteration:
            keep_training = False
            break

mindspore.save_checkpoint(generator, os.path.join(opts.save_path, 'InpaintingModel_gen_latest.ckpt'))
mindspore.save_checkpoint(discriminator, os.path.join(opts.save_path, 'InpaintingModel_dis_latest.ckpt'))


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


推理
下面进行所训练的模型在ImageNet验证集集上的推理,由于ImageNet数据集规模过大,在此案例中从ImageNet数据集中挑选几张图片作为案例推理测试使用。
模型训练好的权重可以在此处ict_ckpt进行下载,下载完毕后放到ckpts_ICT下,其中权重文件目录如下所示,其中origin是由pytorch模型转换而来的权重文件,ms_train则是由MindSpore框架训练的权重文件:

.ckpts_ICT/
├── origin
   ├── Transformer/
   │   ├── ImageNet.ckpt
   │   ├── FFHQ.ckpt
   │   └── Places2_Nature.ckpt
   └── Upsample/
       ├── FFHQ
       │   ├──InpaintingModel_dis.ckpt
       │   └──InpaintingModel_gen.ckpt
       ├── ImageNet
       │   ├──InpaintingModel_dis.ckpt
       │   └──InpaintingModel_gen.ckpt
       └── Places2_Nature
           ├──InpaintingModel_dis.ckpt
           └──InpaintingModel_gen.ckpt
├── ms_train
   ├── Transformer/
   │   └── ImageNet_best.ckpt
   └── Upsample/
       ├── InpaintingModel_dis_best.ckpt
       └── InpaintingModel_gen_best.ckpt
├── VGG19.ckpt

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import mindspore

if not os.getcwd().endswith('ICT') or not os.getcwd().endswith('ict'):
    os.chdir("..")

opts = parse_args()
input_path = './input'

if not os.path.exists(input_path):
    os.mkdir(input_path)

if os.path.exists(opts.image_url):
    opts.image_url = os.path.join(opts.image_url, 'n01440764')
    os.system('cp {} {}'.format(os.path.join(opts.image_url, 'ILSVRC2012_val_00000293.JPEG'), input_path))
    os.system('cp {} {}'.format(os.path.join(opts.image_url, 'ILSVRC2012_val_00002138.JPEG'), input_path))

opts.image_url = input_path
C = np.load('./kmeans_centers.npy')
C = np.rint(127.5 * (C + 1.0))
C = mindspore.Tensor.from_numpy(C)

img_list = sorted(os.listdir(opts.image_url))
mask_list = sorted(os.listdir(opts.mask_url))
n_samples = opts.condition_num


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
First Stage —— Transformer
定义网络以及加载权重
from mindspore.train import Model

# Define the model
block_size = opts.prior_size * opts.prior_size
transformer = GPT(vocab_size=C.shape[0], n_embd=opts.n_embd, n_layer=opts.n_layer, n_head=opts.n_head,
                  block_size=block_size, use_gelu2=opts.GELU_2, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0)
# 根据相应权重文件进行路径修改,如果进行修改,请注意使用绝对路径避免不必要的错误
opts.ckpt_path = './ckpts_ICT/ms_train/Transformer/ImageNet_best.ckpt'
if os.path.exists(opts.ckpt_path):
    print('Start loading the model parameters from %s' % (opts.ckpt_path))
    checkpoint = mindspore.load_checkpoint(opts.ckpt_path)
    mindspore.load_param_into_net(transformer, checkpoint)
    print('Finished load the model')
transformer.set_train(False)
model = Model(transformer)


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16


from PIL import Image

from transformer_utils.util import sample_mask

for x_name, y_name in zip(img_list, mask_list):
    # load image
    print(x_name)
    image_url = os.path.join(opts.image_url, x_name)
    x = Image.open(image_url).convert("RGB")
    x = x.resize((opts.prior_size, opts.prior_size), resample=Image.BILINEAR)
    x = mindspore.Tensor.from_numpy(np.array(x)).view(-1, 3)
    x = P.Cast()(x, mindspore.float32)
    x = ((x[:, None, :] - C[None, :, :]) ** 2).sum(-1).argmin(1)

    # load mask
    mask_url = os.path.join(opts.mask_url, y_name)
    y = Image.open(mask_url).convert("L")
    y = y.resize((opts.prior_size, opts.prior_size), resample=Image.NEAREST)
    y = (np.array(y) / 255.) > 0.5
    y = mindspore.Tensor.from_numpy(y).view(-1)
    y = P.Cast()(y, mindspore.float32)

    x_list = [x] * n_samples
    x_tensor = P.Stack()(x_list)
    y_list = [y] * n_samples
    y_tensor = P.Stack()(y_list)
    x_tensor = P.Cast()(x_tensor * (1 - y_tensor), mindspore.int32)
    outputs = sample_mask(model, x_tensor, y_tensor, length=opts.prior_size * opts.prior_size,
                          top_k=opts.top_k)

    img_name = x_name[:x_name.find('.')] + x_name[x_name.find('.'):]
    for i in range(n_samples):
        current_url = os.path.join(opts.save_url, 'condition_%d' % (i + 1))
        os.makedirs(current_url, exist_ok=True)
        current_img = C[outputs[i]].view(opts.prior_size, opts.prior_size, 3).asnumpy().astype(np.uint8)
        tmp = Image.fromarray(current_img)
        tmp.save(os.path.join(current_url, img_name))


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


Second Stage —— Upsample
利用Transformer产生的图像先验信息进行图像上采样,恢复到高分辨图片。

def stitch_images(inputs, *outputs, img_per_row=2):
    gap = 5
    columns = len(outputs) + 1

    width, height = inputs[0][:, :, 0].shape
    img = Image.new('RGB',
                    (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
    images = [inputs, *outputs]

    for ix in range(len(inputs)):
        xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
        yoffset = int(ix / img_per_row) * height

        for cat in range(len(images)):
            im = images[cat][ix].asnumpy().astype(np.uint8).squeeze()
            im = Image.fromarray(im)
            img.paste(im, (xoffset + cat * width, yoffset))

    return img


from upsample_utils.util import postprocess, imsave

opts.mask_type = 3
opts.mode = 2
opts.input = './input/'
opts.kmeans = './kmeans_centers.npy'
# 确保第二阶段的输入与第一阶段的输出相同
opts.prior = opts.save_url

generator = Generator()
generator.set_train(False)
# 根据相应权重文件进行路径修改,如果进行修改,请注意使用绝对路径避免不必要的错误
opts.ckpt_path = './ckpts_ICT/ms_train/Upsample/InpaintingModel_gen_best.ckpt'
if os.path.exists(opts.ckpt_path):
    print('Start loading the model parameters from %s' % (opts.ckpt_path))
    checkpoint = mindspore.load_checkpoint(opts.ckpt_path)
    mindspore.load_param_into_net(generator, checkpoint)
    print('Finished load the model')

psnr_func = PSNR(255.0)

test_dataset = load_dataset(image_flist=opts.input, edge_flist=opts.prior, mask_filst=opts.mask,
                            image_size=opts.image_size, prior_size=opts.prior_size, mask_type=opts.mask_type,
                            kmeans=opts.kmeans, condition_num=opts.condition_num,
                            augment=False, training=False)

index = 0
psnr = AverageMeter()
mae = AverageMeter()
test_batch_size = 1
test_dataset = test_dataset.batch(test_batch_size)
for sample in test_dataset.create_dict_iterator():
    name = sample['name'].asnumpy()[0]
    images = sample['images']
    edges = sample['edges']
    masks = sample['masks']
    inputs = (images * (1 - masks)) + masks
    index += test_batch_size
    outputs = generator(images, edges, masks)
    outputs_merged = (outputs * masks) + (images * (1 - masks))
    psnr.update(psnr_func(postprocess(images), postprocess(outputs_merged)), 1)
    mae.update((P.ReduceSum()(P.Abs()(images - outputs_merged)) / P.ReduceSum()(images)), 1)
    result_merge = stitch_images(
        postprocess(images),
        postprocess(inputs),
        postprocess(outputs_merged),
        img_per_row=1
    )
    result_merge.show()
    output = postprocess(outputs_merged)[0]
    path = os.path.join(opts.save_url, name[:-4] + "_%d" % (index % opts.condition_num) + '.png')
    imsave(output, path)
print('PSNR: {}, MAE: {}'.format(psnr.avg, mae.avg))


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

————————————————

                            版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
                        
原文链接:https://blog.csdn.net/qq_46207024/article/details/141317651

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值