deeplabv3+源码之慢慢解析5 第一章根目录(5)predict.py--get_argparser函数和main函数

系列文章目录(共五章33节已完结)

第一章deeplabv3+源码之慢慢解析 根目录(1)main.py–get_argparser函数
第一章deeplabv3+源码之慢慢解析 根目录(2)main.py–get_dataset函数
第一章deeplabv3+源码之慢慢解析 根目录(3)main.py–validate函数
第一章deeplabv3+源码之慢慢解析 根目录(4)main.py–main函数
第一章deeplabv3+源码之慢慢解析 根目录(5)predict.py–get_argparser函数和main函数

第二章deeplabv3+源码之慢慢解析 datasets文件夹(1)voc.py–voc_cmap函数和download_extract函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(2)voc.py–VOCSegmentation类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(3)cityscapes.py–Cityscapes类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(4)utils.py–6个小函数

第三章deeplabv3+源码之慢慢解析 metrics文件夹stream_metrics.py–StreamSegMetrics类和AverageMeter类

第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a1)hrnetv2.py–4个函数和可执行代码
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a2)hrnetv2.py–Bottleneck类和BasicBlock类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a3)hrnetv2.py–StageModule类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a4)hrnetv2.py–HRNet类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b1)mobilenetv2.py–2个类和2个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b2)mobilenetv2.py–MobileNetV2类和mobilenet_v2函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c1)resnet.py–2个基础函数,BasicBlock类和Bottleneck类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c2)resnet.py–ResNet类和10个不同结构的调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d1)xception.py–SeparableConv2d类和Block类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d2)xception.py–Xception类和xception函数
第四章deeplabv3+源码之慢慢解析 network文件夹(2)_deeplab.py–ASPP相关的4个类和1个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(3)_deeplab.py–DeepLabV3类,DeepLabHeadV3Plus类和DeepLabHead类
第四章deeplabv3+源码之慢慢解析 network文件夹(4)modeling.py–5个私有函数(4个骨干网,1个模型载入)
第四章deeplabv3+源码之慢慢解析 network文件夹(5)modeling.py–12个调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(6)utils.py–_SimpleSegmentationModel类和IntermediateLayerGetter类

第五章deeplabv3+源码之慢慢解析 utils文件夹(1)ext_transforms.py.py–2个翻转类和ExtCompose类
第五章deeplabv3+源码之慢慢解析 utils文件夹(2)ext_transforms.py.py–2个裁剪类和2个缩放类
第五章deeplabv3+源码之慢慢解析 utils文件夹(3)ext_transforms.py.py–旋转类,填充类,张量转化类和标准化类
第五章deeplabv3+源码之慢慢解析 utils文件夹(4)ext_transforms.py.py–ExtResize类,ExtColorJitter类,Lambda类和Compose类
第五章deeplabv3+源码之慢慢解析 utils文件夹(5)loss.py–FocalLoss类
第五章deeplabv3+源码之慢慢解析 utils文件夹(6)scheduler.py–PolyLR类
第五章deeplabv3+源码之慢慢解析 utils文件夹(7)utils.py–去标准化,momentum设定,标准化层锁定和路径创建
第五章deeplabv3+源码之慢慢解析 utils文件夹(8)visualizer.py–Visualizer类(完结)


说明

  1. main.py的代码已经说完,有详细的训练过程(此处重在过程,不再函数细节)。
  2. 根目录下面只有两个python源代码文件,此篇说第二个predict.py。
  3. 原本是每次只说一个函数,因predict.py中仅有两个函数get_argparser和main,且都和main.py相似,因此,一次说完。
  4. predict.py顾名思义,就是用于预测的代码,实际上是main.py训练好之后,单独用于预测的代码,和validate函数有很多类似之处,大家可以边学习边对比。

predict.py导入

from torch.utils.data import dataset  #因直接预测,需要使用数据部分
#以下同main.py
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np

