DeepLabV3+(tensorflow)工程应用系列(三)—— 训练代码关键点解析(train.py)

一、官方代码组织结构

  • core:各种网络结构的定义。
  • datasets:原始数据存放位置,以及数据制作,数据读取,数据增强的脚本文件。
  • g3doc:该工程的说明文档。
  • pretrain_model:ImageNet或着COCO数据集上预训练的模型。
  • utils:一些辅助功能。

在这里插入图片描述

二、整体训练流程(数据读取模块,网络结构,损失函数)

参考官方的提供的DeepLab-V3+代码【train.py】,对其中的数据读取模块作相应的解析,具体步骤如下:

  1. 初始化Dataset类
    在【train.py】脚本中,首先需要初始化数据类。主要是设置一些参数,见下面的代码所示:

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):  # '/device:CPU:0'
            # Data loader
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True)
    
  2. TFRecords数据读取
    在【train.py】脚本中,调用【dataset】类的成员函数【get_one_shot_iterator()】,得到训练数据流接口,代码如下:

        with tf.device(config.variables_device()):  # '/device:CPU:0'
            global_step = tf.train.get_or_create_global_step()
            print('Init global_step: ', global_step)
    
            # Define the model and create clones
            model_fn = _build_deeplab
    
            # 创建数据读取迭代器, 并且包含对数据进行解析和数据增强操作
            model_args = (dataset.get_one_shot_iterator(),
                          {common.OUTPUT_TYPE: dataset.num_of_classes},
                          dataset.ignore_label)
            print("model_args: ", model_args)
    

    Tensorflow中【TFRecords】文件读取函数为tf.data.TFRecordDataset(),传入参数是所有的【TFRecords】文件,以及线程数量,具体代码如下:

        def get_one_shot_iterator(self):
            """Gets an iterator that iterates across the dataset once.
    
            Returns:
                An iterator of type tf.data.Iterator.
            """
    
            files = self._get_all_files()
            print('train files: ', files)
    
            dataset = (tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
                       .map(self._parse_function, num_parallel_calls=self.num_readers)     # 解析tfrecord文件
                       .map(self._preprocess_image, num_parallel_calls=self.num_readers))  # 图片数据增强处理
            print("dataset: ", dataset)
    
            if self.should_shuffle:
                dataset = dataset.shuffle(buffer_size=100)
    
            if self.should_repeat:
                dataset = dataset.repeat()  # Repeat forever for training.
            else:
                dataset = dataset.repeat(1)
    
            dataset = dataset.batch(self.batch_size).prefetch(self.batch_size)
            return dataset.make_one_shot_iterator()
    
    

    解析【TFRecords】文件,获取相应的训练数据。首先,features字典是我们在生成【TFRecords】文件时,定义的【Example】域的属性,可以参考该系列上一篇博文。然后,通过函数tf.parse_single_example(example_proto, features)解析文件。最后,生成Sample字典,用于训练。

        def _parse_function(self, example_proto):
            """Function to parse the example proto.
    
            Args:
                example_proto: Proto in the format of tf.Example.
    
            Returns:
                A dictionary with parsed image, label, height, width and image name.
    
            Raises:
                ValueError: Label is of wrong shape.
            """
    
            # Currently only supports jpeg and png.
            # Need to use this logic because the shape is not known for
            # tf.image.decode_image and we rely on this info to
            # extend label if necessary.
            def _decode_image(content, channels):
                return tf.cond(tf.image.is_jpeg(content),
                               lambda: tf.image.decode_jpeg(content, channels),
                               lambda: tf.image.decode_png(content, channels))
    		
    		# 制作TFRecords时, 定义的tf.Example域的属性
            features = {
                'image/encoded':
                    tf.FixedLenFeature((), tf.string, default_value=''),
                'image/filename':
                    tf.FixedLenFeature((), tf.string, default_value=''),
                'image/format':
                    tf.FixedLenFeature((), tf.string, default_value='jpeg'),
                'image/height':
                    tf.FixedLenFeature((), tf.int64, default_value=0),
                'image/width':
                    tf.FixedLenFeature((), tf.int64, default_value=0),
                'image/segmentation/class/encoded':
                    tf.FixedLenFeature((), tf.string, default_value=''),
                'image/segmentation/class/format':
                    tf.FixedLenFeature((), tf.string, default_value='png'),
            }
    		
    		# 解析单个Example信息
            parsed_features = tf.parse_single_example(example_proto, features)
    
            image = _decode_image(parsed_features['image/encoded'], channels=3)
    
            label = None
            if self.split_name != common.TEST_SET:
                label = _decode_image(
                    parsed_features['image/segmentation/class/encoded'], channels=1)
    
            image_name = parsed_features['image/filename']
            if image_name is None:
                image_name = tf.constant('')
    		
    		# 训练数据流字典
            sample = {
                common.IMAGE: image,
                common.IMAGE_NAME: image_name,
                common.HEIGHT: parsed_features['image/height'],
                common.WIDTH: parsed_features['image/width'],
            }
    
            if label is not None:
                if label.get_shape().ndims == 2:
                    label = tf.expand_dims(label, 2)
                elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
                    pass
                else:
                    raise ValueError('Input label shape must be [height, width], or '
                                     '[height, width, 1].')
    
                label.set_shape([None, None, 1])
    
                sample[common.LABELS_CLASS] = label
    
            return sample
    
  3. 网络模型
    下面的图片给出了网络的大体结构,大致分为三部分:DCNN+ASPP+Decoder,下面分别介绍各个模块,

    DCNN模块:输入图像的大小为【513x513x3】,是语义分割网络的基础模块。【DeepLab-V3+】论文中,实验两种基础模块 ResNet-101-beta,Xception65。
    ASPP模块:该模块在【DeepLab-V2】提出的结构,用于解决多尺度目标的分割效果,并且具有较高的效率。之后一直沿用这种空洞卷积模块。该模块的输出大小为: (?, 33, 33, 256)。
    Decoder:在【DeepLab-V3+】中,通过Decode模块(融合低层次特征和高层次特征)逐步恢复到一定的分辨率,最终的输出为,shape=(?, 129, 129, 21). 相比与DeepLab-V1,V2,V3中直接上采样的方式,通过Decode编码模块,能够更好的恢复语义分割的细节特征。

    在这里插入图片描述
    下面是【train.py】代码中模型的定义接口,

    def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):
        """Builds a clone of DeepLab.
    
        Args:
            iterator: An iterator of type tf.data.Iterator for images and labels.
            outputs_to_num_classes: A map from output type to the number of classes. For
                example, for the task of semantic segmentation with 21 semantic classes,
                we would have outputs_to_num_classes['semantic'] = 21.
            ignore_label: Ignore label.
        """
        samples = iterator.get_next()
        print('samples: ', samples)
    
        # Add name to input and label nodes so we can add to summary.
        samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
        samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)
    
        model_options = common.ModelOptions(outputs_to_num_classes=outputs_to_num_classes,
                                            crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                                            atrous_rates=FLAGS.atrous_rates,
                                            output_stride=FLAGS.output_stride)
        print("model option: ", model_options)
    
        # define and load model
        outputs_to_scales_to_logits = model.multi_scale_logits(samples[common.IMAGE],
                                                               model_options=model_options,
                                                               image_pyramid=FLAGS.image_pyramid,
                                                               weight_decay=FLAGS.weight_decay,
                                                               is_training=True,
                                                               fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
    
    
  4. 损失函数的计算
    在【train.py】脚本中,损失函数的调用接口如下代码所示。由于语义分割实际上是每一个像素的分类问题,所以可以使用交叉熵损失函数:

    	# 由于语义分割问题转换为分类问题,所以使用Softmax交叉熵
        for output, num_classes in six.iteritems(outputs_to_num_classes):
            train_utils.add_softmax_cross_entropy_loss_for_each_scale(
                outputs_to_scales_to_logits[output],
                samples[common.LABEL],
                num_classes,
                ignore_label,
                loss_weight=model_options.label_weights,
                upsample_logits=FLAGS.upsample_logits,
                hard_example_mining_step=FLAGS.hard_example_mining_step,
                top_k_percent_pixels=FLAGS.top_k_percent_pixels,
                scope=output)
    

    在【DeepLab-V3+】中,经过Decoder之后,输出形状为 shape=(?, 129, 129, 256),由于 PASCAL VOC 2012分割任务是【21】类目标的语义分割,所以最终的网络输出应该转换为 shape=(?, 129, 129, 21)。由于标签的大小为 shape=(?, 513, 513, 1),DeepLab 采取的方式是上采样网络的预测结果,这样才能计算损失函数。当然也可以将标签降采样为shape=(?, 129, 129, 21),但是这样会影响分割精度,

            if upsample_logits:
                print("Loss Run Here 1\n")
                # Label is not downsampled, and instead we upsample logits.
                # print(preprocess_utils.resolve_shape(labels, 4)[1:3])
                # print(logits)
    
                logits = tf.image.resize_bilinear(logits,
                                                  preprocess_utils.resolve_shape(labels, 4)[1:3],
                                                  align_corners=True)
                print("upsample logits: ", logits)
                # upsample logits:  Tensor("ResizeBilinear_1:0", shape=(?, 513, 513, 21), dtype=float32, device=/device:GPU:0)
    

    为了计算的方便,代码中,将标签拉伸为一维的向量,代码如下,注意代码的注释部分:

            scaled_labels = tf.reshape(scaled_labels, shape=[-1])  # 将标签flatten为一维向量,每一张标签图是一个一维块
            print(scaled_labels)
            # 拉伸前的标签:  Tensor("label:0", shape=(?, 513, 513, 1), dtype=int32, device=/device:GPU:0)
            # 拉伸后的标签:Tensor("Reshape:0", shape=(?,), dtype=int32, device=/device:GPU:0)
            # n为一个batch的图片数量,形式为:[labelmap-1, ..., labelmap-n]
    

    同时,也需要考虑忽略标签的问题,将其权重置0,这样计算损失的时候就不会参与运算。经过上采样操作后,网络输出为【batchx513x513x21】,具体见下面代码的注释说明:

            # 将像素值不等于255(忽略标签的像素值)的区域置1,标注边界为0
            weights = utils.get_label_weight_mask(scaled_labels, ignore_label, num_classes, label_weights=loss_weight)
    
            # Dimension of keep_mask is equal to the total number of pixels.
            keep_mask = tf.cast(tf.not_equal(scaled_labels, ignore_label), dtype=tf.float32)
    
            train_labels = None
            print("logits======: ", logits)
            logits = tf.reshape(logits, shape=[-1, num_classes])
            # origin logits:  Tensor("ResizeBilinear_1:0", shape=(?, 513, 513, 21), dtype=float32, device=/device:GPU:0)
    	    # reshape logits:  Tensor("Reshape_1:0", shape=(?, 21), dtype=float32, device=/device:GPU:0)
    	    # featuremap-b1-1:大小为513x513,的一维列向量,b1:表示一个batch(=4)内第1张图,-1:表示的第几个通道
    	    # 形式应该为:[[featuremap-b1-1, featuremap-b1-2, ..., featuremap-b1-21],
    	    				...,
    	    			 [featuremap-b1-1, featuremap-b1-2, ..., featuremap-b1-21],
    	    			 [featuremap-b2-1, featuremap-b2-2, ..., featuremap-b2-21],
    	    			 ...,
    	    			 [featuremap-b2-1, featuremap-b2-2, ..., featuremap-b2-21],
    	    			 ...,
    	    			 [featuremap-b4-1, featuremap-b4-2, ..., featuremap-b4-21],
    	    			 ...,
    	    			 [featuremap-b4-1, featuremap-b4-2, ..., featuremap-b4-21]]
    

    我们知道,在训练分类问网络时,通常将标签转为【one-hot】形式。训练过程中,假设输入图的batch=2,标签图大小为【513x513x1】。首先将标签【2x513x513x1】拉伸为一维向量,比如形式为【0,0, …, 0, 1,1,…,1】,0和1分别表示2张图块顺次排列。当转【one-hot】的时候,判断每一个像素的值,非零的话,根据所属的类别(在生成语义分割标签的时候,图中每一个类别都有特定的索引值,比如人的像素值为1,自行车为2,等等)转为【one-hot】标签,那么每一个像素变成21维的【one-hot】向量,代码如下:

                train_labels = tf.one_hot(scaled_labels, num_classes, on_value=1.0, off_value=0.0)
    			# train label:  Tensor("one_hot:0", shape=(?, 21), dtype=float32, device=/device:GPU:0)
    			# 假如batch为4,labelmap-1 表示每一张图中分割类的 onehot,形状为 (513x513,21)
    			#  形式应该为:[[labelmap-1, 
    			#			   labelmap-2,
    			#			   ..., 
    			#			   labelmap-4]
    

    经过上述标签和网络输出的转换,它们具有相同的形状,可以直接传入损失函数进行计算,具体代码如下所示:

            with tf.name_scope(loss_scope, default_loss_scope, [logits, train_labels, weights]):
                # Compute the loss for all pixels.
                pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(train_labels, name='train_labels_stop_gradient'),
                                                                          logits=logits,
                                                                          name='pixel_losses')
    
                weighted_pixel_losses = tf.multiply(pixel_losses, weights)
    
                if top_k_percent_pixels == 1.0:
                    print('Run Here Loss 6')
                    total_loss = tf.reduce_sum(weighted_pixel_losses)
                    num_present = tf.reduce_sum(keep_mask)
                    loss = _div_maybe_zero(total_loss, num_present)
                    tf.losses.add_loss(loss)
    

