Cityscapse 数据集使用 + 训练 STDC踩坑

官网地址:Cityscapes Dataset – Semantic Understanding of Urban Street Scenes (cityscapes-dataset.com)

getFine 文件下存储标注文件

leftImg8bit 文件夹下存储真实图片

两个文件下文件目录级别是一样的:

        getFine/liftImg8bit  ->  test / train / val  ->  城市目录  ->  labels / imgs

cityscapse 数据集内的图片均为 2048 x 1024, 总共5000张精细释,2975张训练图,500张验证图和1525张测试图, 共有19个类别。

下面着重说一下getFine下的标注文件:

xxx_gtFine_color.png,  : 标注的可视化图片

xxx_gtFine_instanceIds.png :是用来做实例分割训练用的

xxx_gtFine_labelsIds.png :是语义分割训练需要的,它们的像素值就是class值

xxx_gtFine_polygons.json :用labelme工具标注后所生成的文件,里面主要记录了每个多边形标注框上的点集坐标。

对于 cityscapse_info.json 文件如下:

Label = namedtuple('Label', [
                   'name', 
                   'id', 
                   'trainId', 
                   'category', 
                   'categoryId', 
                   'hasInstances', 
                   'ignoreInEval', 
                   'color'])
 
labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        1 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        2 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        3 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        4 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        5 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        6 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        7 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        8 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        9 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,       10 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       11 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       12 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       13 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       15 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       17 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       18 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       19 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]



    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                    # We use them to uniquely name a class

    'id'          , # An integer ID that is associated with this label.
                    # The IDs are used to represent the label in ground truth images
                    # An ID of -1 means that this label does not have an ID and thus
                    # is ignored when creating ground truth images (e.g. license plate).
                    # Do not modify these IDs, since exactly these IDs are expected by the
                    # evaluation server.

    'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                    # ground truth images with train IDs, using the tools provided in the
                    # 'preparation' folder. However, make sure to validate or submit results
                    # to our evaluation server using the regular IDs above!
                    # For trainIds, multiple labels might have the same ID. Then, these labels
                    # are mapped to the same class in the ground truth images. For the inverse
                    # mapping, we use the label that is defined first in the list below.
                    # For example, mapping all void-type classes to the same ID in training,
                    # might make sense for some approaches.
                    # Max value is 255!

    'category'    , # The name of the category that this label belongs to

    'categoryId'  , # The ID of this category. Used to create ground truth images
                    # on category level.

    'hasInstances', # Whether this label distinguishes between single instances or not

    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                    # during evaluations or not

    'color'       , # The color of this label

如果将自己的数据集改成cityscapse格式:

1、先确定好自己的类别 n_classes,另外背景也要算一类!比如分割人的二值图像,那就是两类,未标注的黑色背景算一类,标注的白色人形算一类。

2、更改cityscapse_info.json文件,文件总共包含34个字典,每个字典代表一个label,格式如下:

{
    "hasInstances": false,
    "category": "void",
    "catid": 0,
    "name": "unlabeled",
    "ignoreInEval": false,
    "id": 0,
    "color": [
      0,
      0,
      0
    ],
    "trainId": 1
  }

更改时,整体框架不要动!! 只需要将用不到的类别做两个操作:

        1)、将“ignoreInEval”设置为true

        2)、将“trainId”设置为255

对于需要保留的类别:

        1)、"catid" 从0开始计数,依次递增,总共有3个类别,那就是:0、1、2

        2)、“id”从0开始计数,依次递增,总共有3个类别,那就是:0、1、2。该值为标注图片中对                  应类别对象区域的像素值。

        3)、“ignoreInEval” 设置为 false

        4)、“trainId”从0开始计数,依次递增,总共有3个类别,那就是:0、1、2

3、更改自己的标注图片:

        1、自己数据集的原始图片 和 标注后的mask灰度图片 名称中,应该按照官方cityscapse数据集的格式命名,即:

                原始图片名称:  ***_leftImg8bit.png

                标注图片名称: ***_gtFine_labelIds.png

        重命名代码:

import os
import tqdm
from PIL import Image


gt_path = './gtFine_trainvaltest/gtFine/train/qdu/'
img_path = './leftImg8bit_trainvaltest/leftImg8bit/train/qdu/'

gt_suffix = '_gtFine_labelIds.png'
img_suffix = '_leftImg8bit.png'

gtlist = os.listdir(gt_path)
for gt in tqdm.tqdm(gtlist):

    newname = gt[:-4] + gt_suffix
    src = os.path.join(gt_path, gt)
    dst = os.path.join(gt_path, newname)
    os.rename(src, dst)

    newname = gt[:-4] + img_suffix
    src = os.path.join(img_path, gt)
    dst = os.path.join(img_path, newname)
    os.rename(src, dst)

       

         2、标注图片应该为单通道图片,多通道转单通道图片代码:

import os
import cv2
import tqdm
import numpy as np
from PIL import Image


gt_path = './gtFine_trainvaltest/gtFine/train/qdu/'


gtlist = os.listdir(gt_path)
for gt in tqdm.tqdm(gtlist):
    read_path = os.path.join(gt_path, gt)
    
    mask = Image.open(read_path)
    img = np.array(mask)
    print(img.shape, mask.size)

    #导入 opencvimport cv2  #读入原始图像,使用 cv2.IMREAD_UNCHANGED
    img = cv2.imread(read_path, cv2.IMREAD_UNCHANGED)  #查看原始图像
    shape = img.shape
    if len(shape) == 2:
     continue
    # print(shape)

    if shape[2] == 3 or shape[2] == 4:      
     img_gray = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)    #将彩色图转化为单通道图
    # print(img_gray.shape) #查看转换后图像通道
    # cv2.imshow("gray_img", img_gray)
    # cv2.imshow("image", img)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()

    cv2.imwrite(read_path, img_gray)

    # break

        3、标注图片及原图尺寸应该为 2048 x 1024, 更改图片尺寸代码:

import os
import tqdm
from PIL import Image


gt_path = './gtFine_trainvaltest/gtFine/train/qdu/'
img_path = './leftImg8bit_trainvaltest/leftImg8bit/train/qdu/'

gtlist = os.listdir(gt_path)
for gt in tqdm.tqdm(gtlist):

    rdpath = os.path.join(gt_path, gt)
    mask = Image.open(rdpath)
    mask = mask.resize((2048, 1024), Image.ANTIALIAS)
    mask.save(rdpath)

    gt = gt.replace('_gtFine_labelIds.png', '_leftImg8bit.png')
    rdpath = os.path.join(img_path, gt)
    image = Image.open(rdpath)
    image = image.resize((1024, 2048), Image.ANTIALIAS)
    image.save(rdpath)

        4、 对于语义分割,标注mask图片中的像素值即为类别值,某个类别对象区域的像素值应该为该类别的“id”的值。其他未标注的区域设置为0即可。和cityscapse_info.json文件中的“id”值对应。更改像素值代码参考:

import os
import cv2
import tqdm
import glob
import numpy as np
from PIL import Image

np.set_printoptions(threshold=np.inf)
gt_path = './gtFine_trainvaltest/gtFine/train/qdu/'
gtlist = os.listdir(gt_path)

for gt in tqdm.tqdm(gtlist):

    read_path = os.path.join(gt_path, gt)
    mask = Image.open(read_path)


    img = np.array(mask)
    img[img < 100] = 0
    img[img > 100] = 1

    image = Image.fromarray(img)
    image.save(read_path)

#    img = np.array(mask)
#    print(img.shape, img)
 
    break

STDC 踩坑:

RuntimeError: transform: failed to synchronize: cudaErrorIllegalAddress: an illegal memory access

 报上述错误:

        1)、标注文件 / 标签 等没有设置正确

        2)、网络模型、训练数据(图片和标签)没有放置到GPU上

        3)、缩小batch size

        4)、python版本不匹配(可能性不大,我的python3.7 + torhc1.7正常跑)

RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size

报上述错误,标注文件通道不对! 语义分割,cityscape格式数据集,标注mask图片应该为单通道图片! 如果是多通道图片就会出现上述出错。 其他问题可以参考: ​​​​​​RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size - DuanYongchun - 博客园 (cnblogs.com)

RuntimeError: Some elements marked as dirty during the forward method were not returned as output. The inputs that are modified inplace must all be outputs of the Function.

 报上述错误,BatchNorm2d函数不对,建议使用torch官方正则化函数,即nn.BatchNorm2d。 其他问题参考:RuntimeError: Some elements marked as dirty during the forward method were not returned as output. The inputs that are modified inplace must all be outputs of the Function. · Issue #267 · zhanghang1989/PyTorch-Encoding (github.com)

RuntimeError: stack expects each tensor to be equal size, but got [3, 256, 341] at entry 0

报上述类似错误,标注mask图片和真实图片以及网络输出的维度不对应,网络输出一般为(batch, class, width ,height)

RuntimeError: shape '[2, 2]' is invalid for input of size 202

报上述类似错误,cityscape_info.json 文件有误 或者 mask图片像素值有误

测试:

保存预测mask

#!/usr/bin/python
# -*- encoding: utf-8 -*-

import logging
import math
import os
import os.path as osp
import time
from PIL import Image
import matplotlib.pyplot as plot
from pathlib import Path
import torchvision.transforms as transforms

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from cityscapes import CityScapes
from logger import setup_logger
from models.model_stages import BiSeNet
from torch.utils.data import DataLoader
from tqdm import tqdm
from transform import *


np.set_printoptions(threshold = 1e9999)


to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])



class MscEvalV0(object):

    def __init__(self, readpath, save_path, scale=0.5, ignore_label=255):
        self.ignore_label = ignore_label
        self.scale = scale
        self.readpath = readpath
        self.save_path = save_path

    def __call__(self, net, n_classes):
        # evaluate
        hist = torch.zeros(n_classes, n_classes).cuda().detach()


        for i, imagepath in enumerate(self.readpath.iterdir()):

            imgs = Image.open(imagepath).convert('RGB') 
            imgs = to_tensor(imgs)
            imgs = imgs[None,:]
            imgnames = imagepath.name
            imgs = imgs.cuda()

            N, C, H, W = imgs.size()
            new_hw = [int(H * self.scale), int(W * self.scale)]

            imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
            logits = net(imgs)[0]
            logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            preds = preds.cpu().numpy()
            preds[preds>0] = 255
            img = np.array(preds[0])
            predgt = Image.fromarray(np.uint8(img)).convert('L')
            predgt.save(f'{self.save_path}/{imgnames}')
            
            print('\r', i, end='', flush=True)




def evaluatev0(respth='./pretrained', backbone='CatNetSmall', test_path=None, save_path=None, scale=0.75, 
               use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):

    print('scale', scale)
    print('use_boundary_2', use_boundary_2)
    print('use_boundary_4', use_boundary_4)
    print('use_boundary_8', use_boundary_8)
    print('use_boundary_16', use_boundary_16)
    print("backbone:", backbone)
    
    # dataset
    batchsize = 5
    n_workers = 2
    n_classes = 2
    
    # model
    net = BiSeNet(backbone=backbone, n_classes=n_classes,
                  use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
                  use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
                  use_conv_last=use_conv_last)
    net.load_state_dict(torch.load(respth))
    net.cuda()
    net.eval()

    with torch.no_grad():
        single_scale = MscEvalV0(readpath=test_path, save_path=save_path, scale=scale)
        single_scale(net, n_classes)




