本文以cityScapes和BiSeNet.R18为例
- BaseDataset
TorchSeg的dataset定义在furnace/datasets/BaseDataset.py中,根据setting来提供基本的图像和label的读取功能。而setting定义在各个网络的config.py中定义。
- cityscapes数据集
cityScapes数据集的定义在furnace/datasets/cityscapes/cityscapes.py
import numpy as np
from datasets.BaseDataset import BaseDataset
class Cityscapes(BaseDataset):
trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 31, 32, 33]
@classmethod
def get_class_colors(*args):
return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
[102, 102, 156], [190, 153, 153], [153, 153, 153],
[250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
[0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
[0, 0, 230], [119, 11, 32]]
@classmethod
def get_class_names(*args):
# class counting(gtFine)
# 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832
# 359 274 142 513 1646
return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign',
'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'train', 'motorcycle', 'bicycle']
@classmethod
def transform_label(cls, pred, name):
label = np.zeros(pred.shape)
ids = np.unique(pred)
for id in ids:
label[np.where(pred == id)] = cls.trans_labels[id]
new_name = (name.split('.')[0]).split('_')[:-1]
new_name = '_'.join(new_name) + '.png'
print('Trans', name, 'to', new_name, ' ',
np.unique(np.array(pred, np.uint8)), ' ---------> ',
np.unique(np.array(label, np.uint8)))
return label, new_name
一般来说,文章关注的分类有19类,定义在get_class_names中,get_class_names为其所对应的颜色RGB值。所有分类的定义可以参见cityscapesScripts。transform_label似乎并没有在任何代码中被使用到。
- BiSeNet.R18的dataloader
对于不同网络的dataloader,他们被定义在对应网络的文件夹中,以BiSeNet.R18为例,dataloader定义在model/bisenet/cityscapes.bisenet.R18/dataloader.py中。
import cv2
import torch
import numpy as np
from torch.utils import data
from config import config
from utils.img_utils import random_scale, random_mirror, normalize, \
generate_random_crop_pos, random_crop_pad_to_shape
class TrainPre(object):
def __init__(self, img_mean, img_std):
self.img_mean = img_mean
self.img_std = img_std
def __call__(self, img, gt):
img, gt = random_mirror(img, gt)
if config.train_scale_array is not None:
img, gt, scale = random_scale(img, gt, config.train_scale_array)
img = normalize(img, self.img_mean, self.img_std)
crop_size = (config.image_height, config.image_width)
crop_pos = generate_random_crop_pos(img.shape[:2], crop_size)
p_img, _ = random_crop_pad_to_shape(img, crop_pos, crop_size, 0)
p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 255)
p_img = p_img.transpose(2, 0, 1)
extra_dict = None
return p_img, p_gt, extra_dict
def get_train_loader(engine, dataset):
data_setting = {'img_root': config.img_root_folder,
'gt_root': config.gt_root_folder,
'train_source': config.train_source,
'eval_source': config.eval_source}
train_preprocess = TrainPre(config.image_mean, config.image_std)
train_dataset = dataset(data_setting, "train", train_preprocess,
config.batch_size * config.niters_per_epoch)
train_sampler = None
is_shuffle = True
batch_size = config.batch_size
if engine.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
batch_size = config.batch_size // engine.world_size
is_shuffle = False
train_loader = data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=config.num_workers,
drop_last=True,
shuffle=is_shuffle,
pin_memory=True,
sampler=train_sampler)
return train_loader, train_sampler
在TrainPre类中可以定义数据增强的具体函数,这些函数被定义在furnace/utils/img_utils.py中。具体支持的函数如下:
def get_2dshape(shape, *, zero=True)
def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value)
def generate_random_crop_pos(ori_size, crop_size)
def pad_image_to_shape(img, shape, border_mode, value)
def pad_image_size_to_multiples_of(img, multiple, pad_value)
def resize_ensure_shortest_edge(img, edge_length, interpolation_mode=cv2.INTER_LINEAR)
def random_scale(img, gt, scales)
def random_scale_with_length(img, gt, length)
def random_mirror(img, gt)
def random_rotation(img, gt)
def random_gaussian_blur(img)
def center_crop(img, shape)
def random_crop(img, gt, size)
def normalize(img, mean, std)
def findContours(*args, **kwargs)
get_train_loader函数根据config设置数据增强的函数,并将其传递给对应的dataset中(在此处dataset=CityScapes)。最终将dataset送进DataLoader中。此外,engine类被用于设置分布式训练的设置。