三、模型评估

语义分割是对每一个像素进行分类,从而得到像素级的分割结果。那么在评估模型的效果时,使用【mIOU】评估算法,也即是求预测值与真值的交并比,值的范围为【0,1】,当为0时,所有的像素预测错误,当为1时,所有的像素预测正确。关于【IOU】的详细解释,请参考我之前的博文 YOLO-V3代码解析系列(五) —— 损失函数(yolov3.py). 在官方公布的评估代码中,【eval.py】详细的展示了整个过程,代码如下:

"""Evaluation script for the DeepLab model.

See model.py for more details and usage.
"""

import numpy as np
import six
import tensorflow as tf
from tensorflow.contrib import metrics as contrib_metrics
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib import tfprof as contrib_tfprof
from tensorflow.contrib import training as contrib_training
from deeplab import common
from deeplab import model
from deeplab.datasets import data_generator

flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')

# Settings for log directories.
flags.DEFINE_string('eval_logdir', './eval_log/', 'Where to write the event logs.')

flags.DEFINE_string('checkpoint_dir', './checkpoint-0326/', 'Directory of model checkpoints.')

# Settings for evaluating the model.
flags.DEFINE_integer('eval_batch_size', 1,
                     'The number of images in each batch during evaluation.')

flags.DEFINE_list('eval_crop_size', '513, 513',
                  'Image crop size [height, width] for evaluation.')

