昇思25天学习打卡营第23天 | FCN图像语义分割

FCN图像语义分割

https://gitee.com/mindspore/docs/blob/r2.3/tutorials/application/source_zh_cn/cv/fcn8s.ipynb

全卷积网络(Fully Convolutional Networks,FCN)是UC Berkeley的Jonathan Long等人于2015年在Fully Convolutional Networks for Semantic Segmentation[1]一文中提出的用于图像语义分割的一种框架。

FCN是首个端到端(end to end)进行像素级(pixel level)预测的全卷积网络。

fcn-1

语义分割

在具体介绍FCN之前,首先介绍何为语义分割:

图像语义分割(semantic segmentation)是图像处理和机器视觉技术中关于图像理解的重要一环,AI领域中一个重要分支,常被应用于人脸识别、物体检测、医学影像、卫星图像分析、自动驾驶感知等领域。

语义分割的目的是对图像中每个像素点进行分类。与普通的分类任务只输出某个类别不同,语义分割任务输出与输入大小相同的图像,输出图像的每个像素对应了输入图像每个像素的类别。语义在图像领域指的是图像的内容,对图片意思的理解,下图是一些语义分割的实例:

fcn-2

模型简介

FCN主要用于图像分割领域,是一种端到端的分割方法,是深度学习应用在图像语义分割的开山之作。通过进行像素级的预测直接得出与原图大小相等的label map。因FCN丢弃全连接层替换为全卷积层,网络所有层均为卷积层,故称为全卷积网络。

全卷积神经网络主要使用以下三种技术:

  1. 卷积化(Convolutional)

    使用VGG-16作为FCN的backbone。VGG-16的输入为224*224的RGB图像,输出为1000个预测值。VGG-16只能接受固定大小的输入,丢弃了空间坐标,产生非空间输出。VGG-16中共有三个全连接层,全连接层也可视为带有覆盖整个区域的卷积。将全连接层转换为卷积层能使网络输出由一维非空间输出变为二维矩阵,利用输出能生成输入图片映射的heatmap。

    fcn-3

  2. 上采样(Upsample)

    在卷积过程的卷积操作和池化操作会使得特征图的尺寸变小,为得到原图的大小的稠密图像预测,需要对得到的特征图进行上采样操作。使用双线性插值的参数来初始化上采样逆卷积的参数,后通过反向传播来学习非线性上采样。在网络中执行上采样,以通过像素损失的反向传播进行端到端的学习。

    fcn-4

  3. 跳跃结构(Skip Layer)

    利用上采样技巧对最后一层的特征图进行上采样得到原图大小的分割是步长为32像素的预测,称之为FCN-32s。由于最后一层的特征图太小,损失过多细节,采用skips结构将更具有全局信息的最后一层预测和更浅层的预测结合,使预测结果获取更多的局部细节。将底层(stride 32)的预测(FCN-32s)进行2倍的上采样得到原尺寸的图像,并与从pool4层(stride 16)进行的预测融合起来(相加),这一部分的网络被称为FCN-16s。随后将这一部分的预测再进行一次2倍的上采样并与从pool3层得到的预测融合起来,这一部分的网络被称为FCN-8s。 Skips结构将深层的全局信息与浅层的局部信息相结合。

    fcn-5

网络特点

  1. 不含全连接层(fc)的全卷积(fully conv)网络,可适应任意尺寸输入。
  2. 增大数据尺寸的反卷积(deconv)层,能够输出精细的结果。
  3. 结合不同深度层结果的跳级(skip)结构,同时确保鲁棒性和精确性。

数据处理

开始实验前,需确保本地已经安装Python环境及MindSpore。

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
# !pip uninstall mindspore -y
# !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"

download(url, "./dataset", kind="tar", replace=True)
Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar (537.2 MB)

file_sizes: 100%|█████████████████████████████| 563M/563M [00:03<00:00, 145MB/s]
Extracting tar file...
Successfully downloaded / unzipped to ./dataset





'./dataset'

数据预处理

由于PASCAL VOC 2012数据集中图像的分辨率大多不一致,无法放在一个tensor中,故输入前需做标准化处理。

数据加载

将PASCAL VOC 2012数据集与SDB数据集进行混合。

import numpy as np
import cv2
import mindspore.dataset as ds

class SegDataset:
    def __init__(self,
                 image_mean,
                 image_std,
                 data_file='',
                 batch_size=32,
                 crop_size=512,
                 max_scale=2.0,
                 min_scale=0.5,
                 ignore_label=255,
                 num_classes=21,
                 num_readers=2,
                 num_parallel_calls=4):

        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        max_scale > min_scale

    def preprocess_dataset(self, image, label):
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        sc = np.random.uniform(self.min_scale, self.max_scale)
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
            label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
        label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out

    def get_dataset(self):
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],
                                 shuffle=True, num_parallel_workers=self.num_readers)
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],
                              output_columns=["data", "label"],
                              num_parallel_workers=self.num_parallel_calls)
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset


# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21

# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)

dataset = dataset.get_dataset()

训练集可视化

运行以下代码观察载入的数据集图片(数据处理过程中已做归一化处理)。

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 8))

# 对训练集中的数据进行展示
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)
# 将图片转换HWC格式后进行展示
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

在这里插入图片描述

网络构建

网络流程

FCN网络的流程如下图所示:

  1. 输入图像image,经过pool1池化后,尺寸变为原始尺寸的1/2。
  2. 经过pool2池化,尺寸变为原始尺寸的1/4。
  3. 接着经过pool3、pool4、pool5池化,大小分别变为原始尺寸的1/8、1/16、1/32。
  4. 经过conv6-7卷积,输出的尺寸依然是原图的1/32。
  5. FCN-32s是最后使用反卷积,使得输出图像大小与输入图像相同。
  6. FCN-16s是将conv7的输出进行反卷积,使其尺寸扩大两倍至原图的1/16,并将其与pool4输出的特征图进行融合,后通过反卷积扩大到原始尺寸。
  7. FCN-8s是将conv7的输出进行反卷积扩大4倍,将pool4输出的特征图反卷积扩大2倍,并将pool3输出特征图拿出,三者融合后通反卷积扩大到原始尺寸。

fcn-6

使用以下代码构建FCN-8s网络。

import mindspore.nn as nn