if __name__ == "__main__":
    log_dir = 'evaluation_logs/origin/'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    setup_logger(log_dir)
    
    test_path = Path('data/leftImg8bit/test/qdu')
    model_path = './checkpoints/train_STDC2-Seg/pths/model_maxmIOU75.pth'

    # STDC1-Seg50 mIoU 0.7222
    evaluatev0(model_path, test_path=test_path, save_path=log_dir, backbone='STDCNet1446', scale=0.75, use_boundary_8=True)
    
    

从mask黑白图生成外围轮廓json

import cv2
import json
import tqdm
import random
import numpy as np
from pathlib import Path

PAD = 10
mask_root = Path('evaluation_logs/origin')
pad_mask = Path('evaluation_logs/padding')
image_root = Path('data/leftImg8bit/test/qdu')
json_root = Path('evaluation_logs/jsons')


# 填充mask,使其有黑边,方便查找轮廓
def padMask():
    for mask_path in mask_root.iterdir():

        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        maskpad = np.pad(mask, PAD, 'constant')
        print(f'{mask.shape}  -->  {maskpad.shape}')

        pad_mask.mkdir(parents=True, exist_ok=True)
        save_path = pad_mask / mask_path.name
        cv2.imwrite(str(save_path), maskpad)


# 获取轮廓,生成json
def generateJson():
    for mask_path in tqdm.tqdm(pad_mask.iterdir()):
        name = mask_path.name
        label = "person"
        item = dict(name = label, points = None)

        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        contours, hierarchy = cv2.findContours(mask,
                                               cv2.RETR_EXTERNAL,
                                               cv2.CHAIN_APPROX_SIMPLE)
        # print(type(contours), len(contours))

        maxlen, outline = 0, []
        H, W = mask.shape
        for contour in contours:
            # print(type(contour), len(contour), contour.shape)

            if len(contour) <= maxlen:
                continue
            maxlen = len(contour)

            # 获取相对原图坐标点
            # print("process...")
            contour = contour[:, 0, :]
            contour -= PAD                      # 减去填充导致的位置便宜
            contour[contour < 0] = 0            # 小于0的位置设置为0
            contour[contour[:, 0] > W, 0] = W   # 坐标超过右边界的,设置为宽度
            contour[contour[:, 1] > H, 1] = H   # 坐标超过下边界的,设置为高度

            # 加入列表
            contour = contour.astype(np.float)
            outline = contour.tolist()
        # print(maxlen)

        item['points'] = outline
        item = [item]
        json_root.mkdir(parents=True, exist_ok=True)
        with open(json_root / name.replace('jpg', 'json'), 'w') as handle:
            json.dump(item, handle)

        # break


# 获取轮廓,生成json
def vasualization():
    for mask_path in mask_root.iterdir():
        name = mask_path.name

        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        contours, hierarchy = cv2.findContours(mask,
                                               cv2.RETR_EXTERNAL,
                                               cv2.CHAIN_APPROX_SIMPLE)
        print(type(contours), len(contours))

        outlines = []
        H, W = mask.shape
        for contour in contours:
            print(type(contour), len(contour), contour.shape)

            # 轮廓点少于20个,直接跳过舍弃
            if len(contour) < 20:
                continue

            # 获取相对原图坐标点
            print("process...")
            contour = contour[:, 0, :]
            contour -= PAD                      # 减去填充导致的位置便宜
            contour[contour < 0] = 0            # 小于0的位置设置为0
            contour[contour[:, 0] > W, 0] = W   # 坐标超过右边界的,设置为宽度
            contour[contour[:, 1] > H, 1] = H   # 坐标超过下边界的,设置为高度

            # 加入列表
            contour = contour.tolist()
            outlines.append(contour)

        print("outlines : ", len(outlines))

        imagepath = image_root / name
        image = cv2.imread(str(imagepath), cv2.IMREAD_COLOR)

        for otl in outlines:
            c1 = random.randint(0,255)
            c2 = random.randint(0,255)
            c3 = random.randint(0,255)
            for point in otl:
                cv2.circle(image, point, 2, (c1,c2,c3))

        cv2.namedWindow('show', 0)
        cv2.imshow('show', image)
        cv2.waitKey(0)


        break


