deeplabv3+源码之慢慢解析3 第一章根目录(3)main.py--validate函数

系列文章目录(共五章33节已完结)

第一章deeplabv3+源码之慢慢解析 根目录(1)main.py–get_argparser函数
第一章deeplabv3+源码之慢慢解析 根目录(2)main.py–get_dataset函数
第一章deeplabv3+源码之慢慢解析 根目录(3)main.py–validate函数
第一章deeplabv3+源码之慢慢解析 根目录(4)main.py–main函数
第一章deeplabv3+源码之慢慢解析 根目录(5)predict.py–get_argparser函数和main函数

第二章deeplabv3+源码之慢慢解析 datasets文件夹(1)voc.py–voc_cmap函数和download_extract函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(2)voc.py–VOCSegmentation类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(3)cityscapes.py–Cityscapes类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(4)utils.py–6个小函数

第三章deeplabv3+源码之慢慢解析 metrics文件夹stream_metrics.py–StreamSegMetrics类和AverageMeter类

第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a1)hrnetv2.py–4个函数和可执行代码
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a2)hrnetv2.py–Bottleneck类和BasicBlock类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a3)hrnetv2.py–StageModule类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a4)hrnetv2.py–HRNet类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b1)mobilenetv2.py–2个类和2个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b2)mobilenetv2.py–MobileNetV2类和mobilenet_v2函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c1)resnet.py–2个基础函数,BasicBlock类和Bottleneck类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c2)resnet.py–ResNet类和10个不同结构的调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d1)xception.py–SeparableConv2d类和Block类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d2)xception.py–Xception类和xception函数
第四章deeplabv3+源码之慢慢解析 network文件夹(2)_deeplab.py–ASPP相关的4个类和1个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(3)_deeplab.py–DeepLabV3类,DeepLabHeadV3Plus类和DeepLabHead类
第四章deeplabv3+源码之慢慢解析 network文件夹(4)modeling.py–5个私有函数(4个骨干网,1个模型载入)
第四章deeplabv3+源码之慢慢解析 network文件夹(5)modeling.py–12个调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(6)utils.py–_SimpleSegmentationModel类和IntermediateLayerGetter类

第五章deeplabv3+源码之慢慢解析 utils文件夹(1)ext_transforms.py.py–2个翻转类和ExtCompose类
第五章deeplabv3+源码之慢慢解析 utils文件夹(2)ext_transforms.py.py–2个裁剪类和2个缩放类
第五章deeplabv3+源码之慢慢解析 utils文件夹(3)ext_transforms.py.py–旋转类,填充类,张量转化类和标准化类
第五章deeplabv3+源码之慢慢解析 utils文件夹(4)ext_transforms.py.py–ExtResize类,ExtColorJitter类,Lambda类和Compose类
第五章deeplabv3+源码之慢慢解析 utils文件夹(5)loss.py–FocalLoss类
第五章deeplabv3+源码之慢慢解析 utils文件夹(6)scheduler.py–PolyLR类
第五章deeplabv3+源码之慢慢解析 utils文件夹(7)utils.py–去标准化,momentum设定,标准化层锁定和路径创建
第五章deeplabv3+源码之慢慢解析 utils文件夹(8)visualizer.py–Visualizer类(完结)


第一章deeplabv3+源码之慢慢解析根目录(3)main.py–validate函数

本篇介绍main.py中的第三个函数validate,主要涉及函数中所使用的验证数据。

验证数据设置,validate函数

提示:这个函数相对简单独立,主要是为了main函数服务,做交叉验证并返回具体的数据。很多地方需要结合main函数的语句一起了解,会在注释地方补充。

def validate(opts, model, loader, device, metrics, ret_samples_ids=None):
#私以为,此处最重要的就是搞清楚这6个参数都是做啥的,理清思路比较重要,剩下的语法是相对简单的。在main函数中,对validate函数的使用语句如下val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id),接下来对每个参数详解。
#opts前文get_argparser函数提过,就是在命令行窗口输入的命令参数解析后的结果。
#model就是所用的模型,,在main函数中model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride),具体需要在network文件夹中的modeling.py查看。后文详解。
#loader就是生成交叉验证的数据集。对应main函数中的val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2),其中data就是导入中的第11行,为了调用torch自带的dataloader(末尾有补充详细内容的链接。)
#device是指CPU或者选择的GPU,对应main函数中device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')。
#metrics度量,就是处理结果指标的矩阵,具体对应导入部分第14行,详见metrics文件夹下的stream_metrics.py文件,后文详解。
#ret_samples_ids默认值为None,对应main函数中 vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,np.int32) if opts.enable_vis else None  即可视化样本的索引,需要用到get_argparser函数中第84和90行,可视化是否开启和可视化样本数量。

    """Do validation and return specified samples"""
    metrics.reset()    #每次validation重置metrics矩阵,即初始化。
    ret_samples = []
    if opts.save_val_results:       #get_argparser函数中第38行,即是否保存validation的结果。
        if not os.path.exists('results'):
            os.mkdir('results')      #如果没有results文件夹,就建一个。
        denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])  #归一化还原,详见utils文件夹下的utils.py。具体值来自imagnet训练结果,详见get_dataset函数第14行注释和对应的链接。
        img_id = 0

    with torch.no_grad():    #交叉验证阶段,非训练,不需要梯度更新
        for i, (images, labels) in tqdm(enumerate(loader)): #从loader中读取序号,图像数据和标签

            images = images.to(device, dtype=torch.float32)     #数值加载到device
            labels = labels.to(device, dtype=torch.long)

            outputs = model(images)    #图像images进入模型,然后输出为outputs
            preds = outputs.detach().max(dim=1)[1].cpu().numpy()  #detach()方法从原计算图返回tensor并不影响原图。max返回其中最大结果的值([0])和索引([1]),即此处返回索引。放到cpu,转为numpy格式。
            targets = labels.cpu().numpy()

            metrics.update(targets, preds)  #比较实际结果和预测值,详见metrics文件夹下的stream_metrics.py文件。
            if ret_samples_ids is not None and i in ret_samples_ids:  # get vis samples获得可视化样本
                ret_samples.append(
                    (images[0].detach().cpu().numpy(), targets[0], preds[0]))

            if opts.save_val_results:     #如保存validation结果,见get_argparser函数第38行
                for i in range(len(images)):
                    image = images[i].detach().cpu().numpy()
                    target = targets[i]
                    pred = preds[i]

					#下面的.transpose(1,2,0),将数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),转换后才可以显示。
                    image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8)  #归一化除以了255,还原*255.
                    target = loader.dataset.decode_target(target).astype(np.uint8)
                    pred = loader.dataset.decode_target(pred).astype(np.uint8)

                    Image.fromarray(image).save('results/%d_image.png' % img_id)
                    Image.fromarray(target).save('results/%d_target.png' % img_id)
                    Image.fromarray(pred).save('results/%d_pred.png' % img_id)

                    fig = plt.figure()
                    plt.imshow(image)
                    plt.axis('off')
                    plt.imshow(pred, alpha=0.7)
                    ax = plt.gca()
                    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) #主刻度设置,详见下面的补充链接。
                    ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0)
                    plt.close()
                    img_id += 1

        score = metrics.get_results()   #获得结果,详见metrics文件夹下的stream_metrics.py文件。
    return score, ret_samples

Tips

  1. 补充torch.utils.data.DataLoader参数说明
  2. 补充matplotlib绘图常用方法总结
  3. 下一个函数是主函数main。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值