flags.DEFINE_integer('eval_interval_secs', 60 * 5,
                     'How often (in seconds) to run evaluation.')

# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
# one could use different atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer('atrous_rates', [6, 12, 18],
                           'Atrous rates for atrous spatial pyramid pooling.')

flags.DEFINE_integer('output_stride', 16,
                     'The ratio of input to output spatial resolution.')

# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test.
flags.DEFINE_multi_float('eval_scales', [1.0],
                         'The scales to resize images for evaluation.')

# Change to True for adding flipped images during test.
flags.DEFINE_bool('add_flipped_images', False,
                  'Add flipped images for evaluation or not.')

flags.DEFINE_integer('quantize_delay_step', -1,
                     'Steps to start quantized training. If < 0, will not quantize model.')

# Dataset settings.
flags.DEFINE_string('dataset', 'pascal_voc_seg',
                    'Name of the segmentation dataset.')

flags.DEFINE_string('eval_split', 'val',
                    'Which split of the dataset used for evaluation')

flags.DEFINE_string('dataset_dir',
                    './datasets/pascal_voc_seg/tfrecord/trainaug',
                    # './datasets/pascal_voc_seg/tfrecord/offical',
                    'Where the dataset reside.')

flags.DEFINE_integer('max_number_of_evaluations', 0,
                     'Maximum number of eval iterations. Will loop '
                     'indefinitely upon nonpositive values.')


