【DeeplabV3+ get_miou_png】DeeplabV3+获取数据集预测结果灰度图

1 为什么有这么一篇文章

其实之前有写过deeplabv3+图像输入->处理->输出全过程,里面包含了如下内容:

deeplabv3+图像输入->处理->输出全过程 目录
该有的似乎都有了,只是想着大家平时针对数据集操作还挺多的,保存数据集的分割预测结果也是一小部分工作内容,故又加了这一篇,内容和上述文章区别不是很大,很容易。

2 获取并保存数据集分割预测结果

get_miou.py代码中,给出了下列代码,完成图片从输入到得到数据集预测结果灰度图的全部过程。

import os

from PIL import Image
from tqdm import tqdm

# ----------------------------------------------------------#
#	DeeplabV3表示分割网络结构,其代码在deeplab.py中,解读见下一节
# ----------------------------------------------------------#
from deeplab import DeeplabV3   
# ---------------------------------------------------------------------#
#	compute_mIoU和show_results,其代码在utils/utils_metrics.py中,
#	解读见链接:会有的
#	!本文中并未用到!
# ---------------------------------------------------------------------#
from utils.utils_metrics import compute_mIoU, show_results


"""
进行指标评估需要注意:
该文件生成的图为灰度图,因为值比较小,按照PNG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。
"""
if __name__ == "__main__":
    #---------------------------------------------------------------------------#
    #   miou_mode用于指定该文件运行时计算的内容
    #   miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。
    #   miou_mode为1代表仅仅获得预测结果。
    #   miou_mode为2代表仅仅计算miou。          !!本文中并未用到!!
    #---------------------------------------------------------------------------#
    miou_mode       = 1
    #------------------------------------#
    #   分类个数+1、如2+1
    #   VOC数据集,所需要区分的类的个数+1
    #------------------------------------#
    num_classes     = 21
    #--------------------------------------------#
    #   区分的种类,和json_to_dataset里面的一样
    #   种类名称,此例为VOC
    #--------------------------------------------#
    name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    # name_classes    = ["_background_","cat","dog"]
    #-------------------------------------------------------------------#
    #   指向VOC数据集所在的文件夹
    #   默认指向根目录下的VOC数据集
    #   链接:https://pan.baidu.com/s/1OZfxoyVUKlESsyqs1nuuuw 提取码:wlna
    #-------------------------------------------------------------------#
    VOCdevkit_path  = '../VOCdevkit'

    #--------------------------------------------#
    #   image_ids:['图片名1', '图片名2',...]
    #--------------------------------------------#
    image_ids       = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines() 
    gt_dir          = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/")
    miou_out_path   = "miou_out"
    #-------------------------------------------------#
    #   pred_dir预测结果png图片路径,只有8位深度,灰度图
    #   正常jpg,RGB三通道,24位深度
    #   彩色png,RGBA四通道,32位深度
    #-------------------------------------------------#
    pred_dir        = os.path.join(miou_out_path, 'detection-results')  

    #-------------------------------------------------#
    #   获得预测结果,输出为8位深度的灰度图
    #-------------------------------------------------#
    if miou_mode == 0 or miou_mode == 1:
        if not os.path.exists(pred_dir):
            os.makedirs(pred_dir)
            
        #-----------------------------------------------------------------------------------#
        #   下方有给出代码
        #   详细解读见:https://blog.csdn.net/weixin_45377629/article/details/124124238
        #-----------------------------------------------------------------------------------#
        print("Load model.")
        deeplab = DeeplabV3()
        print("Load model done.")

        print("Get predict result.")
        for image_id in tqdm(image_ids):
            image_path  = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
            image       = Image.open(image_path)
            # ------------------------------------#
            #   image是png图片,8位深度,灰度图
            #   deeplab.get_miou_png(image)见下方解读
            #   # image size:(原图宽, 原图高)
            # ------------------------------------#
            image       = deeplab.get_miou_png(image)   
            image.save(os.path.join(pred_dir, image_id + ".png"))
        print("Get predict result done.")

结果输出:

结果输出

该文件生成的图为灰度图,因为值比较小,按照PNG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。

3 deeplab.get_miou_png()函数代码解析

通过deeplab.py完成get_miou.pyimage= deeplab.get_miou_png(image),用来获取数据集分割结果灰度图。deeplabv3+网络结构详细介绍可见 DeeplabV3+网络结构详解,通过网络结构获取8位深度的分割结果灰度图见下方代码。

import colorsys
import copy
import time

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import FloatTensor, nn, tensor

