TorchSeg代码学习笔记(二:data和augmentation)

本文以cityScapes和BiSeNet.R18为例

  • BaseDataset

TorchSeg的dataset定义在furnace/datasets/BaseDataset.py中,根据setting来提供基本的图像和label的读取功能。而setting定义在各个网络的config.py中定义。

  • cityscapes数据集

cityScapes数据集的定义在furnace/datasets/cityscapes/cityscapes.py

import numpy as np
from datasets.BaseDataset import BaseDataset

class Cityscapes(BaseDataset):
    trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
                    28, 31, 32, 33]

    @classmethod
    def get_class_colors(*args):
        return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
                [102, 102, 156], [190, 153, 153], [153, 153, 153],
                [250, 170, 30], [220, 220, 0], [107, 142, 35],
                [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
                [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
                [0, 0, 230], [119, 11, 32]]

    @classmethod
    def get_class_names(*args):
        # class counting(gtFine)
        # 2953 2811 2934  970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832
        # 359  274  142  513 1646
        return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
                'traffic light', 'traffic sign',
                'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
                'truck', 'bus', 'train', 'motorcycle', 'bicycle']

    @classmethod
    def transform_label(cls, pred, name):
        label = np.zeros(pred.shape)
        ids = np.unique(pred)
        for id in ids:
            label[np.where(pred == id)] = cls.trans_labels[id]

        new_name = (name.split('.')[0]).split('_')[:-1]
        new_name = '_'.join(new_name) + '.png'

        print('Trans', name, 'to', new_name, '    ',
              np.unique(np.array(pred, np.uint8)), ' ---------> ',
              np.unique(np.array(label, np.uint8)))
        return label, new_name

一般来说,文章关注的分类有19类,定义在get_class_names中,get_class_names为其所对应的颜色RGB值。所有分类的定义可以参见cityscapesScriptstransform_label似乎并没有在任何代码中被使用到。

  • BiSeNet.R18的dataloader

对于不同网络的dataloader,他们被定义在对应网络的文件夹中,以BiSeNet.R18为例,dataloader定义在model/bisenet/cityscapes.bisenet.R18/dataloader.py中。

import cv2
import torch
import numpy as np
from torch.utils import data
from config import config
from utils.img_utils import random_scale, random_mirror, normalize, \
    generate_random_crop_pos, random_crop_pad_to_shape

class TrainPre(object):
    def __init__(self, img_mean, img_std):
        self.img_mean = img_mean
        self.img_std = img_std

    def __call__(self, img, gt):
        img, gt = random_mirror(img, gt)
        if config.train_scale_array is not None:
            img, gt, scale = random_scale(img, gt, config.train_scale_array)
        img = normalize(img, self.img_mean, self.img_std)
        crop_size = (config.image_height, config.image_width)
        crop_pos = generate_random_crop_pos(img.shape[:2], crop_size)
        p_img, _ = random_crop_pad_to_shape(img, crop_pos, crop_size, 0)
        p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 255)
        p_img = p_img.transpose(2, 0, 1)
        extra_dict = None
        return p_img, p_gt, extra_dict

def get_train_loader(engine, dataset):
    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std)
    train_dataset = dataset(data_setting, "train", train_preprocess,
                            config.batch_size * config.niters_per_epoch)
    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size
    
    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=True,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler

TrainPre类中可以定义数据增强的具体函数,这些函数被定义在furnace/utils/img_utils.py中。具体支持的函数如下:

def get_2dshape(shape, *, zero=True)
def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value)
def generate_random_crop_pos(ori_size, crop_size)
def pad_image_to_shape(img, shape, border_mode, value)
def pad_image_size_to_multiples_of(img, multiple, pad_value)
def resize_ensure_shortest_edge(img, edge_length, interpolation_mode=cv2.INTER_LINEAR)
def random_scale(img, gt, scales)
def random_scale_with_length(img, gt, length)
def random_mirror(img, gt)
def random_rotation(img, gt)
def random_gaussian_blur(img)
def center_crop(img, shape)
def random_crop(img, gt, size)
def normalize(img, mean, std)
def findContours(*args, **kwargs)

get_train_loader函数根据config设置数据增强的函数,并将其传递给对应的dataset中(在此处dataset=CityScapes)。最终将dataset送进DataLoader中。此外,engine类被用于设置分布式训练的设置。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值