def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.eval_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=FLAGS.eval_batch_size,
        crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        model_variant=FLAGS.model_variant,
        num_readers=2,
        is_training=False,
        should_shuffle=False,
        should_repeat=False)

    tf.gfile.MakeDirs(FLAGS.eval_logdir)
    tf.logging.info('Evaluating on %s set', FLAGS.eval_split)

    with tf.Graph().as_default():
        samples = dataset.get_one_shot_iterator().get_next()

        model_options = common.ModelOptions(
            outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
            crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        # Set shape in order for tf.contrib.tfprof.model_analyzer to work properly.
        samples[common.IMAGE].set_shape(
            [FLAGS.eval_batch_size,
             int(FLAGS.eval_crop_size[0]),
             int(FLAGS.eval_crop_size[1]),
             3])

        if tuple(FLAGS.eval_scales) == (1.0,):
            tf.logging.info('Performing single-scale test.')
            predictions = model.predict_labels(samples[common.IMAGE], model_options,
                                               image_pyramid=FLAGS.image_pyramid)
        else:
            tf.logging.info('Performing multi-scale test.')
            if FLAGS.quantize_delay_step >= 0:
                raise ValueError(
                    'Quantize mode is not supported with multi-scale test.')

            predictions = model.predict_labels_multi_scale(
                samples[common.IMAGE],
                model_options=model_options,
                eval_scales=FLAGS.eval_scales,
                add_flipped_images=FLAGS.add_flipped_images)
        print(predictions)
        predictions = predictions[common.OUTPUT_TYPE]
        print('predictions: ', predictions)

        predictions = tf.reshape(predictions, shape=[-1])
        labels = tf.reshape(samples[common.LABEL], shape=[-1])
        print('reshape predictions: ', predictions)
        print('reshape label: ', labels)

        weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))

        # Set ignore_label regions to label 0, because metrics.mean_iou requires
        # range of labels = [0, dataset.num_classes). Note the ignore_label regions
        # are not evaluated since the corresponding regions contain weights = 0.
        labels = tf.where(
            tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels)

        predictions_tag = 'miou'
        for eval_scale in FLAGS.eval_scales:
            predictions_tag += '_' + str(eval_scale)
        if FLAGS.add_flipped_images:
            predictions_tag += '_flipped'

        # Define the evaluation metric.
        metric_map = {}
        num_classes = dataset.num_of_classes
        metric_map['eval/%s_overall' % predictions_tag] = tf.metrics.mean_iou(
            labels=labels, predictions=predictions, num_classes=num_classes,
            weights=weights)

        # IoU for each class.
        one_hot_predictions = tf.one_hot(predictions, num_classes)
        one_hot_predictions = tf.reshape(one_hot_predictions, [-1, num_classes])
        one_hot_labels = tf.one_hot(labels, num_classes)
        one_hot_labels = tf.reshape(one_hot_labels, [-1, num_classes])
        for c in range(num_classes):
            predictions_tag_c = '%s_class_%d' % (predictions_tag, c)
            tp, tp_op = tf.metrics.true_positives(
                labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c],
                weights=weights)
            fp, fp_op = tf.metrics.false_positives(
                labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c],
                weights=weights)
            fn, fn_op = tf.metrics.false_negatives(
                labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c],
                weights=weights)

            tp_fp_fn_op = tf.group(tp_op, fp_op, fn_op)
            iou = tf.where(tf.greater(tp + fn, 0.0),
                           tp / (tp + fn + fp),
                           tf.constant(np.NaN))
            metric_map['eval/%s' % predictions_tag_c] = (iou, tp_fp_fn_op)

        (metrics_to_values, metrics_to_updates) = contrib_metrics.aggregate_metric_map(metric_map)

        summary_ops = []
        for metric_name, metric_value in six.iteritems(metrics_to_values):
            op = tf.summary.scalar(metric_name, metric_value)
            op = tf.Print(op, [metric_value], metric_name)
            summary_ops.append(op)

        summary_op = tf.summary.merge(summary_ops)
        summary_hook = contrib_training.SummaryAtEndHook(
            log_dir=FLAGS.eval_logdir, summary_op=summary_op)
        hooks = [summary_hook]

        num_eval_iters = None
        if FLAGS.max_number_of_evaluations > 0:
            num_eval_iters = FLAGS.max_number_of_evaluations

        if FLAGS.quantize_delay_step >= 0:
            contrib_quantize.create_eval_graph()

        contrib_tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=contrib_tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)

        contrib_tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS)

        contrib_training.evaluate_repeatedly(
            checkpoint_dir=FLAGS.checkpoint_dir,
            master=FLAGS.master,
            eval_ops=list(metrics_to_updates.values()),
            max_number_of_evaluations=num_eval_iters,
            hooks=hooks,
            eval_interval_secs=FLAGS.eval_interval_secs)