#以下是数据部分所需
from torch.utils import data
from datasets import VOCSegmentation, Cityscapes, cityscapes
from torchvision import transforms as T
from metrics import StreamSegMetrics
#以下是神经网络所需
import torch
import torch.nn as nn
#以下是可视化和图片操作所需
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from glob import glob
  1. 导入部分,大致分块注释,按此理解即可。
  2. 很多地方和main.py导入一致,建议一定详细看完关于main.py介绍,之后再看predict.py,会感觉很简单。
  3. predict.py分为get_argparser函数和main函数,下面分别详解。

参数解析,get_argparser函数

提示:看过main.py部分的get_argparser函数,再看下面的内容。是不是有种似曾相识的感觉?代码学习要多积累,切记切记。

def get_argparser():
    parser = argparse.ArgumentParser()

    # Datset Options
    #input参数,此处输入预测文件夹的路径,建议即使是单一预测图像也放在文件夹中,养成良好的路径管理习惯。如本代码测试数据选用了samples文件夹下的图像,结果图像放在自建的test_results文件夹。
    parser.add_argument("--input", type=str, required=True,
                        help="path to a single image or image directory")   
    parser.add_argument("--dataset", type=str, default='voc',
                        choices=['voc', 'cityscapes'], help='Name of training set')#同main.py的get_argparser函数第10行。

    # Deeplab Options
    available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \
                              not (name.startswith("__") or name.startswith('_')) and callable(
                              network.modeling.__dict__[name])
                              )#同main.py的get_argparser函数第19行。

    parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet',
                        choices=available_models, help='model name')#同main.py的get_argparser函数第26行。
    parser.add_argument("--separable_conv", action='store_true', default=False,
                        help="apply separable conv to decoder and aspp")#同main.py的get_argparser函数第30行。
    parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16])#同main.py的get_argparser函数第32行。

    # Train Options
    parser.add_argument("--save_val_results_to", default=None,
                        help="save segmentation results to the specified dir")#此处开启了验证保存。这里输入的是保存结果的目录。

    parser.add_argument("--crop_val", action='store_true', default=False,
                        help='crop validation (default: False)')#同main.py的get_argparser函数第50行。
    parser.add_argument("--val_batch_size", type=int, default=4,
                        help='batch size for validation (default: 4)')#同main.py的get_argparser函数第54行。
    parser.add_argument("--crop_size", type=int, default=513)#同main.py的get_argparser函数第57行。

    
    parser.add_argument("--ckpt", default=None, type=str,
                        help="resume from checkpoint")#同main.py的get_argparser函数第58行。
    parser.add_argument("--gpu_id", type=str, default='0',
                        help="GPU ID")#同main.py的get_argparser函数第65行。
    return parser

主函数,main函数

提示:同样,对照main.py部分的main函数,更容易理解此处的代码。