if __name__ == '__main__':
    padMask()
    generateJson()

可视化分割mask,即:mask+原图

import cv2
import numpy as np
from tqdm import tqdm
from PIL import Image
from pathlib import Path

image_root = Path('data/leftImg8bit/test/02/images')
mask_root = Path('evaluation_logs/origin')
save_root = Path('evaluation_logs/visual')

for mask in tqdm(mask_root.iterdir()):
    name = mask.name
    imagepath = image_root / name

    # mask = Image.open(mask)
    # image = Image.open(imagepath)
    # print(mask.mode)        # L
    # print(image.mode)       # RGB
    mask = cv2.imread(str(mask), cv2.IMREAD_GRAYSCALE)
    image = cv2.imread(str(imagepath), cv2.IMREAD_COLOR)
    # print(mask.shape)       # 1080 1920
    # print(image.shape)      # 1080 1920 3


    image = image.astype(np.float64)
    image[mask > 100] = (image[mask > 100] * 0.6).astype(np.int64)
    image[mask > 100] += np.array([100,0,0], dtype=np.int64)

    save_root.mkdir(parents=True, exist_ok=True)
    sp = save_root / name
    cv2.imwrite(str(sp), image)

可视化全景分割结果:

#!/usr/bin/python
# -*- encoding: utf-8 -*-

import logging
import math
import os
import os.path as osp
import time
from PIL import Image
import matplotlib.pyplot as plot

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from cityscapes import CityScapes
from logger import setup_logger
from models.model_stages import BiSeNet
from torch.utils.data import DataLoader
from tqdm import tqdm


#ignore_label=255,
#label_mapping = {-1: ignore_label, 0: ignore_label,
#                  1: ignore_label, 2: ignore_label,
#                  3: ignore_label, 4: ignore_label,
#                  5: ignore_label, 6: ignore_label,
#                  7: 0, 8: 1, 9: ignore_label,
#                  10: ignore_label, 11: 2, 12: 3,
#                  13: 4, 14: ignore_label, 15: ignore_label,
#                  16: ignore_label, 17: 5, 18: ignore_label,
#                  19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
#                  25: 12, 26: 13, 27: 14, 28: 15,
#                  29: ignore_label, 30: ignore_label,
#                  31: 16, 32: 17, 33: 18}
ignore_label=255,
label_mapping = {-1: ignore_label, 0: 0,
                  1: ignore_label, 2: ignore_label,
                  3: ignore_label, 4: ignore_label,
                  5: ignore_label, 6: ignore_label,
                  7: ignore_label, 8: ignore_label, 9: ignore_label,
                  10: ignore_label, 11: ignore_label, 12: ignore_label,
                  13: ignore_label, 14: ignore_label, 15: ignore_label,
                  16: ignore_label, 17: ignore_label, 18: ignore_label,
                  19: ignore_label, 20: ignore_label, 21: ignore_label, 
                  22: ignore_label, 23: ignore_label, 24: 1,
                  25: ignore_label, 26: ignore_label, 27: ignore_label, 
                  28: ignore_label, 29: ignore_label, 30: ignore_label,
                  31: ignore_label, 32: ignore_label, 33: ignore_label}

def convert_label(label, inverse=False):
    temp = label.copy()
    if inverse:
        for v, k in label_mapping.items():
            label[temp == k] = v
    else:
        for k, v in label_mapping.items():
            label[temp == k] = v
    return label

def get_palette(n):
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3
    return palette

def save_pred(preds, sv_path, name):
    palette = get_palette(256)
    preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
    for i in range(preds.shape[0]):
        pred = convert_label(preds[i], inverse=True)
        save_img = Image.fromarray(pred)
        save_img.putpalette(palette)
        save_img.save(os.path.join(sv_path, f'{name[i]}.png'))