if __name__ == '__main__':
    flags.mark_flag_as_required('checkpoint_dir')
    flags.mark_flag_as_required('eval_logdir')
    flags.mark_flag_as_required('dataset_dir')
    tf.compat.v1.app.run()

四、测试脚本

在官方公布的测试文件中,【vis.py】运行完成后,会出现下面几个问题。

第一个问题:程序处于保持状态,见下图所示。
在这里插入图片描述

第二个问题:运行读取文件为【TFRecords】,这样显然方便实际测试。针对上述的一系列问题,同时也是在项目中遇到的实际的需求的情况下,对该脚本进行修改,可以读取单张图片进行测试。

第三个问题:官方代码存在于多个文件夹内,在Pycharm软件中很容易设置模块导入,当单独的测试代码中,导入可能有问题,具体的代码处理如下:

"""Segmentation results visualization on a given set of images.

See model.py for more details and usage.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
sys.path.append('./DeepLab/slim/')
import os.path
import time
import numpy as np
from six.moves import range
import tensorflow as tf
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib import training as contrib_training
# from DeepLab.deeplab import common
# from DeepLab.deeplab import model
# from DeepLab.deeplab.datasets import data_generator
# from DeepLab.deeplab.utils import save_annotation
from deeplab import common
from deeplab import model
# from deeplab.datasets import *
from deeplab.datasets import data_generator
from deeplab.utils import save_annotation
import cv2
from PIL import Image

flags = tf.app.flags

FLAGS = flags.FLAGS

flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')

# Settings for log directories.
flags.DEFINE_string("FileName",None, "sssss")
# flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.')
#
# flags.DEFINE_string('checkpoint_dir',None, 'Directory of model checkpoints.')

# Settings for visualizing the model.

flags.DEFINE_integer('vis_batch_size', 1,
                     'The number of images in each batch during evaluation.')

flags.DEFINE_list('vis_crop_size', '512, 512',
                  'Crop size [height, width] for visualization.')

flags.DEFINE_integer('eval_interval_secs', 0,
                     'How often (in seconds) to run evaluation.')

# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
# one could use different atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer('atrous_rates', [6, 12, 18],
                           'Atrous rates for atrous spatial pyramid pooling.')

flags.DEFINE_integer('output_stride', 16,
                     'The ratio of input to output spatial resolution.')

# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test.
flags.DEFINE_multi_float('eval_scales', [1.0],
                         'The scales to resize images for evaluation.')

# Change to True for adding flipped images during test.
flags.DEFINE_bool('add_flipped_images', False,
                  'Add flipped images for evaluation or not.')

flags.DEFINE_integer(
    'quantize_delay_step', -1,
    'Steps to start quantized training. If < 0, will not quantize model.')

# Dataset settings.

flags.DEFINE_string('dataset', 'mydata',
                    'Name of the segmentation dataset.')

flags.DEFINE_string('vis_split', 'val',
                    'Which split of the dataset used for visualizing results')

# flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')

flags.DEFINE_enum('colormap_type', 'pascal', ['pascal', 'cityscapes', 'ade20k'],
                  'Visualization colormap type.')

flags.DEFINE_boolean('also_save_raw_predictions', False,
                     'Also save raw predictions.')

flags.DEFINE_integer('max_number_of_iterations', 0,
                     'Maximum number of visualization iterations. Will loop '
                     'indefinitely upon nonpositive values.')

# The folder where semantic segmentation predictions are saved.
_SEMANTIC_PREDICTION_SAVE_FOLDER = 'segmentation_results'

# The folder where raw semantic segmentation predictions are saved.
_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER = 'raw_segmentation_results'

# The format to save image.
_IMAGE_FORMAT = '%06d_image'

# The format to save prediction
_PREDICTION_FORMAT = '%06d_prediction'

# To evaluate Cityscapes results on the evaluation server, the labels used
# during training should be mapped to the labels for evaluation.
_CITYSCAPES_TRAIN_ID_TO_EVAL_ID = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
                                   23, 24, 25, 26, 27, 28, 31, 32, 33]

vis_logdir="./Data/OutputData/"
checkpoint_dir = "./DeepLab/deeplab/train-0124-2/"
dataset_dir="./Data/OutputData/renderBMP/"


def _convert_train_id_to_eval_id(prediction, train_id_to_eval_id):
    """Converts the predicted label for evaluation.

	  There are cases where the training labels are not equal to the evaluation
	  labels. This function is used to perform the conversion so that we could
	  evaluate the results on the evaluation server.
	
	  Args:
	    prediction: Semantic segmentation prediction.
	    train_id_to_eval_id: A list mapping from train id to evaluation id.
	
	  Returns:
	    Semantic segmentation prediction whose labels have been changed.
  	"""
    converted_prediction = prediction.copy()
    for train_id, eval_id in enumerate(train_id_to_eval_id):
        converted_prediction[prediction == train_id] = eval_id

    return converted_prediction


def _process_batch(sess, original_images, semantic_predictions, image_names,
                   image_heights, image_widths, image_id_offset, save_dir,
                   raw_save_dir, train_id_to_eval_id=None, save_name=None):
    """Evaluates one single batch qualitatively.

	  Args:
	    sess: TensorFlow session.
	    original_images: One batch of original images.
	    semantic_predictions: One batch of semantic segmentation predictions.
	    image_names: Image names.
	    image_heights: Image heights.
	    image_widths: Image widths.
	    image_id_offset: Image id offset for indexing images.
	    save_dir: The directory where the predictions will be saved.
	    raw_save_dir: The directory where the raw predictions will be saved.
	    train_id_to_eval_id: A list mapping from train id to eval id.
 	 """
    (original_images_,
     semantic_predictions_,
     image_names_,
     image_heights_,
     image_widths_) = sess.run([original_images, semantic_predictions, image_names, image_heights, image_widths])

    num_image = semantic_predictions_.shape[0]
    for i in range(num_image):
        # image_height = np.squeeze(image_heights_[i])
        # image_width = np.squeeze(image_widths_[i])

        image_height = 512
        image_width = 512

        original_image = np.squeeze(original_images_[i])
        semantic_prediction = np.squeeze(semantic_predictions_[i])
        crop_semantic_prediction = semantic_prediction[:image_height, :image_width]

        # Save image.
        # save_annotation.save_annotation(
        #     original_image, save_dir, _IMAGE_FORMAT % (image_id_offset + i),
        #     add_colormap=False)

        save_name = save_name.split('/')[-1]
        save_annotation.save_annotation(original_image, save_dir, save_name[0:-4], add_colormap=False)

        # Save prediction.
        save_annotation.save_annotation(
            crop_semantic_prediction, save_dir,
            save_name[0:-4]+'_mask', add_colormap=True,
            colormap_type=FLAGS.colormap_type)

        if FLAGS.also_save_raw_predictions:
            image_filename = os.path.basename(image_names_[i])

            if train_id_to_eval_id is not None:
                crop_semantic_prediction = _convert_train_id_to_eval_id(
                    crop_semantic_prediction,
                    train_id_to_eval_id)
            save_annotation.save_annotation(
                crop_semantic_prediction, raw_save_dir, image_filename,
                add_colormap=False)


def main_entry():
    # flags.File_name = '2021_01_24.11_58_59.roc.bmp'
    file_name = FLAGS.FileName
    print("file name: ", file_name)
    tf.logging.set_verbosity(tf.logging.INFO)
    # Get dataset-dependent information.

    # dataset = data_generator.Dataset(
    #     dataset_name=FLAGS.dataset,
    #     split_name=FLAGS.vis_split,
    #     dataset_dir=FLAGS.dataset_dir,
    #     batch_size=FLAGS.vis_batch_size,
    #     crop_size=[int(sz) for sz in FLAGS.vis_crop_size],
    #     min_resize_value=FLAGS.min_resize_value,
    #     max_resize_value=FLAGS.max_resize_value,
    #     resize_factor=FLAGS.resize_factor,
    #     model_variant=FLAGS.model_variant,
    #     is_training=False,
    #     should_shuffle=False,
    #     should_repeat=False)
    # print("1========================================")
    # print(dataset.dataset_name)
    # exit()

    train_id_to_eval_id = None
    # if dataset.dataset_name == data_generator.get_cityscapes_dataset_name():
    #     tf.logging.info('Cityscapes requires converting train_id to eval_id.')
    #     train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID
    # Prepare for visualization.
    tf.gfile.MakeDirs(vis_logdir)
    save_dir = os.path.join(vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(save_dir)
    # raw_save_dir = os.path.join(
    #     FLAGS.vis_logdir, _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
    # tf.gfile.MakeDirs(raw_save_dir)

    # tf.logging.info('Visualizing on %s set', FLAGS.vis_split)

    with tf.Graph().as_default():
        # samples = dataset.get_one_shot_iterator().get_next()
        # sess = tf.Session()
        # sess.run(tf.global_variables_initializer())
        # # saver = tf.train.Saver()
        # saver = tf.train.import_meta_graph(FLAGS.checkpoint_dir+'model.ckpt-4056.meta')
        # saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))

        model_options = common.ModelOptions(
            outputs_to_num_classes={common.OUTPUT_TYPE: 2},
            crop_size=[int(sz) for sz in FLAGS.vis_crop_size],
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        path = dataset_dir
        # name = os.listdir(path)
        print("path: ", path)
        for i in range(1):
            # file = path + name[i]
            file = file_name
            print("file: ", file)
            image_raw_data_jpg = tf.gfile.FastGFile(file, 'rb').read()
            img_data_png = tf.image.decode_bmp(image_raw_data_jpg, channels=3)
            image = tf.image.convert_image_dtype(img_data_png, dtype=tf.uint8)
            image = tf.image.flip_up_down(image)
            image = tf.reshape(image, (1, 512, 512, 3))

            image_name = tf.convert_to_tensor(file_name, dtype=tf.string)
            input_shape = tf.convert_to_tensor(512, dtype=tf.int64)

            # if tuple(FLAGS.eval_scales) == (1.0,):
            tf.logging.info('Performing single-scale test.')
            # predictions = model.predict_labels(
            #     samples[common.IMAGE],
            #     model_options=model_options,
            #     image_pyramid=FLAGS.image_pyramid)

            predictions = model.predict_labels(
                image,
                model_options=model_options,
                image_pyramid=FLAGS.image_pyramid)

            # else:
            #   tf.logging.info('Performing multi-scale test.')
            #   if FLAGS.quantize_delay_step >= 0:
            #     raise ValueError(
            #         'Quantize mode is not supported with multi-scale test.')
            #   predictions = model.predict_labels_multi_scale(
            #       samples[common.IMAGE],
            #       model_options=model_options,
            #       eval_scales=FLAGS.eval_scales,
            #       add_flipped_images=FLAGS.add_flipped_images)
            # image_id_offset = 0
            # _process_batch(sess=sess,
            #                original_images=image,
            #                semantic_predictions=predictions,
            #                image_names=image_name,
            #                image_heights=input_shape,
            #                image_widths=input_shape,
            #                image_id_offset=image_id_offset,
            #                save_dir=save_dir,
            #                raw_save_dir=raw_save_dir,
            #                train_id_to_eval_id=train_id_to_eval_id,
            #                save_name=name[i])

            predictions_ = predictions[common.OUTPUT_TYPE]

            # if FLAGS.min_resize_value and FLAGS.max_resize_value:
            #   print('================S1')
            #   # Only support batch_size = 1, since we assume the dimensions of original
            #   # image after tf.squeeze is [height, width, 3].
            #   assert FLAGS.vis_batch_size == 1
            #
            #   # Reverse the resizing and padding operations performed in preprocessing.
            #   # First, we slice the valid regions (i.e., remove padded region) and then
            #   # we resize the predictions back.
            #   original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
            #   original_image_shape = tf.shape(original_image)
            #   predictions = tf.slice(
            #       predictions,
            #       [0, 0, 0],
            #       [1, original_image_shape[0], original_image_shape[1]])
            #   resized_shape = tf.to_int32([tf.squeeze(samples[common.HEIGHT]),
            #                                tf.squeeze(samples[common.WIDTH])])
            #   predictions = tf.squeeze(
            #       tf.image.resize_images(tf.expand_dims(predictions, 3),
            #                              resized_shape,
            #                              method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
            #                              align_corners=True), 3)

            tf.train.get_or_create_global_step()
            if FLAGS.quantize_delay_step >= 0:
                contrib_quantize.create_eval_graph()

            # num_iteration = 0
            # max_num_iteration = FLAGS.max_number_of_iterations

            # checkpoints_iterator = contrib_training.checkpoints_iterator(
            #     FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)

            checkpoints_iterator = checkpoint_dir+'model.ckpt-4056'
            # contrib_training.evaluate_once(FLAGS.checkpoint_dir+'model.ckpt-4056')
            checkpoint_path = checkpoints_iterator
            # for checkpoint_path in checkpoints_iterator:
            # num_iteration += 1
            tf.logging.info(
                'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                             time.gmtime()))
            tf.logging.info('Visualizing with model %s', checkpoint_path)

            config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
            config.gpu_options.allow_growth = True
            # config.gpu_options.allow_soft_placement = True
            scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer())
            session_creator = tf.train.ChiefSessionCreator(
                scaffold=scaffold,
                master=FLAGS.master,
                config=config,
                checkpoint_filename_with_path=checkpoint_path)

            with tf.train.MonitoredSession(
                    session_creator=session_creator, hooks=None) as sess:
                batch = 0
                image_id_offset = 0

                # while not sess.should_stop():
                tf.logging.info('Visualizing batch %d', batch + 1)
                # _process_batch(sess=sess,
                #                original_images=samples[common.ORIGINAL_IMAGE],
                #                semantic_predictions=predictions,
                #                image_names=samples[common.IMAGE_NAME],
                #                image_heights=samples[common.HEIGHT],
                #                image_widths=samples[common.WIDTH],
                #                image_id_offset=image_id_offset,
                #                save_dir=save_dir,
                #                raw_save_dir=raw_save_dir,
                #                train_id_to_eval_id=train_id_to_eval_id)

                _process_batch(sess=sess,
                               original_images=image,
                               semantic_predictions=predictions_,
                               image_names=image_name,
                               image_heights=input_shape,
                               image_widths=input_shape,
                               image_id_offset=image_id_offset,
                               save_dir=save_dir,
                               raw_save_dir=None,
                               train_id_to_eval_id=train_id_to_eval_id,
                               save_name=file_name)

                image_id_offset += FLAGS.vis_batch_size
                batch += 1

                # tf.logging.info(
                #     'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))

            # if max_num_iteration > 0 and num_iteration >= max_num_iteration:
            #     break

main_entry()
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值