#---------------------------------------------------------------------------------#
#   DeepLab网络及代码
#   详细介绍可见 https://blog.csdn.net/weixin_45377629/article/details/124083978
#---------------------------------------------------------------------------------#
from nets.deeplabv3_plus import DeepLab
#----------------------------------------------------------------------------------#
#   三个函数代码下方给出
#   cvtColor:          将图像转换成RGB图像,防止灰度图在预测时报错。
#   preprocess_input:  归一化
#   resize_image:      对输入图像进行resize,letterbox_image方式,不失真resize
#----------------------------------------------------------------------------------#
from utils.utils import cvtColor, preprocess_input, resize_image


#-----------------------------------------------------------------------------------#
#   使用自己训练好的模型预测需要修改3个参数
#   model_path、backbone和num_classes都需要修改!
#   如果出现shape不匹配,一定要注意训练时的model_path、backbone和num_classes的修改
#-----------------------------------------------------------------------------------#
class DeeplabV3(object):
    _defaults = {
        #-------------------------------------------------------------------#
        #   model_path指向logs文件夹下的权值文件
        #   训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
        #   验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。
        #   链接:https://pan.baidu.com/s/1TrBlnZUd6xwxUvgFjbz7TQ 提取码:cj80
        #-------------------------------------------------------------------#
        "model_path"        : 'model_data/deeplab_mobilenetv2.pth',
        #----------------------------------------#
        #   所需要区分的类的个数+1
        #----------------------------------------#
        "num_classes"       : 21,
        #----------------------------------------#
        #   所使用的的主干网络:
        #   mobilenet  
        #----------------------------------------#
        "backbone"          : "mobilenet",
        #----------------------------------------#
        #   输入图片的大小
        #----------------------------------------#
        "input_shape"       : [512, 512],
        #----------------------------------------#
        #   下采样的倍数,一般可选的为8和16
        #   与训练时设置的一样即可
        #----------------------------------------#
        "downsample_factor" : 16,
        #-------------------------------------------------#
        #   mix_type参数用于控制检测结果的可视化方式
        #
        #   mix_type = 0的时候代表原图与生成的图进行混合
        #   mix_type = 1的时候代表仅保留生成的图
        #   mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标
        #   下方有给出三种可视化结果的区别
        #-------------------------------------------------#
        "mix_type"          : 0,
        #-------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------#
        "cuda"              : False,
    }

    #---------------------------------------------------#
    #   初始化Deeplab
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        #---------------------------------------------------#
        #   _defaults字典原来是这么用起来的
        #---------------------------------------------------#
        self.__dict__.update(self._defaults)

        for name, value in kwargs.items():
            #-----------------------------------------------#
            #   设置属性 name 值,即self.name==value
            #-----------------------------------------------#
            setattr(self, name, value)
        #---------------------------------------------------#
        #   画框设置不同的颜色
        #---------------------------------------------------#
        if self.num_classes <= 21:
            self.colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), 
                            (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), 
                            (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), 
                            (128, 64, 12)]
        else:
            hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
            self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
            self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
        #---------------------------------------------------#
        #   获得模型
        #---------------------------------------------------#
        self.generate()
                    
    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def generate(self):
        #-----------------------------------------------------------------------------------#
        #   载入模型与权值
        #   详细介绍可见 https://blog.csdn.net/weixin_45377629/article/details/124083978
        #-----------------------------------------------------------------------------------#
        self.net = DeepLab(num_classes=self.num_classes, backbone=self.backbone, downsample_factor=self.downsample_factor, pretrained=False)

        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net    = self.net.eval()
        print('{} model, and classes loaded.'.format(self.model_path))
        
        if self.cuda:
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

    #---------------------------------------------------#
    #   预测图片,得到灰度图结果
    #---------------------------------------------------#
    def get_miou_png(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        orininal_h  = np.array(image).shape[0]
        orininal_w  = np.array(image).shape[1]
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #   VOC为例,self.net(images) shape:torch.size([1,21,512,512])
            #   pr :tensor, shape:torch.size([21,512,512])
            pr = self.net(images)[0]
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #   pr.permute(1,2,0):通道交换
            #   F.softmax(input, dim=-1):在行上softmax,和为1
            #   pr :array, shape:(512,512,21)
            #---------------------------------------------------#
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
            #----------------------------------------------------#
            #   将灰条部分截取掉
            #   letterbox_image一般会引入灰条
            #   pr :array, shape:(512,512,21),有灰条w、h尺寸会变
            #----------------------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #---------------------------------------------------#
            #   进行图片的resize
            #   灰条去掉后,resize回原图大小
            #   pr :array, shape:(orininal_w, orininal_h,21)
            #---------------------------------------------------#
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #   pr :array, shape:(orininal_w, orininal_h)
            #---------------------------------------------------#
            pr = pr.argmax(axis=-1)
    
        image = Image.fromarray(np.uint8(pr))   # size:(orininal_w, orininal_h)
        return image

4 感谢链接

https://blog.csdn.net/weixin_44791964/article/details/120113686
https://www.bilibili.com/video/BV173411q7xF?p=15
  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 11
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值