class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        self.conv1 = nn.SequentialCell(
            nn.Conv2d(in_channels=3, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.SequentialCell(
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.SequentialCell(
            nn.Conv2d(in_channels=128, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.SequentialCell(
            nn.Conv2d(in_channels=256, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=4096,
                      kernel_size=7, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(in_channels=4096, out_channels=4096,
                      kernel_size=1, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
                                  kernel_size=1, weight_init='xavier_uniform')
        self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                                kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=16, stride=8, weight_init='xavier_uniform')

    def construct(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        x6 = self.conv6(p5)
        x7 = self.conv7(x6)
        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        u4 = self.upscore_pool4(f4)
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        out = self.upscore8(f3)
        return out

训练准备

导入VGG-16部分预训练权重

FCN使用VGG-16作为骨干网络,用于实现图像编码。使用下面代码导入VGG-16预训练模型的部分预训练权重。

from download import download
from mindspore import load_checkpoint, load_param_into_net

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt (513.2 MB)

file_sizes: 100%|█████████████████████████████| 538M/538M [00:03<00:00, 164MB/s]
Successfully downloaded file to fcn8s_vgg16_pretrain.ckpt

损失函数

语义分割是对图像中每个像素点进行分类,仍是分类问题,故损失函数选择交叉熵损失函数来计算FCN网络输出与mask之间的交叉熵损失。这里我们使用的是mindspore.nn.CrossEntropyLoss()作为损失函数。

自定义评价指标 Metrics

这一部分主要对训练出来的模型效果进行评估,为了便于解释,假设如下:共有 k + 1 k+1 k+1 个类(从 L 0 L_0 L0 L k L_k Lk, 其中包含一个空类或背景), p i j p_{i j} pij 表示本属于 i i i类但被预测为 j j j类的像素数量。即, p i i p_{i i} pii 表示真正的数量, 而 p i j p j i p_{i j} p_{j i} pijpji 则分别被解释为假正和假负, 尽管两者都是假正与假负之和。

  • Pixel Accuracy(PA, 像素精度):这是最简单的度量,为标记正确的像素占总像素的比例。

P A = ∑ i = 0 k p i i ∑ i = 0 k ∑ j = 0 k p i j P A=\frac{\sum_{i=0}^k p_{i i}}{\sum_{i=0}^k \sum_{j=0}^k p_{i j}} PA=i=0kj=0kpiji=0kpii

  • Mean Pixel Accuracy(MPA, 均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。

M P A = 1 k + 1 ∑ i = 0 k p i i ∑ j = 0 k p i j M P A=\frac{1}{k+1} \sum_{i=0}^k \frac{p_{i i}}{\sum_{j=0}^k p_{i j}} MPA=k+11i=0kj=0kpijpii

  • Mean Intersection over Union(MloU, 均交并比):为语义分割的标准度量。其计算两个集合的交集和并集之,在语义分割的问题中,这两个集合为真实值(ground truth) 和预测值(predicted segmentation)。这个比例可以变形为正真数 (intersection) 比上真正、假负、假正(并集)之和。在每个类上计算loU,之后平均。

M I o U = 1 k + 1 ∑ i = 0 k p i i ∑ j = 0 k p i j + ∑ j = 0 k p j i − p i i M I o U=\frac{1}{k+1} \sum_{i=0}^k \frac{p_{i i}}{\sum_{j=0}^k p_{i j}+\sum_{j=0}^k p_{j i}-p_{i i}} MIoU=k+11i=0kj=0kpij+j=0kpjipiipii

  • Frequency Weighted Intersection over Union(FWIoU, 频权交井比):为MloU的一种提升,这种方法根据每个类出现的频率为其设置权重。

F W I o U = 1 ∑ i = 0 k ∑ j = 0 k p i j ∑ i = 0 k p i i ∑ j = 0 k p i j + ∑ j = 0 k p j i − p i i F W I o U=\frac{1}{\sum_{i=0}^k \sum_{j=0}^k p_{i j}} \sum_{i=0}^k \frac{p_{i i}}{\sum_{j=0}^k p_{i j}+\sum_{j=0}^k p_{j i}-p_{i i}} FWIoU=i=0kj=0kpij1i=0kj=0kpij+j=0kpjipiipii

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train

class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def eval(self):
        pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return pixel_accuracy


class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy


class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        mean_iou = np.nanmean(mean_iou)
        return mean_iou


class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))

        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou

模型训练

导入VGG-16预训练参数后,实例化损失函数、优化器,使用Model接口编译网络,训练FCN-8s网络。

import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model

device_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)

train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs

lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                            base_lr,
                                            total_step,
                                            iters_per_epoch,
                                            decay_epoch=2)
lr = Tensor(lr_scheduler[-1])

# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})

# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10,
                               keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s",
                                directory="./ckpt",
                                config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)
epoch: 1 step: 1, loss is 3.0538573
epoch: 1 step: 2, loss is 3.0008237
epoch: 1 step: 3, loss is 2.93123
epoch: 1 step: 4, loss is 2.8197358
epoch: 1 step: 5, loss is 2.668547
epoch: 1 step: 6, loss is 2.3382783
epoch: 1 step: 7, loss is 2.0454466
epoch: 1 step: 8, loss is 1.374428
epoch: 1 step: 9, loss is 2.889871
epoch: 1 step: 10, loss is 2.120362
epoch: 1 step: 11, loss is 1.7049471
epoch: 1 step: 12, loss is 0.92450166
epoch: 1 step: 13, loss is 2.4112763
epoch: 1 step: 14, loss is 1.5740921
epoch: 1 step: 15, loss is 1.8964003
epoch: 1 step: 16, loss is 1.17377
epoch: 1 step: 17, loss is 1.8548889
epoch: 1 step: 18, loss is 2.040294
epoch: 1 step: 19, loss is 1.8032327
epoch: 1 step: 20, loss is 1.0739855
epoch: 1 step: 21, loss is 1.276482
epoch: 1 step: 22, loss is 1.3008155
epoch: 1 step: 23, loss is 1.198459
epoch: 1 step: 24, loss is 0.9260506
epoch: 1 step: 25, loss is 1.1306816
epoch: 1 step: 26, loss is 1.302331
epoch: 1 step: 27, loss is 0.7967311
epoch: 1 step: 28, loss is 0.51942986
epoch: 1 step: 29, loss is 1.7102058
epoch: 1 step: 30, loss is 1.8468792
epoch: 1 step: 31, loss is 1.1661999
epoch: 1 step: 32, loss is 2.5722916
epoch: 1 step: 33, loss is 0.9790484
epoch: 1 step: 34, loss is 0.9009774
epoch: 1 step: 35, loss is 1.9052577
epoch: 1 step: 36, loss is 1.4671967
epoch: 1 step: 37, loss is 1.5104197
epoch: 1 step: 38, loss is 1.7782892
epoch: 1 step: 39, loss is 1.8547891
epoch: 1 step: 40, loss is 1.387258
epoch: 1 step: 41, loss is 1.8291082
epoch: 1 step: 42, loss is 0.659582
epoch: 1 step: 43, loss is 1.5941986
epoch: 1 step: 44, loss is 3.7530715
epoch: 1 step: 45, loss is 1.2560936
epoch: 1 step: 46, loss is 1.1510864
epoch: 1 step: 47, loss is 1.6902436
epoch: 1 step: 48, loss is 1.2025737
epoch: 1 step: 49, loss is 1.7223932
epoch: 1 step: 50, loss is 1.1915472
epoch: 1 step: 51, loss is 2.0016415
epoch: 1 step: 52, loss is 1.808786
epoch: 1 step: 53, loss is 1.2604945
epoch: 1 step: 54, loss is 2.556777
epoch: 1 step: 55, loss is 1.2556043
epoch: 1 step: 56, loss is 1.9303658
epoch: 1 step: 57, loss is 1.8376763
epoch: 1 step: 58, loss is 1.9239762
epoch: 1 step: 59, loss is 2.7055879
epoch: 1 step: 60, loss is 1.3846637
epoch: 1 step: 61, loss is 1.8099169
epoch: 1 step: 62, loss is 1.4711019
epoch: 1 step: 63, loss is 1.7987081
epoch: 1 step: 64, loss is 2.4578075
epoch: 1 step: 65, loss is 2.0836825
epoch: 1 step: 66, loss is 1.4903432
epoch: 1 step: 67, loss is 1.720757
epoch: 1 step: 68, loss is 1.3626233
epoch: 1 step: 69, loss is 2.2197456
epoch: 1 step: 70, loss is 0.96732336
epoch: 1 step: 71, loss is 0.9555376
epoch: 1 step: 72, loss is 2.3950076
epoch: 1 step: 73, loss is 1.071553
epoch: 1 step: 74, loss is 1.6058477
epoch: 1 step: 75, loss is 1.8470367
epoch: 1 step: 76, loss is 0.84497297
epoch: 1 step: 77, loss is 1.2816542
epoch: 1 step: 78, loss is 1.5322306
epoch: 1 step: 79, loss is 2.082866
epoch: 1 step: 80, loss is 0.73774403
epoch: 1 step: 81, loss is 1.3971207
epoch: 1 step: 82, loss is 1.1055415
epoch: 1 step: 83, loss is 1.4079438
epoch: 1 step: 84, loss is 1.9287384
epoch: 1 step: 85, loss is 1.1752979
epoch: 1 step: 86, loss is 1.5902065
epoch: 1 step: 87, loss is 0.9796615
epoch: 1 step: 88, loss is 1.5459417
epoch: 1 step: 89, loss is 1.5211993
epoch: 1 step: 90, loss is 1.4763175
epoch: 1 step: 91, loss is 1.1391922
epoch: 1 step: 92, loss is 1.6960126
epoch: 1 step: 93, loss is 0.9214844
epoch: 1 step: 94, loss is 1.4570516
epoch: 1 step: 95, loss is 1.8151841
epoch: 1 step: 96, loss is 1.9713868
epoch: 1 step: 97, loss is 1.4357623
epoch: 1 step: 98, loss is 1.0030726
epoch: 1 step: 99, loss is 0.91766536
epoch: 1 step: 100, loss is 1.9359992
epoch: 1 step: 101, loss is 1.9648862
epoch: 1 step: 102, loss is 1.2478971
epoch: 1 step: 103, loss is 2.2104764
epoch: 1 step: 104, loss is 1.4705557
epoch: 1 step: 105, loss is 1.7346936
epoch: 1 step: 106, loss is 1.6611106
epoch: 1 step: 107, loss is 1.9504085
epoch: 1 step: 108, loss is 1.1252086
epoch: 1 step: 109, loss is 1.2663596
epoch: 1 step: 110, loss is 2.380007
epoch: 1 step: 111, loss is 1.4858794
epoch: 1 step: 112, loss is 1.9616028
epoch: 1 step: 113, loss is 1.4791129
epoch: 1 step: 114, loss is 1.5439643
epoch: 1 step: 115, loss is 1.3739212
epoch: 1 step: 116, loss is 1.0753161
epoch: 1 step: 117, loss is 1.9428264
epoch: 1 step: 118, loss is 1.8953425
epoch: 1 step: 119, loss is 1.4126146
epoch: 1 step: 120, loss is 0.9011042
epoch: 1 step: 121, loss is 1.8002346
epoch: 1 step: 122, loss is 0.6776845
epoch: 1 step: 123, loss is 1.826196
epoch: 1 step: 124, loss is 0.8350794
epoch: 1 step: 125, loss is 1.3619298
epoch: 1 step: 126, loss is 0.83396816
epoch: 1 step: 127, loss is 1.3385247
epoch: 1 step: 128, loss is 2.307642
epoch: 1 step: 129, loss is 1.8814622
epoch: 1 step: 130, loss is 1.038039
epoch: 1 step: 131, loss is 0.85279
epoch: 1 step: 132, loss is 1.2884648
epoch: 1 step: 133, loss is 1.4709547
epoch: 1 step: 134, loss is 1.7986912
epoch: 1 step: 135, loss is 1.345838
epoch: 1 step: 136, loss is 1.3462955
epoch: 1 step: 137, loss is 1.6120331
epoch: 1 step: 138, loss is 1.6924083
epoch: 1 step: 139, loss is 0.9375493
epoch: 1 step: 140, loss is 1.0473466
epoch: 1 step: 142, loss is 0.56790984
epoch: 1 step: 143, loss is 2.1579657
epoch: 1 step: 144, loss is 1.829849
epoch: 1 step: 145, loss is 1.4958788
epoch: 1 step: 146, loss is 0.57452494
epoch: 1 step: 147, loss is 1.3633496
epoch: 1 step: 148, loss is 1.1109313
epoch: 1 step: 149, loss is 1.2465361
epoch: 1 step: 150, loss is 2.3619897
epoch: 1 step: 151, loss is 1.2183244
epoch: 1 step: 152, loss is 1.3268704
epoch: 1 step: 153, loss is 1.6576173
epoch: 1 step: 154, loss is 1.3380648
epoch: 1 step: 155, loss is 2.070176
epoch: 1 step: 156, loss is 1.5775394
epoch: 1 step: 157, loss is 1.7351686
epoch: 1 step: 158, loss is 1.3684554
epoch: 1 step: 159, loss is 1.12731
epoch: 1 step: 160, loss is 1.1406639
epoch: 1 step: 161, loss is 1.7836227
epoch: 1 step: 162, loss is 1.3046056
epoch: 1 step: 163, loss is 1.1335624
epoch: 1 step: 164, loss is 0.83398646
epoch: 1 step: 165, loss is 2.1292508
epoch: 1 step: 166, loss is 1.0319086
epoch: 1 step: 167, loss is 1.1100733
epoch: 1 step: 168, loss is 2.092089
epoch: 1 step: 169, loss is 1.2686191
epoch: 1 step: 170, loss is 1.2458344
epoch: 1 step: 171, loss is 1.7166963
epoch: 1 step: 172, loss is 1.476351
epoch: 1 step: 173, loss is 1.6506145
epoch: 1 step: 174, loss is 1.6585238
epoch: 1 step: 175, loss is 1.0212238
epoch: 1 step: 176, loss is 0.9486788
epoch: 1 step: 177, loss is 2.020333
epoch: 1 step: 178, loss is 0.98453736
epoch: 1 step: 179, loss is 2.7886019
epoch: 1 step: 180, loss is 1.0922515
epoch: 1 step: 181, loss is 2.3652806
epoch: 1 step: 182, loss is 1.0841622
epoch: 1 step: 183, loss is 1.4091527
epoch: 1 step: 184, loss is 1.7552348
epoch: 1 step: 185, loss is 1.9400871
epoch: 1 step: 186, loss is 2.8875933
epoch: 1 step: 187, loss is 1.2455513
epoch: 1 step: 188, loss is 1.7234563
epoch: 1 step: 189, loss is 1.2126485
epoch: 1 step: 190, loss is 1.2171657
epoch: 1 step: 191, loss is 2.4938555
epoch: 1 step: 192, loss is 1.6721321
epoch: 1 step: 193, loss is 1.3619661
epoch: 1 step: 194, loss is 1.0940161
epoch: 1 step: 195, loss is 1.0598927
epoch: 1 step: 196, loss is 1.7291107
epoch: 1 step: 197, loss is 1.7518456
epoch: 1 step: 198, loss is 1.1776814
epoch: 1 step: 199, loss is 0.823361
epoch: 1 step: 200, loss is 1.4129037
epoch: 1 step: 201, loss is 1.9944301
epoch: 1 step: 202, loss is 2.0459101
epoch: 1 step: 203, loss is 1.2689497
epoch: 1 step: 204, loss is 1.4202747
epoch: 1 step: 205, loss is 2.0625627
epoch: 1 step: 206, loss is 0.782083
epoch: 1 step: 207, loss is 1.9816635
epoch: 1 step: 208, loss is 1.3628544
epoch: 1 step: 209, loss is 1.0884308
epoch: 1 step: 210, loss is 1.3431948
epoch: 1 step: 211, loss is 1.230685
epoch: 1 step: 212, loss is 1.1206374
epoch: 1 step: 213, loss is 0.7444167
epoch: 1 step: 214, loss is 1.5900314
epoch: 1 step: 215, loss is 1.6452525
epoch: 1 step: 216, loss is 1.728012
epoch: 1 step: 217, loss is 0.87217253
epoch: 1 step: 218, loss is 1.47408
epoch: 1 step: 219, loss is 0.9538213
epoch: 1 step: 220, loss is 1.5637614
epoch: 1 step: 221, loss is 1.3780619
epoch: 1 step: 222, loss is 0.89218736
epoch: 1 step: 223, loss is 1.509501
epoch: 1 step: 224, loss is 1.5050203
epoch: 1 step: 225, loss is 1.8802297
epoch: 1 step: 226, loss is 1.627268
epoch: 1 step: 227, loss is 0.87056106
epoch: 1 step: 228, loss is 1.2578086
epoch: 1 step: 229, loss is 1.1119819
epoch: 1 step: 230, loss is 1.2356167
epoch: 1 step: 231, loss is 1.3253276
epoch: 1 step: 232, loss is 1.5741605
epoch: 1 step: 233, loss is 1.204543
epoch: 1 step: 234, loss is 0.8889137
epoch: 1 step: 235, loss is 1.4568745
epoch: 1 step: 236, loss is 1.4847553
epoch: 1 step: 237, loss is 1.4984345
epoch: 1 step: 238, loss is 1.6147062
epoch: 1 step: 239, loss is 1.4532025
epoch: 1 step: 240, loss is 1.802917
epoch: 1 step: 241, loss is 1.7582785
epoch: 1 step: 242, loss is 1.597206
epoch: 1 step: 243, loss is 2.3590007
epoch: 1 step: 244, loss is 1.7156513
epoch: 1 step: 245, loss is 1.7260761
epoch: 1 step: 246, loss is 1.3954093
epoch: 1 step: 247, loss is 1.7300851
epoch: 1 step: 248, loss is 1.6917727
epoch: 1 step: 249, loss is 1.4999915
epoch: 1 step: 250, loss is 1.1593745
epoch: 1 step: 251, loss is 1.3651739
epoch: 1 step: 252, loss is 1.6455175
epoch: 1 step: 253, loss is 1.2750356
epoch: 1 step: 254, loss is 1.2784407
epoch: 1 step: 255, loss is 1.8596032
epoch: 1 step: 256, loss is 2.3892477
epoch: 1 step: 257, loss is 1.3523878
epoch: 1 step: 258, loss is 2.0215192
epoch: 1 step: 259, loss is 0.89530385
epoch: 1 step: 260, loss is 1.4485546
epoch: 1 step: 261, loss is 1.4127266
epoch: 1 step: 262, loss is 1.249488
epoch: 1 step: 263, loss is 1.6189716
epoch: 1 step: 264, loss is 1.0317703
epoch: 1 step: 265, loss is 1.1898699
epoch: 1 step: 266, loss is 1.5803378
epoch: 1 step: 267, loss is 1.5654547
epoch: 1 step: 268, loss is 0.80530256
epoch: 1 step: 269, loss is 0.75385237
epoch: 1 step: 270, loss is 2.2383232
epoch: 1 step: 271, loss is 1.7888939
epoch: 1 step: 272, loss is 2.9838006
epoch: 1 step: 273, loss is 2.0929375
epoch: 1 step: 274, loss is 1.8991345
epoch: 1 step: 275, loss is 2.308967
epoch: 1 step: 276, loss is 1.842851
epoch: 1 step: 277, loss is 2.1976917
epoch: 1 step: 278, loss is 1.4825155
epoch: 1 step: 279, loss is 1.5989889
epoch: 1 step: 280, loss is 1.4417198
epoch: 1 step: 281, loss is 1.4095142
epoch: 1 step: 282, loss is 1.2511361
epoch: 1 step: 283, loss is 0.8681016
epoch: 1 step: 284, loss is 2.623462
epoch: 1 step: 285, loss is 1.1502581
epoch: 1 step: 286, loss is 1.8815742
epoch: 1 step: 287, loss is 2.0392356
epoch: 1 step: 288, loss is 1.9739225
epoch: 1 step: 289, loss is 1.7844248
epoch: 1 step: 290, loss is 1.4983072
epoch: 1 step: 291, loss is 1.9216789
epoch: 1 step: 292, loss is 1.3301548
epoch: 1 step: 293, loss is 1.7415706
epoch: 1 step: 294, loss is 1.3552836
epoch: 1 step: 295, loss is 0.6301868
epoch: 1 step: 296, loss is 1.3576651
epoch: 1 step: 297, loss is 2.0440025
epoch: 1 step: 298, loss is 3.128151
epoch: 1 step: 299, loss is 1.963721
epoch: 1 step: 300, loss is 2.1215444
epoch: 1 step: 301, loss is 1.7985905
epoch: 1 step: 302, loss is 1.5894812
epoch: 1 step: 303, loss is 1.4806875
epoch: 1 step: 304, loss is 1.44601
epoch: 1 step: 305, loss is 1.6993293
epoch: 1 step: 306, loss is 1.3276365
epoch: 1 step: 307, loss is 1.93879
epoch: 1 step: 308, loss is 1.3397124
epoch: 1 step: 309, loss is 2.0331826
epoch: 1 step: 310, loss is 1.724363
epoch: 1 step: 311, loss is 1.2008224
epoch: 1 step: 312, loss is 0.95647067
epoch: 1 step: 313, loss is 1.5039113
epoch: 1 step: 314, loss is 1.6690071
epoch: 1 step: 315, loss is 1.1511701
epoch: 1 step: 316, loss is 1.3557866
epoch: 1 step: 317, loss is 1.0289929
epoch: 1 step: 318, loss is 1.4153924
epoch: 1 step: 319, loss is 2.3355012
epoch: 1 step: 320, loss is 2.309549
epoch: 1 step: 321, loss is 1.7569923
epoch: 1 step: 322, loss is 1.410734
epoch: 1 step: 323, loss is 1.2817314
epoch: 1 step: 324, loss is 1.0308682
epoch: 1 step: 325, loss is 1.5445722
epoch: 1 step: 326, loss is 1.3478638
epoch: 1 step: 327, loss is 0.9241598
epoch: 1 step: 328, loss is 1.805352
epoch: 1 step: 329, loss is 1.131786
epoch: 1 step: 330, loss is 2.5791283
epoch: 1 step: 331, loss is 1.6182013
epoch: 1 step: 332, loss is 2.420151
epoch: 1 step: 333, loss is 1.4631511
epoch: 1 step: 334, loss is 1.9793999
epoch: 1 step: 335, loss is 1.5215651
epoch: 1 step: 336, loss is 1.9355025
epoch: 1 step: 337, loss is 1.7679123
epoch: 1 step: 338, loss is 1.4034736
epoch: 1 step: 339, loss is 1.365607
epoch: 1 step: 340, loss is 0.8819081
epoch: 1 step: 341, loss is 1.1885784
epoch: 1 step: 342, loss is 2.2224824
epoch: 1 step: 343, loss is 0.81424063
epoch: 1 step: 344, loss is 2.1172006
epoch: 1 step: 345, loss is 0.7818281
epoch: 1 step: 346, loss is 2.1715128
epoch: 1 step: 347, loss is 1.3026395
epoch: 1 step: 348, loss is 2.327958
epoch: 1 step: 349, loss is 1.8538085
epoch: 1 step: 350, loss is 1.6179016
epoch: 1 step: 351, loss is 1.2931116
epoch: 1 step: 352, loss is 1.4702431
epoch: 1 step: 353, loss is 1.5970409
epoch: 1 step: 354, loss is 1.7783672
epoch: 1 step: 355, loss is 1.9051666
epoch: 1 step: 356, loss is 2.2975245
epoch: 1 step: 357, loss is 1.462628
epoch: 1 step: 358, loss is 1.3256626
epoch: 1 step: 359, loss is 1.0927052
epoch: 1 step: 360, loss is 1.1338124
epoch: 1 step: 361, loss is 0.94072384
epoch: 1 step: 362, loss is 1.4085321
epoch: 1 step: 363, loss is 1.8808831
epoch: 1 step: 364, loss is 1.918277
epoch: 1 step: 365, loss is 2.044463
epoch: 1 step: 366, loss is 0.91307217
epoch: 1 step: 367, loss is 1.1799338
epoch: 1 step: 368, loss is 2.3033905
epoch: 1 step: 369, loss is 1.0327874
epoch: 1 step: 370, loss is 0.9989224
epoch: 1 step: 371, loss is 1.0159185
epoch: 1 step: 372, loss is 1.0173409
epoch: 1 step: 373, loss is 1.2576902
epoch: 1 step: 374, loss is 1.1491234
epoch: 1 step: 375, loss is 0.7996969
epoch: 1 step: 376, loss is 2.0646849
epoch: 1 step: 377, loss is 0.9671624
epoch: 1 step: 378, loss is 1.2011591
epoch: 1 step: 379, loss is 0.6776372
epoch: 1 step: 380, loss is 2.3544633
epoch: 1 step: 381, loss is 1.2949228
epoch: 1 step: 382, loss is 1.1072327
epoch: 1 step: 383, loss is 1.2927371
epoch: 1 step: 384, loss is 2.1590984
epoch: 1 step: 385, loss is 2.0417051
epoch: 1 step: 386, loss is 1.2747768
epoch: 1 step: 387, loss is 1.3899628
epoch: 1 step: 388, loss is 1.954678
epoch: 1 step: 389, loss is 1.7340003
epoch: 1 step: 390, loss is 1.1878844
epoch: 1 step: 391, loss is 1.6485598
epoch: 1 step: 392, loss is 1.4867641
epoch: 1 step: 393, loss is 2.0560706
epoch: 1 step: 394, loss is 1.4391658
epoch: 1 step: 395, loss is 1.4367392
epoch: 1 step: 396, loss is 1.5290427
epoch: 1 step: 397, loss is 1.4072627
epoch: 1 step: 398, loss is 1.0043467
epoch: 1 step: 399, loss is 2.2134423
epoch: 1 step: 400, loss is 1.230771
epoch: 1 step: 401, loss is 1.3548392
epoch: 1 step: 402, loss is 1.6942997
epoch: 1 step: 403, loss is 2.2216513
epoch: 1 step: 404, loss is 0.82193726
epoch: 1 step: 405, loss is 1.754913
epoch: 1 step: 406, loss is 1.8550925
epoch: 1 step: 407, loss is 1.2496967
epoch: 1 step: 408, loss is 1.7596623
epoch: 1 step: 409, loss is 2.8909135
epoch: 1 step: 410, loss is 1.2451619
epoch: 1 step: 411, loss is 2.008631
epoch: 1 step: 412, loss is 1.358
epoch: 1 step: 413, loss is 1.8881397
epoch: 1 step: 414, loss is 1.1616787
epoch: 1 step: 415, loss is 1.5433798
epoch: 1 step: 416, loss is 0.8479133
epoch: 1 step: 417, loss is 1.2960136
epoch: 1 step: 418, loss is 1.8772743
epoch: 1 step: 419, loss is 1.6293979
epoch: 1 step: 420, loss is 0.8803827
epoch: 1 step: 421, loss is 1.2635891
epoch: 1 step: 422, loss is 1.6076564
epoch: 1 step: 423, loss is 1.5299233
epoch: 1 step: 424, loss is 0.980051
epoch: 1 step: 425, loss is 1.2562948
epoch: 1 step: 426, loss is 1.423223
epoch: 1 step: 427, loss is 1.9626325
epoch: 1 step: 428, loss is 1.2402831
epoch: 1 step: 429, loss is 1.2292256
epoch: 1 step: 430, loss is 1.5492425
epoch: 1 step: 431, loss is 1.432518
epoch: 1 step: 432, loss is 2.1835656
epoch: 1 step: 433, loss is 1.8950657
epoch: 1 step: 434, loss is 0.9784037
epoch: 1 step: 435, loss is 1.1228023
epoch: 1 step: 436, loss is 1.2686464
epoch: 1 step: 437, loss is 1.4692252
epoch: 1 step: 438, loss is 2.5748706
epoch: 1 step: 439, loss is 1.1190864
epoch: 1 step: 440, loss is 1.0943197
epoch: 1 step: 441, loss is 1.0791196
epoch: 1 step: 442, loss is 1.6978616
epoch: 1 step: 443, loss is 1.0548728
epoch: 1 step: 444, loss is 0.92251766
epoch: 1 step: 445, loss is 1.4176005
epoch: 1 step: 446, loss is 1.7058883
epoch: 1 step: 447, loss is 0.9724615
epoch: 1 step: 448, loss is 1.3960894
epoch: 1 step: 449, loss is 3.1534379
epoch: 1 step: 450, loss is 1.8036578
epoch: 1 step: 451, loss is 0.9785741
epoch: 1 step: 452, loss is 1.1241843
epoch: 1 step: 453, loss is 1.4794102
epoch: 1 step: 454, loss is 1.7737528
epoch: 1 step: 455, loss is 1.4642189
epoch: 1 step: 456, loss is 1.5362893
epoch: 1 step: 457, loss is 1.1465919
epoch: 1 step: 458, loss is 0.94599044
epoch: 1 step: 459, loss is 0.8908225
epoch: 1 step: 460, loss is 1.4768763
epoch: 1 step: 461, loss is 1.3861086
epoch: 1 step: 462, loss is 0.91812825
epoch: 1 step: 463, loss is 1.0989889
epoch: 1 step: 464, loss is 1.7393584
epoch: 1 step: 465, loss is 0.53075486
epoch: 1 step: 466, loss is 1.145461
epoch: 1 step: 467, loss is 2.2586012
epoch: 1 step: 468, loss is 0.91045314
epoch: 1 step: 469, loss is 0.8048902
epoch: 1 step: 470, loss is 1.7757021
epoch: 1 step: 471, loss is 1.7945919
epoch: 1 step: 472, loss is 2.0874314
epoch: 1 step: 473, loss is 1.3519688
epoch: 1 step: 474, loss is 1.5992619
epoch: 1 step: 475, loss is 1.2527186
epoch: 1 step: 476, loss is 1.3911091
epoch: 1 step: 477, loss is 1.6093314
epoch: 1 step: 478, loss is 1.6135104
epoch: 1 step: 479, loss is 1.6279286
epoch: 1 step: 480, loss is 1.35702
epoch: 1 step: 481, loss is 2.5428
epoch: 1 step: 482, loss is 2.004964
epoch: 1 step: 483, loss is 1.6592869
epoch: 1 step: 484, loss is 1.1530886
epoch: 1 step: 485, loss is 2.2842014
epoch: 1 step: 486, loss is 1.4322463
epoch: 1 step: 487, loss is 1.8067303
epoch: 1 step: 488, loss is 1.3025694
epoch: 1 step: 489, loss is 1.5667956
epoch: 1 step: 490, loss is 0.87021554
epoch: 1 step: 491, loss is 0.69189644
epoch: 1 step: 492, loss is 1.2770282
epoch: 1 step: 493, loss is 2.4788914
epoch: 1 step: 494, loss is 2.440883
epoch: 1 step: 495, loss is 1.9201921
epoch: 1 step: 496, loss is 1.6065029
epoch: 1 step: 497, loss is 0.82671946
epoch: 1 step: 498, loss is 1.5543772
epoch: 1 step: 499, loss is 1.2513846
epoch: 1 step: 500, loss is 1.5135756
epoch: 1 step: 501, loss is 1.842005
epoch: 1 step: 502, loss is 0.91561013
epoch: 1 step: 503, loss is 1.8401824
epoch: 1 step: 504, loss is 1.2581476
epoch: 1 step: 505, loss is 1.1895336
epoch: 1 step: 506, loss is 1.691113
epoch: 1 step: 507, loss is 1.2355351
epoch: 1 step: 508, loss is 2.0056295
epoch: 1 step: 509, loss is 1.9786228
epoch: 1 step: 510, loss is 2.4990335
epoch: 1 step: 511, loss is 1.1890054
epoch: 1 step: 512, loss is 0.7560407
epoch: 1 step: 513, loss is 1.5969294
epoch: 1 step: 514, loss is 0.7416187
epoch: 1 step: 515, loss is 1.1997447
epoch: 1 step: 516, loss is 1.7080477
epoch: 1 step: 517, loss is 0.6667213
epoch: 1 step: 518, loss is 1.2951919
epoch: 1 step: 519, loss is 0.9460233
epoch: 1 step: 520, loss is 2.4028647
epoch: 1 step: 521, loss is 0.7313364
epoch: 1 step: 522, loss is 2.3969393
epoch: 1 step: 523, loss is 0.67383474
epoch: 1 step: 524, loss is 1.900128
epoch: 1 step: 525, loss is 0.7979565
epoch: 1 step: 526, loss is 1.7467053
epoch: 1 step: 527, loss is 1.2942317
epoch: 1 step: 528, loss is 1.4475969
epoch: 1 step: 529, loss is 1.3927977
epoch: 1 step: 530, loss is 1.7860438
epoch: 1 step: 531, loss is 2.23973
epoch: 1 step: 532, loss is 1.9893365
epoch: 1 step: 533, loss is 1.0965588
epoch: 1 step: 534, loss is 1.1013372
epoch: 1 step: 535, loss is 1.4547001
epoch: 1 step: 536, loss is 0.9996863
epoch: 1 step: 537, loss is 1.4403392
epoch: 1 step: 538, loss is 2.0151887
epoch: 1 step: 539, loss is 1.3169831
epoch: 1 step: 540, loss is 1.7014015
epoch: 1 step: 541, loss is 1.2018706
epoch: 1 step: 542, loss is 1.8350927
epoch: 1 step: 543, loss is 1.2652528
epoch: 1 step: 544, loss is 1.3934675
epoch: 1 step: 545, loss is 2.0987773
epoch: 1 step: 546, loss is 0.79897255
epoch: 1 step: 547, loss is 1.1532202
epoch: 1 step: 548, loss is 1.2609544
epoch: 1 step: 549, loss is 1.3690522
epoch: 1 step: 550, loss is 1.2174276
epoch: 1 step: 551, loss is 2.4536
epoch: 1 step: 552, loss is 0.9180831
epoch: 1 step: 553, loss is 1.150316
epoch: 1 step: 554, loss is 1.5213302
epoch: 1 step: 555, loss is 1.0609323
epoch: 1 step: 556, loss is 1.974631
epoch: 1 step: 557, loss is 1.1014698
epoch: 1 step: 558, loss is 1.0603566
epoch: 1 step: 559, loss is 1.7373581
epoch: 1 step: 560, loss is 1.4042748
epoch: 1 step: 561, loss is 0.97482824
epoch: 1 step: 562, loss is 1.568401
epoch: 1 step: 563, loss is 1.479336
epoch: 1 step: 564, loss is 1.1916636
epoch: 1 step: 565, loss is 2.4185524
epoch: 1 step: 566, loss is 2.552045
epoch: 1 step: 567, loss is 1.7662873
epoch: 1 step: 568, loss is 1.4527646
epoch: 1 step: 569, loss is 1.8034967
epoch: 1 step: 570, loss is 1.0246687
epoch: 1 step: 571, loss is 1.6779621
epoch: 1 step: 572, loss is 1.196817
epoch: 1 step: 573, loss is 1.964355
epoch: 1 step: 574, loss is 0.7645623
epoch: 1 step: 575, loss is 1.0493279
epoch: 1 step: 576, loss is 1.9929748
epoch: 1 step: 577, loss is 1.3580036
epoch: 1 step: 578, loss is 1.9052657
epoch: 1 step: 579, loss is 0.7607452
epoch: 1 step: 580, loss is 1.4320997
epoch: 1 step: 581, loss is 0.7930942
epoch: 1 step: 582, loss is 2.3165395
epoch: 1 step: 583, loss is 1.615868
epoch: 1 step: 584, loss is 1.2220751
epoch: 1 step: 585, loss is 1.381418
epoch: 1 step: 586, loss is 2.4991438
epoch: 1 step: 587, loss is 1.5977526
epoch: 1 step: 588, loss is 1.1650064
epoch: 1 step: 589, loss is 1.4682877
epoch: 1 step: 590, loss is 1.6823592
epoch: 1 step: 591, loss is 0.9045695
epoch: 1 step: 592, loss is 1.7934263
epoch: 1 step: 593, loss is 1.457248
epoch: 1 step: 594, loss is 1.7482295
epoch: 1 step: 595, loss is 1.0465549
epoch: 1 step: 596, loss is 1.8515587
epoch: 1 step: 597, loss is 1.476748
epoch: 1 step: 598, loss is 1.2632848
epoch: 1 step: 599, loss is 1.2382678
epoch: 1 step: 600, loss is 1.0418402
epoch: 1 step: 601, loss is 2.6405387
epoch: 1 step: 602, loss is 1.3072877
epoch: 1 step: 603, loss is 1.6162653
epoch: 1 step: 604, loss is 1.6196364
epoch: 1 step: 605, loss is 1.1898572
epoch: 1 step: 606, loss is 0.859409
epoch: 1 step: 607, loss is 1.8086871
epoch: 1 step: 608, loss is 1.5792798
epoch: 1 step: 609, loss is 1.8329651
epoch: 1 step: 610, loss is 1.1435925
epoch: 1 step: 611, loss is 1.4280559
epoch: 1 step: 612, loss is 1.2352223
epoch: 1 step: 613, loss is 1.8347135
epoch: 1 step: 614, loss is 1.6290873
epoch: 1 step: 615, loss is 1.829961
epoch: 1 step: 616, loss is 1.4242986
epoch: 1 step: 617, loss is 0.68460876
epoch: 1 step: 618, loss is 1.3120354
epoch: 1 step: 619, loss is 1.1081161
epoch: 1 step: 620, loss is 1.9331317
epoch: 1 step: 621, loss is 1.1618829
epoch: 1 step: 622, loss is 1.0096247
epoch: 1 step: 623, loss is 1.7552127
epoch: 1 step: 624, loss is 1.0370134
epoch: 1 step: 625, loss is 1.370508
epoch: 1 step: 626, loss is 2.2301862
epoch: 1 step: 627, loss is 1.7327023
epoch: 1 step: 628, loss is 1.0129626
epoch: 1 step: 629, loss is 2.0277576
epoch: 1 step: 630, loss is 1.1579026
epoch: 1 step: 631, loss is 1.5143166
epoch: 1 step: 632, loss is 1.6535301
epoch: 1 step: 633, loss is 1.3473115
epoch: 1 step: 634, loss is 1.1354238
epoch: 1 step: 635, loss is 0.6990392
epoch: 1 step: 636, loss is 1.3353117
epoch: 1 step: 637, loss is 2.7284014
epoch: 1 step: 638, loss is 1.6001185
epoch: 1 step: 639, loss is 1.2594099
epoch: 1 step: 640, loss is 1.998916
epoch: 1 step: 641, loss is 1.2167265
epoch: 1 step: 642, loss is 1.0401949
epoch: 1 step: 643, loss is 1.7153178
epoch: 1 step: 644, loss is 0.8860398
epoch: 1 step: 645, loss is 1.4834538
epoch: 1 step: 646, loss is 1.3572388
epoch: 1 step: 647, loss is 1.474732
epoch: 1 step: 648, loss is 1.3883584
epoch: 1 step: 649, loss is 1.3569728
epoch: 1 step: 650, loss is 2.0174239
epoch: 1 step: 651, loss is 1.8912256
epoch: 1 step: 652, loss is 1.3615077
epoch: 1 step: 653, loss is 1.1695085
epoch: 1 step: 654, loss is 2.3774867
epoch: 1 step: 655, loss is 2.3047686
epoch: 1 step: 656, loss is 1.0737629
epoch: 1 step: 657, loss is 1.5773399
epoch: 1 step: 658, loss is 2.0913444
epoch: 1 step: 659, loss is 1.4727192
epoch: 1 step: 660, loss is 1.1297174
epoch: 1 step: 661, loss is 2.3793695
epoch: 1 step: 662, loss is 1.4156965
epoch: 1 step: 663, loss is 1.5777516
epoch: 1 step: 664, loss is 1.1636202
epoch: 1 step: 665, loss is 0.8610534
epoch: 1 step: 666, loss is 1.6579095
epoch: 1 step: 667, loss is 2.3668861
epoch: 1 step: 668, loss is 1.3911513
epoch: 1 step: 669, loss is 1.4822259
epoch: 1 step: 670, loss is 2.0586123
epoch: 1 step: 671, loss is 1.0990242
epoch: 1 step: 672, loss is 1.3859339
epoch: 1 step: 673, loss is 1.8222079
epoch: 1 step: 674, loss is 1.1656622
epoch: 1 step: 675, loss is 0.9734727
epoch: 1 step: 676, loss is 2.0141056
epoch: 1 step: 677, loss is 2.13117
epoch: 1 step: 678, loss is 1.7686499
epoch: 1 step: 679, loss is 1.8333597
epoch: 1 step: 680, loss is 1.3103758
epoch: 1 step: 681, loss is 1.1785396
epoch: 1 step: 682, loss is 1.5423274
epoch: 1 step: 683, loss is 0.89124167
epoch: 1 step: 684, loss is 2.1677163
epoch: 1 step: 685, loss is 0.82135946
epoch: 1 step: 686, loss is 1.078844
epoch: 1 step: 687, loss is 0.91661775
epoch: 1 step: 688, loss is 2.1029468
epoch: 1 step: 689, loss is 1.9936483
epoch: 1 step: 690, loss is 1.9184018
epoch: 1 step: 691, loss is 1.3582053
epoch: 1 step: 692, loss is 1.4467864
epoch: 1 step: 693, loss is 1.4055998
epoch: 1 step: 694, loss is 1.7503028
epoch: 1 step: 695, loss is 0.9700002
epoch: 1 step: 696, loss is 1.0180645
epoch: 1 step: 697, loss is 1.6728761
epoch: 1 step: 698, loss is 2.358896
epoch: 1 step: 699, loss is 0.9859949
epoch: 1 step: 700, loss is 1.6545221
epoch: 1 step: 701, loss is 1.1929978
epoch: 1 step: 702, loss is 1.3082576
epoch: 1 step: 703, loss is 1.506224
epoch: 1 step: 704, loss is 1.2718801
epoch: 1 step: 705, loss is 1.8148435
epoch: 1 step: 706, loss is 1.2746598
epoch: 1 step: 707, loss is 0.82526827
epoch: 1 step: 708, loss is 1.6113437
epoch: 1 step: 709, loss is 2.7660322
epoch: 1 step: 710, loss is 0.9557553
epoch: 1 step: 711, loss is 0.93094665
epoch: 1 step: 712, loss is 1.5535289
epoch: 1 step: 713, loss is 1.309947
epoch: 1 step: 714, loss is 1.2552496
epoch: 1 step: 715, loss is 1.284125
epoch: 1 step: 716, loss is 1.3375615
epoch: 1 step: 717, loss is 1.3016047
epoch: 1 step: 718, loss is 2.0246065
epoch: 1 step: 719, loss is 0.7287121
epoch: 1 step: 720, loss is 2.4560804
epoch: 1 step: 721, loss is 1.7537823
epoch: 1 step: 722, loss is 2.7837656
epoch: 1 step: 723, loss is 1.7803614
epoch: 1 step: 724, loss is 1.6521822
epoch: 1 step: 725, loss is 1.2065257
epoch: 1 step: 726, loss is 1.837728
epoch: 1 step: 727, loss is 1.2398078
epoch: 1 step: 728, loss is 1.0907512
epoch: 1 step: 729, loss is 1.7155799
epoch: 1 step: 730, loss is 1.2578337
epoch: 1 step: 731, loss is 0.9910513
epoch: 1 step: 732, loss is 2.1482778
epoch: 1 step: 733, loss is 1.1908579
epoch: 1 step: 734, loss is 0.8630355
epoch: 1 step: 735, loss is 1.1897058
epoch: 1 step: 736, loss is 1.8889909
epoch: 1 step: 737, loss is 1.1536769
epoch: 1 step: 738, loss is 0.7867823
epoch: 1 step: 739, loss is 1.1137031
epoch: 1 step: 740, loss is 1.3423734
epoch: 1 step: 741, loss is 2.6421354
epoch: 1 step: 742, loss is 1.2122232
epoch: 1 step: 743, loss is 1.9138614
epoch: 1 step: 744, loss is 1.950445
epoch: 1 step: 745, loss is 1.23287
epoch: 1 step: 746, loss is 1.979454
epoch: 1 step: 747, loss is 1.5075603
epoch: 1 step: 748, loss is 2.316857
epoch: 1 step: 749, loss is 2.2969525
epoch: 1 step: 750, loss is 0.99530077
epoch: 1 step: 751, loss is 0.88919467
epoch: 1 step: 752, loss is 1.0155463
epoch: 1 step: 753, loss is 1.0921952
epoch: 1 step: 754, loss is 1.3373929
epoch: 1 step: 755, loss is 0.9323154
epoch: 1 step: 756, loss is 1.2995187
epoch: 1 step: 757, loss is 1.1787848
epoch: 1 step: 758, loss is 2.6910563
epoch: 1 step: 759, loss is 1.1386487
epoch: 1 step: 760, loss is 1.2989185
epoch: 1 step: 761, loss is 1.0233533
epoch: 1 step: 762, loss is 0.82195747
epoch: 1 step: 763, loss is 1.5808702
epoch: 1 step: 764, loss is 1.0293467
epoch: 1 step: 765, loss is 1.3196009
epoch: 1 step: 766, loss is 1.6609744
epoch: 1 step: 767, loss is 1.1705285
epoch: 1 step: 768, loss is 1.3090541
epoch: 1 step: 769, loss is 2.133829
epoch: 1 step: 770, loss is 1.1495576
epoch: 1 step: 771, loss is 1.5994767
epoch: 1 step: 772, loss is 0.67954403
epoch: 1 step: 773, loss is 1.498097
epoch: 1 step: 774, loss is 1.4529055
epoch: 1 step: 775, loss is 1.4885108
epoch: 1 step: 776, loss is 1.2502661
epoch: 1 step: 777, loss is 1.2519264
epoch: 1 step: 778, loss is 1.5726031
epoch: 1 step: 779, loss is 2.3873594
epoch: 1 step: 780, loss is 0.64518046
epoch: 1 step: 781, loss is 0.888303
epoch: 1 step: 782, loss is 1.7049657
epoch: 1 step: 783, loss is 1.821528
epoch: 1 step: 784, loss is 1.6205864
epoch: 1 step: 785, loss is 1.1624395
epoch: 1 step: 786, loss is 0.5778998
epoch: 1 step: 787, loss is 1.6014494
epoch: 1 step: 788, loss is 1.162089
epoch: 1 step: 789, loss is 1.3712866
epoch: 1 step: 790, loss is 1.6229004
epoch: 1 step: 791, loss is 0.92097133
epoch: 1 step: 792, loss is 0.9893667
epoch: 1 step: 793, loss is 1.3134686
epoch: 1 step: 794, loss is 0.99728006
epoch: 1 step: 795, loss is 1.462746
epoch: 1 step: 796, loss is 1.6594787
epoch: 1 step: 797, loss is 1.8230748
epoch: 1 step: 798, loss is 1.1604604
epoch: 1 step: 799, loss is 2.0209866
epoch: 1 step: 800, loss is 1.0432184
epoch: 1 step: 801, loss is 1.0529643
epoch: 1 step: 802, loss is 1.1991208
epoch: 1 step: 803, loss is 1.1369395
epoch: 1 step: 804, loss is 0.9592687
epoch: 1 step: 805, loss is 1.8850667
epoch: 1 step: 806, loss is 2.071511
epoch: 1 step: 807, loss is 1.4339205
epoch: 1 step: 808, loss is 1.1864223
epoch: 1 step: 809, loss is 2.1916008
epoch: 1 step: 810, loss is 1.2894207
epoch: 1 step: 811, loss is 1.2272013
epoch: 1 step: 812, loss is 1.4866894
epoch: 1 step: 813, loss is 2.0451903
epoch: 1 step: 814, loss is 1.7789694
epoch: 1 step: 815, loss is 0.709436
epoch: 1 step: 816, loss is 1.1146446
epoch: 1 step: 817, loss is 1.1670245
epoch: 1 step: 818, loss is 0.49987307
epoch: 1 step: 819, loss is 1.4303055
epoch: 1 step: 820, loss is 1.6934975
epoch: 1 step: 821, loss is 2.5151165
epoch: 1 step: 822, loss is 1.7967864
epoch: 1 step: 823, loss is 1.7060671
epoch: 1 step: 824, loss is 1.7326264
epoch: 1 step: 825, loss is 1.2750481
epoch: 1 step: 826, loss is 1.2841887
epoch: 1 step: 827, loss is 1.3871725
epoch: 1 step: 828, loss is 1.7519778
epoch: 1 step: 829, loss is 1.4123513
epoch: 1 step: 830, loss is 2.4292486
epoch: 1 step: 831, loss is 1.3287517
epoch: 1 step: 832, loss is 1.8054821
epoch: 1 step: 833, loss is 1.8878162
epoch: 1 step: 834, loss is 1.7427497
epoch: 1 step: 835, loss is 1.2273088
epoch: 1 step: 836, loss is 1.3899534
epoch: 1 step: 837, loss is 2.8547494
epoch: 1 step: 838, loss is 3.2127018
epoch: 1 step: 839, loss is 1.3316845
epoch: 1 step: 840, loss is 1.9020799
epoch: 1 step: 841, loss is 0.99702704
epoch: 1 step: 842, loss is 1.8758255
epoch: 1 step: 843, loss is 0.7705092
epoch: 1 step: 844, loss is 1.4335445
epoch: 1 step: 845, loss is 1.0372807
epoch: 1 step: 846, loss is 2.165585
epoch: 1 step: 847, loss is 2.4236934
epoch: 1 step: 848, loss is 1.4307977
epoch: 1 step: 849, loss is 1.6045575
epoch: 1 step: 850, loss is 0.79196805
epoch: 1 step: 851, loss is 2.4351337
epoch: 1 step: 852, loss is 1.7541515
epoch: 1 step: 853, loss is 1.135071
epoch: 1 step: 854, loss is 1.8699466
epoch: 1 step: 855, loss is 1.6874602
epoch: 1 step: 856, loss is 1.4039229
epoch: 1 step: 857, loss is 1.2556418
epoch: 1 step: 858, loss is 1.9822655
epoch: 1 step: 859, loss is 2.1451356
epoch: 1 step: 860, loss is 1.2145696
epoch: 1 step: 861, loss is 1.6548114
epoch: 1 step: 862, loss is 1.1258758
epoch: 1 step: 863, loss is 2.0702477
epoch: 1 step: 864, loss is 1.7784642
epoch: 1 step: 865, loss is 0.76711166
epoch: 1 step: 866, loss is 1.4890877
epoch: 1 step: 867, loss is 1.1974173
epoch: 1 step: 868, loss is 1.157609
epoch: 1 step: 869, loss is 1.2968985
epoch: 1 step: 870, loss is 0.98109263
epoch: 1 step: 871, loss is 1.5401049
epoch: 1 step: 872, loss is 1.4070408
epoch: 1 step: 873, loss is 0.97259045
epoch: 1 step: 874, loss is 0.92109007
epoch: 1 step: 875, loss is 1.7741128
epoch: 1 step: 876, loss is 0.8714395
epoch: 1 step: 877, loss is 1.4307848
epoch: 1 step: 878, loss is 1.0944982
epoch: 1 step: 879, loss is 1.5444462
epoch: 1 step: 880, loss is 0.68220437
epoch: 1 step: 881, loss is 1.3530911
epoch: 1 step: 882, loss is 1.133121
epoch: 1 step: 883, loss is 1.001787
epoch: 1 step: 884, loss is 1.1988907
epoch: 1 step: 885, loss is 2.406839
epoch: 1 step: 886, loss is 1.4209176
epoch: 1 step: 887, loss is 0.9454943
epoch: 1 step: 888, loss is 2.6596746
epoch: 1 step: 889, loss is 0.9580429
epoch: 1 step: 890, loss is 1.029335
epoch: 1 step: 891, loss is 1.6952835
epoch: 1 step: 892, loss is 1.694043
epoch: 1 step: 893, loss is 2.3514166
epoch: 1 step: 894, loss is 1.8389215
epoch: 1 step: 895, loss is 1.7557657
epoch: 1 step: 896, loss is 2.2383707
epoch: 1 step: 897, loss is 1.5355078
epoch: 1 step: 898, loss is 2.037735
epoch: 1 step: 899, loss is 2.4196162
epoch: 1 step: 900, loss is 1.3198181
epoch: 1 step: 901, loss is 1.3416109
epoch: 1 step: 902, loss is 1.265082
epoch: 1 step: 903, loss is 1.2900982
epoch: 1 step: 904, loss is 1.4276958
epoch: 1 step: 905, loss is 2.0730338
epoch: 1 step: 906, loss is 1.3421351
epoch: 1 step: 907, loss is 1.670749
epoch: 1 step: 908, loss is 1.1395612
epoch: 1 step: 909, loss is 2.1065683
epoch: 1 step: 910, loss is 1.7194724
epoch: 1 step: 911, loss is 1.6436872
epoch: 1 step: 912, loss is 0.9835954
epoch: 1 step: 913, loss is 0.95504457
epoch: 1 step: 914, loss is 1.5092107
epoch: 1 step: 915, loss is 1.9562743
epoch: 1 step: 916, loss is 2.0131452
epoch: 1 step: 917, loss is 1.0228437
epoch: 1 step: 918, loss is 1.4788399
epoch: 1 step: 919, loss is 1.0001124
epoch: 1 step: 920, loss is 0.7687977
epoch: 1 step: 921, loss is 2.7400348
epoch: 1 step: 922, loss is 1.587242
epoch: 1 step: 923, loss is 2.1694262
epoch: 1 step: 924, loss is 1.49128
epoch: 1 step: 925, loss is 1.7313235
epoch: 1 step: 926, loss is 1.6179299
epoch: 1 step: 927, loss is 0.9696639
epoch: 1 step: 928, loss is 1.2531508
epoch: 1 step: 929, loss is 1.4004098
epoch: 1 step: 930, loss is 1.5238091
epoch: 1 step: 931, loss is 1.409297
epoch: 1 step: 932, loss is 3.1071928
epoch: 1 step: 933, loss is 1.4850911
epoch: 1 step: 934, loss is 1.715758
epoch: 1 step: 935, loss is 0.8908411
epoch: 1 step: 936, loss is 1.2088581
epoch: 1 step: 937, loss is 2.5787454
epoch: 1 step: 938, loss is 0.81597435
epoch: 1 step: 939, loss is 1.2830538
epoch: 1 step: 940, loss is 0.89021116
epoch: 1 step: 941, loss is 2.5687268
epoch: 1 step: 942, loss is 1.7394732
epoch: 1 step: 943, loss is 1.0762677
epoch: 1 step: 944, loss is 2.1387246
epoch: 1 step: 945, loss is 1.8366687
epoch: 1 step: 946, loss is 1.0237145
epoch: 1 step: 947, loss is 1.314549
epoch: 1 step: 948, loss is 1.9217317
epoch: 1 step: 949, loss is 0.66198105
epoch: 1 step: 950, loss is 1.2199132
epoch: 1 step: 951, loss is 1.1808801
epoch: 1 step: 952, loss is 1.0597556
epoch: 1 step: 953, loss is 2.3816638
epoch: 1 step: 954, loss is 1.2892663
epoch: 1 step: 955, loss is 0.71958745
epoch: 1 step: 956, loss is 1.1668129
epoch: 1 step: 957, loss is 0.73837984
epoch: 1 step: 958, loss is 1.2345405
epoch: 1 step: 959, loss is 1.955028
epoch: 1 step: 960, loss is 0.98684597
epoch: 1 step: 961, loss is 1.3501204
epoch: 1 step: 962, loss is 1.1486751
epoch: 1 step: 963, loss is 0.78713036
epoch: 1 step: 964, loss is 0.71981424
epoch: 1 step: 965, loss is 2.131125
epoch: 1 step: 966, loss is 0.870963
epoch: 1 step: 967, loss is 1.4835389
epoch: 1 step: 968, loss is 1.1410695
epoch: 1 step: 969, loss is 1.5910785
epoch: 1 step: 970, loss is 1.0197784
epoch: 1 step: 971, loss is 1.7528654
epoch: 1 step: 972, loss is 1.6502004
epoch: 1 step: 973, loss is 1.4938257
epoch: 1 step: 974, loss is 1.6456684
epoch: 1 step: 975, loss is 1.3726398
epoch: 1 step: 976, loss is 1.382462
epoch: 1 step: 977, loss is 2.2146502
epoch: 1 step: 978, loss is 1.5492452
epoch: 1 step: 979, loss is 1.0074306
epoch: 1 step: 980, loss is 1.4478523
epoch: 1 step: 981, loss is 1.0397849
epoch: 1 step: 982, loss is 1.2050446
epoch: 1 step: 983, loss is 1.2951709
epoch: 1 step: 984, loss is 0.8871292
epoch: 1 step: 985, loss is 2.5712712
epoch: 1 step: 986, loss is 1.0970973
epoch: 1 step: 987, loss is 1.8151417
epoch: 1 step: 988, loss is 1.41382
epoch: 1 step: 989, loss is 1.9076371
epoch: 1 step: 990, loss is 1.4834149
epoch: 1 step: 991, loss is 0.9332489
epoch: 1 step: 992, loss is 1.229525
epoch: 1 step: 993, loss is 1.6741447
epoch: 1 step: 994, loss is 1.9214704
epoch: 1 step: 995, loss is 2.103998
epoch: 1 step: 996, loss is 1.241635
epoch: 1 step: 997, loss is 1.6832441
epoch: 1 step: 998, loss is 0.8878015
epoch: 1 step: 999, loss is 1.1752499
epoch: 1 step: 1000, loss is 1.1837425
epoch: 1 step: 1001, loss is 0.7106122
epoch: 1 step: 1002, loss is 1.2818336
epoch: 1 step: 1003, loss is 1.4030999
epoch: 1 step: 1004, loss is 1.7420958
epoch: 1 step: 1005, loss is 1.0689362
epoch: 1 step: 1006, loss is 2.0577364
epoch: 1 step: 1007, loss is 1.2902833
epoch: 1 step: 1008, loss is 1.1731718
epoch: 1 step: 1009, loss is 0.9727486
epoch: 1 step: 1010, loss is 0.9811288
epoch: 1 step: 1011, loss is 1.2832747
epoch: 1 step: 1012, loss is 1.4022253
epoch: 1 step: 1013, loss is 1.3713791
epoch: 1 step: 1014, loss is 1.8806754
epoch: 1 step: 1015, loss is 1.4027587
epoch: 1 step: 1016, loss is 1.2643385
epoch: 1 step: 1017, loss is 1.8821111
epoch: 1 step: 1018, loss is 2.3471222
epoch: 1 step: 1019, loss is 0.69396454
epoch: 1 step: 1020, loss is 1.1431576
epoch: 1 step: 1021, loss is 1.9318213
epoch: 1 step: 1022, loss is 1.8620855
epoch: 1 step: 1023, loss is 0.71865207
epoch: 1 step: 1024, loss is 1.1729274
epoch: 1 step: 1025, loss is 1.0937822
epoch: 1 step: 1026, loss is 1.021421
epoch: 1 step: 1027, loss is 1.1384312
epoch: 1 step: 1028, loss is 0.74326515
epoch: 1 step: 1029, loss is 1.6671635
epoch: 1 step: 1030, loss is 1.2472626
epoch: 1 step: 1031, loss is 1.6738926
epoch: 1 step: 1032, loss is 1.4868784
epoch: 1 step: 1033, loss is 1.1485713
epoch: 1 step: 1034, loss is 1.5566187
epoch: 1 step: 1035, loss is 1.7999661
epoch: 1 step: 1036, loss is 1.5068845
epoch: 1 step: 1037, loss is 0.940785
epoch: 1 step: 1038, loss is 1.5984167
epoch: 1 step: 1039, loss is 1.4425021
epoch: 1 step: 1040, loss is 1.485145
epoch: 1 step: 1041, loss is 1.916538
epoch: 1 step: 1042, loss is 1.5097013
epoch: 1 step: 1043, loss is 0.8444477
epoch: 1 step: 1044, loss is 1.3013592
epoch: 1 step: 1045, loss is 1.1176213
epoch: 1 step: 1046, loss is 0.9730844
epoch: 1 step: 1047, loss is 0.8687001
epoch: 1 step: 1048, loss is 1.2718016
epoch: 1 step: 1049, loss is 1.3523899
epoch: 1 step: 1050, loss is 0.73815167
epoch: 1 step: 1051, loss is 1.1476661
epoch: 1 step: 1052, loss is 1.0752705
epoch: 1 step: 1053, loss is 0.49258485
epoch: 1 step: 1054, loss is 3.1066551
epoch: 1 step: 1055, loss is 1.767251
epoch: 1 step: 1056, loss is 1.1767576
epoch: 1 step: 1057, loss is 0.8704813
epoch: 1 step: 1058, loss is 1.2094357
epoch: 1 step: 1059, loss is 1.0262873
epoch: 1 step: 1060, loss is 1.3345357
epoch: 1 step: 1061, loss is 1.1910824
epoch: 1 step: 1062, loss is 1.1537477
epoch: 1 step: 1063, loss is 1.1132579
epoch: 1 step: 1064, loss is 1.3779068
epoch: 1 step: 1065, loss is 1.7454937
epoch: 1 step: 1066, loss is 2.3723776
epoch: 1 step: 1067, loss is 1.6375308
epoch: 1 step: 1068, loss is 0.6437596
epoch: 1 step: 1069, loss is 1.4169464
epoch: 1 step: 1070, loss is 1.0671867
epoch: 1 step: 1071, loss is 1.3626171
epoch: 1 step: 1072, loss is 2.8147404
epoch: 1 step: 1073, loss is 1.4411762
epoch: 1 step: 1074, loss is 2.4054549
epoch: 1 step: 1075, loss is 2.1969209
epoch: 1 step: 1076, loss is 1.399044
epoch: 1 step: 1077, loss is 2.3889668
epoch: 1 step: 1078, loss is 1.3560655
epoch: 1 step: 1079, loss is 1.1616113
epoch: 1 step: 1080, loss is 0.9877867
epoch: 1 step: 1081, loss is 1.9014254
epoch: 1 step: 1082, loss is 2.1019292
epoch: 1 step: 1083, loss is 1.291187
epoch: 1 step: 1084, loss is 1.5365424
epoch: 1 step: 1085, loss is 0.64788616
epoch: 1 step: 1086, loss is 1.9694481
epoch: 1 step: 1087, loss is 1.3188428
epoch: 1 step: 1088, loss is 0.6989618
epoch: 1 step: 1089, loss is 0.86165607
epoch: 1 step: 1090, loss is 1.8750821
epoch: 1 step: 1091, loss is 1.6133876
epoch: 1 step: 1092, loss is 1.4385846
epoch: 1 step: 1093, loss is 2.2135713
epoch: 1 step: 1094, loss is 0.8390728
epoch: 1 step: 1095, loss is 0.781463
epoch: 1 step: 1096, loss is 1.6290425
epoch: 1 step: 1097, loss is 1.5072129
epoch: 1 step: 1098, loss is 1.373088
epoch: 1 step: 1099, loss is 1.8506068
epoch: 1 step: 1100, loss is 1.0593228
epoch: 1 step: 1101, loss is 1.5756218
epoch: 1 step: 1102, loss is 2.7547963
epoch: 1 step: 1103, loss is 1.47439
epoch: 1 step: 1104, loss is 1.2513165
epoch: 1 step: 1105, loss is 1.1543157
epoch: 1 step: 1106, loss is 1.0881364
epoch: 1 step: 1107, loss is 0.952166
epoch: 1 step: 1108, loss is 1.033313
epoch: 1 step: 1109, loss is 1.4835818
epoch: 1 step: 1110, loss is 0.897262
epoch: 1 step: 1111, loss is 1.1610739
epoch: 1 step: 1112, loss is 1.1644759
epoch: 1 step: 1113, loss is 0.91581404
epoch: 1 step: 1114, loss is 0.7886873
epoch: 1 step: 1115, loss is 1.1424193
epoch: 1 step: 1116, loss is 1.7950118
epoch: 1 step: 1117, loss is 0.9393161
epoch: 1 step: 1118, loss is 0.7268297
epoch: 1 step: 1119, loss is 0.70584095
epoch: 1 step: 1120, loss is 1.258619
epoch: 1 step: 1121, loss is 1.4684676
epoch: 1 step: 1122, loss is 3.969326
epoch: 1 step: 1123, loss is 1.9066572
epoch: 1 step: 1124, loss is 1.0778317
epoch: 1 step: 1125, loss is 1.2854974
epoch: 1 step: 1126, loss is 1.72006
epoch: 1 step: 1127, loss is 1.4593782
epoch: 1 step: 1128, loss is 2.2837453
epoch: 1 step: 1129, loss is 1.592466
epoch: 1 step: 1130, loss is 1.4620482
epoch: 1 step: 1131, loss is 1.3837616
epoch: 1 step: 1132, loss is 1.4247851
epoch: 1 step: 1133, loss is 2.7764735
epoch: 1 step: 1134, loss is 0.742168
epoch: 1 step: 1135, loss is 0.6785444
epoch: 1 step: 1136, loss is 1.696727
epoch: 1 step: 1137, loss is 1.6784787
epoch: 1 step: 1138, loss is 0.7807759
epoch: 1 step: 1139, loss is 2.6701849
epoch: 1 step: 1140, loss is 1.7018834
epoch: 1 step: 1141, loss is 0.6149916
epoch: 1 step: 1142, loss is 1.2111646
epoch: 1 step: 1143, loss is 2.2400432
Train epoch time: 797208.421 ms, per step time: 697.470 ms

因为FCN网络在训练的过程中需要大量的训练数据和训练轮数,这里只提供了小数据单个epoch的训练来演示loss收敛的过程,下文中使用已训练好的权重文件进行模型评估和推理效果的展示。

模型评估

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)

ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)

if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})

# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
dataset_eval = dataset.get_dataset()
model.eval(dataset_eval)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt (1.00 GB)

file_sizes: 100%|███████████████████████████| 1.08G/1.08G [00:04<00:00, 230MB/s]
Successfully downloaded file to FCN8s.ckpt
/





{'pixel accuracy': 0.9727286050069582,
 'mean pixel accuracy': 0.9405268615185339,
 'mean IoU': 0.8936091905278665,
 'frequency weighted IoU': 0.947469743069777}

模型推理

使用训练的网络对模型推理结果进行展示。

import cv2
import matplotlib.pyplot as plt

net = FCN8s(n_class=num_classes)
# 设置超参
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []
# 推理效果展示(上方为输入图片,下方为推理效果图片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):
    plt.subplot(2, 4, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 4, i + 5)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

在这里插入图片描述

总结

FCN的核心贡献在于提出使用全卷积层,通过学习让图片实现端到端分割。与传统使用CNN进行图像分割的方法相比,FCN有两大明显的优点:一是可以接受任意大小的输入图像,无需要求所有的训练图像和测试图像具有固定的尺寸。二是更加高效,避免了由于使用像素块而带来的重复存储和计算卷积的问题。

同时FCN网络也存在待改进之处:

一是得到的结果仍不够精细。进行8倍上采样虽然比32倍的效果好了很多,但是上采样的结果仍比较模糊和平滑,尤其是边界处,网络对图像中的细节不敏感。
二是对各个像素进行分类,没有充分考虑像素与像素之间的关系(如不连续性和相似性)。忽略了在通常的基于像素分类的分割方法中使用的空间规整(spatial regularization)步骤,缺乏空间一致性。

引用

[1]Long, Jonathan, Evan Shelhamer, and Trevor Darrell. “Fully convolutional networks for Semantic Segmentation.” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值