def main():
    opts = get_argparser().parse_args()   #同main.py
    if opts.dataset.lower() == 'voc':     #数据集选择
        opts.num_classes = 21
        decode_fn = VOCSegmentation.decode_target     #使用解码后的数据
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19
        decode_fn = Cityscapes.decode_target

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id   #GPU选择
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup dataloader
    image_files = []
    if os.path.isdir(opts.input):    #从get_argparser函数中的input参数获得路径,逐一添加图像
        for ext in ['png', 'jpeg', 'jpg', 'JPEG']:
            files = glob(os.path.join(opts.input, '**/*.%s'%(ext)), recursive=True)
            if len(files)>0:
                image_files.extend(files)
    elif os.path.isfile(opts.input):
        image_files.append(opts.input)
    
    # Set up model (all models are 'constructed at network.modeling)
    model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)     #同main.py中main函数第37行。
    if opts.separable_conv and 'plus' in opts.model:  #同main.py中main函数第38行。
        network.convert_to_separable_conv(model.classifier)  #同main.py中main函数第39行。
    utils.set_bn_momentum(model.backbone, momentum=0.01)   #同main.py中main函数第40行。
    
    if opts.ckpt is not None and os.path.isfile(opts.ckpt): #以下同main.py中main函数第81-86行。(是81-98行的简化)
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        print("Resume model from %s" % opts.ckpt)
        del checkpoint
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    #denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori images

    if opts.crop_val:      #裁剪验证,对数据进行尺寸变化
        transform = T.Compose([
                T.Resize(opts.crop_size),
                T.CenterCrop(opts.crop_size),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
    else:
        transform = T.Compose([
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
    if opts.save_val_results_to is not None:    #新建验证结果保存文件夹
        os.makedirs(opts.save_val_results_to, exist_ok=True)
    with torch.no_grad():
        model = model.eval()    #预测,不训练,关闭梯度更新。此部分的理解,可以参考main.py中validate函数20-28行。
        for img_path in tqdm(image_files):
            ext = os.path.basename(img_path).split('.')[-1]
            img_name = os.path.basename(img_path)[:-len(ext)-1]
            img = Image.open(img_path).convert('RGB')
            img = transform(img).unsqueeze(0) # To tensor of NCHW 
            #对上一句的补充,unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度,比如原本有个三行的数据(3),unsqueeze(0)后就会在0的位置加了一维就变成一行三列(1,3)。
            img = img.to(device)
            
            pred = model(img).max(1)[1].cpu().numpy()[0] # HW   #可以参考main.py中validate函数27行。
            colorized_preds = decode_fn(pred).astype('uint8')    #使用上一句得到的图像索引(pred),在解码目标(本main函数第5行)中得到对应的图像。
            colorized_preds = Image.fromarray(colorized_preds)
            if opts.save_val_results_to:
                colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png')) #将验证结果(即得到的解码图像)保存到前面指定的文件夹。

Tips

  1. 解析参数函数get_argparser函数,对比main.py中的get_argparser函数,每一条都标出了注释,可逐一对理解,不同的也做出了解释,相对简单。

  2. 本文的main函数中可借鉴main.py中的validate和main函数进行学习,相对容易理解。

  3. 根目录下的两个代码已解析完毕。按显示的顺序,下一个介绍datasets文件夹。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
DeepLab是一种基于深度学习的语义分割算法,可以用来图像进行像素级别的分割,实现图像中物体或场景的分割和识别。 以下是使用DeepLab算法进行图像分割的基本步骤: 1. 数据预处理:将原始图像转换为算法输入所需的格式,例如将图像缩放到固定大小、进行归一化等。 2. 加载模型:使用DeepLab的预训练模型或自己训练的模型,加载模型参数和结构。 3. 输入图像:将预处理后的图像输入到模型中,进行前向计算。 4. 获取分割结果:从模型的输出中获取分割结果,通常是一个与输入图像大小相同的分割图像,其中每个像素都表示该像素所属的类别。 5. 后处理:根据需要对分割结果进行后处理,例如去除噪声、合并相邻的像素等。 下面是一个示例代码,使用DeepLabv3+模型进行图像分割: ``` import tensorflow as tf import numpy as np import cv2 # 数据预处理 def preprocess(image): image = cv2.resize(image, (513, 513)) image = image.astype(np.float32) / 255.0 image = image[np.newaxis, ...] return image # 加载模型 model = tf.keras.models.load_model('deeplabv3plus.h5') # 加载图像 image = cv2.imread('input_image.jpg') # 数据预处理 image = preprocess(image) # 输入图像 output = model.predict(image) # 获取分割结果 output = np.squeeze(output) output = np.argmax(output, axis=-1) output = output.astype(np.uint8) # 后处理 output = cv2.resize(output, (image.shape[2], image.shape[1]), interpolation=cv2.INTER_NEAREST) # 显示分割结果 cv2.imshow('Segmentation Result', output) cv2.waitKey(0) cv2.destroyAllWindows() ``` 在上面的代码中,我们首先定义了一个preprocess函数,用于对输入的图像进行预处理。然后,我们通过tf.keras.models.load_model函数加载了DeepLabv3+模型,并使用预处理后的图像作为模型的输入,进行前向计算。最后,我们获取了分割结果,并进行了后处理,最终显示了分割后的图像。 需要注意的是,在使用DeepLab算法进行图像分割时,需要使用较高的计算资和较长的计算时间,因此建议在GPU环境下运行代码。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值