class MscEvalV0(object):

    def __init__(self, scale=0.5, ignore_label=255):
        self.ignore_label = ignore_label
        self.scale = scale

    def __call__(self, net, dl, n_classes):
        # evaluate
        hist = torch.zeros(n_classes, n_classes).cuda().detach()
        if dist.is_initialized() and dist.get_rank() != 0:
            diter = enumerate(dl)
        else:
            diter = enumerate(tqdm(dl))
        for i, (imgs, label, imgnames) in diter:
        
            # 跑 1500 张图片
            if i == 1500:
                break

            N, _, H, W = label.shape

            label = label.squeeze(1).cuda()
            size = label.size()[-2:]

            imgs = imgs.cuda()

            N, C, H, W = imgs.size()
            new_hw = [int(H * self.scale), int(W * self.scale)]

            imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
            logits = net(imgs)[0]
            logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)

            # print(imgnames)
            save_pred(logits, './evaluation_logs/origin/', imgnames)

            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)
            keep = label != self.ignore_label
            hist += torch.bincount(
                label[keep] * n_classes + preds[keep],
                minlength=n_classes ** 2
            ).view(n_classes, n_classes).float()
        if dist.is_initialized():
            dist.all_reduce(hist, dist.ReduceOp.SUM)
        ious = hist.diag() / (hist.sum(dim=0) + hist.sum(dim=1) - hist.diag())
        miou = ious.mean()
        return miou.item()


def evaluatev0(respth='./pretrained', dspth='./data', backbone='CatNetSmall', scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
    print('scale', scale)
    print('use_boundary_2', use_boundary_2)
    print('use_boundary_4', use_boundary_4)
    print('use_boundary_8', use_boundary_8)
    print('use_boundary_16', use_boundary_16)
    # dataset
    batchsize = 5
    n_workers = 2
    dsval = CityScapes(dspth, mode='val')
    dl = DataLoader(dsval,
                    batch_size=batchsize,
                    shuffle=False,
                    num_workers=n_workers,
                    drop_last=False)

    n_classes = 2
    print("backbone:", backbone)
    net = BiSeNet(backbone=backbone, n_classes=n_classes,
                  use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
                  use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
                  use_conv_last=use_conv_last)
    net.load_state_dict(torch.load(respth))
    net.cuda()
    net.eval()

    with torch.no_grad():
        single_scale = MscEvalV0(scale=scale)
        mIOU = single_scale(net, dl, n_classes)
    logger = logging.getLogger()
    logger.info('mIOU is: %s\n', mIOU)


if __name__ == "__main__":
    log_dir = 'evaluation_logs/origin/'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    setup_logger(log_dir)

    # STDC1-Seg50 mIoU 0.7222
    evaluatev0('./checkpoints/train_STDC2-Seg/pths/model_maxmIOU75.pth', dspth='./data', backbone='STDCNet1446', scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)

    # STDC1-Seg75 mIoU 0.7450
    # evaluatev0('./checkpoints/STDC1-Seg/model_maxmIOU75.pth', dspth='./data', backbone='STDCNet813', scale=0.75,
    # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)

    # STDC2-Seg50 mIoU 0.7424
    # evaluatev0('./checkpoints/STDC2-Seg/model_maxmIOU50.pth', dspth='./data', backbone='STDCNet1446', scale=0.5,
    # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)

#    # STDC2-Seg75 mIoU 0.7704
#    imagepath = './data'
#    checkpoints_path = './checkpoints/train_STDC2-Seg/pths/model_maxmIOU75.pth'
#    #checkpoints_path = './checkpoints/STDC2-Seg/model_maxmIOU50.pth'
#    evaluatev0(respth=checkpoints_path, dspth=imagepath, backbone='STDCNet1446',
#               scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)

reference:

1、cityscapes数据集 - learningcaiji - 博客园 (cnblogs.com)

2、图像语意分割Cityscapes训练数据集使用方法分享 - 知乎 (zhihu.com)

3、 Cityscapes数据集的深度完整解析_MVandCV的博客-CSDN博客

  • 7
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 17
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值