RESA: Recurrent Feature-Shift Aggregator for Lane Detection 论文阅读+代码复现(车道线检测)

RESA: Recurrent Feature-Shift Aggregator for Lane Detection

  • 现在的算法典型地应用像素级分割公式, 把车道线检测看作一个分割问题, 图像中的每个像素被二值标记标示是否属于车道线.
  • 这些工作用encoder-decoder框架
  • 首先,应用CNN作为编码器来提取高级语义特征到特征图, 然后用上采样解码器来恢复特征图到原始尺寸并且最终执行像素级预测.
  • 车道线细长,样本比例不平衡,导致提取敏感特征且可能忽略形状先验和车道线间的相关性, 降低检测表现
  • 可能被拥挤的车辆挡住,我们只能凭借常识推断车道线位置
  • SCNN提出空间卷积来在行列之间传播信息,然而信息传递十分费时,导致推理速度较慢.同时在相邻的行或列之间依次传递,迭代次数多,可能导致长距离信息丢失.
  • 这篇文章提出REcurrent Feature-Shift Aggregator (RESA),在特征图内聚集信息,并且更直接有效地传递空间信息.

advantages:

  1. RESA passes information in a parallel way, thus can reduce time cost greatly.
    1)平行传递信息,减少时间消耗
  2. Information will be passed with different strides in RESA, thus different sliced feature maps
    can be gathered without information loss during propagation.
  3. 在RESA中信息以不同的步长传播,因此不同切片的特征图会在传播的时候没有损耗.
  4. RESA is simple and flexible to be incorporated into other networks.
    3)RESA简单灵活能够添加到其他网络.

上采样

  • 提出了双边上采样译码器, 一个分支捕捉粗粒度特征,一个分支捕捉细粒度特征
  • 粗粒度分支直接应用双线性上采样生成模糊图像,细粒度分支用转置卷积来实现上采样,并且用两个non-bottleneck blocks来修复细节损失.
  • 结合两个分支,我们的解码器可以细致地将低分辨率特征图恢复为像素级预测.

数据集

CULane and Tusimple

贡献

  • 提出RESA聚合空间信息的结构
  • 双边上采样译码器
  • 在数据集上效果比较好

相关工作

traditional methods
  • . Sun, Tsai, and Chan (2006) tries to detect lanes in HSI color representation
  • Yu and Jain (1997) extracts lane boundaries via Hough Transform
deep learning method
  • . Huval et al. (2015) are the first to apply deep learning method in lane detection with CNN
  • Neven et al. (2018) propose to cast the lane detection problem as an instance segmentation problem
  • . Philion (2019) integrates the lane decoding step into the network and draws lanes iteratively without recurrent neural network.
  • Self-attention distillation (SAD) is proposed to allow a model to learn from itself and gains substantial improvement without any additional supervision or labels (Hou et al. 2019).
空间信息利用
  • ION (Bell et al. 2016) explores the use of spatial Recurrent Neural Networks (RNNs).
  • . Liang et al. (2016) constructs Graph LSTM to provide information propagation route for semantic object parsing.
  • SCNN (Pan et al. 2018) proposes to generalize traditional layer-by-layer convolutions to slice-by-slice convolutions within feature maps,

方法

  • 模型分为三部分: encoder, aggregator, and decoder.
  • We select commonly used backbone like ResNet (He et al. 2016), VGG (Simonyan and Zisserman 2015), and etc as our encoder to extract preliminary feature from raw image.
  • 用RESA 模块来进行车道线信息的聚合
  • 双边上采样译码器
框架设计

网络结构如下图所示:
在这里插入图片描述

  1. 编码器: 常用骨干网络VGG,ResNet等, 进行特征提取,图像变为 1/8,提取初步特征
  2. RESA模块聚合空间信息,每次迭代四个方向传递信息,RESA模块总共迭代K次来保证每个位置都能接收全图信息.
  3. 双边上采样模块.每个块上采样2次,把1/8图像恢复.
  • After up-sampled by decoder, the output feature map is used to predict existence and probability distribution of each lane.
  • For existence prediction, a fully-connected layer is followed and a 0-1 classification.
RESA

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述在这里插入图片描述

优点
  1. Computationally efficient
  2. Feature information gathered effectively.
  3. Easy to be plugged into other network.
双边上采样译码器
  • 多数解码器用双线性上采样(bilinear upsampling)过程来进行像素级预测 ,容易获得粗糙的结果,但是可能丢失细节.
  • Some methods (Romera et al. 2017) use stacking convolutional operations and deconvolutional operations to obtain refined upsampling results.
  • 针对上述动机,我们结合它们的优点,提出了双边上采样解码器。
    在这里插入图片描述
Coarse grained branch
  • 1x1卷积–BN层–双线性上采样–ReLU
Fine detailed branch

transpose convolution with stride 2 – ReLU-- Nonbottleneck block ---- Nonbottleneck block(consists of four 3 × 1 and 1 × 3 convolutions with BN and ReLU)

Experiment

We use SGD (Bottou 2010) with momentum 0.9 and weight decay 1e-4 as the optimizer to train our model and the learning rate is set 1.6e-2 for CULane and 2.5e-2 for Tusimple respectively. We usewarm-up (Doll, Girshick, and Noordhuis 2017) strategy in first 500 batches and then apply polynomial learning rate decay policy (Mishra and Sarawadekar 2019) with power set to 0.9.

  • 损失函数 segmentation BCE loss and existence classification CE loss.
  • Considering the imbalanced label between background and lane markings, the segmentation loss of background is multiplied by 0.4.
  • The batch size is set 8 for CULane and 16 for Tusimple respectively.
  • The total number of training epoch is set 50 for TuSimple dataset and 12 for CULane dataset.
  • All models are trained with 4 Nvidia 1080ti GPUs. All experiments are implemented with Pytorch.
  • In our experiments, we use ResNet (He et al. 2016) and VGG (Simonyan and Zisserman 2014) as backbone. In ResNet, we add extra 1 × 1 convolution to reduce the output channel to 128. The modification of VGG is same as SCNN (Pan et al. 2018)

代码复现

code:https://github.com/ZJULearning/resa/tree/fe4e7314ebfb2de17f5b75539cb2ed3c28bf6f96

安装环境

conda create -n resa python=3.8 -y
conda activate resa
pip install torch torchvision
pip install -r requirement.txt

数据集

TuSimple
  • structure
$TUSIMPLEROOT/clips # data folders
$TUSIMPLEROOT/lable_data_xxxx.json # label json file x3
$TUSIMPLEROOT/test_tasks_0627.json # test tasks json file
$TUSIMPLEROOT/test_label.json # test label json file
  • 没有分割标记,从json文件中产生分割标记
python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT
# this will generate seg_label directory

测试

作者给出的训练好的模型
(Tusimple: GoogleDrive/BaiduDrive(code:s5ii), CULane: GoogleDrive/BaiduDrive(code:rlwj))

不可视化

 python main.py configs/tusimple.py --validate --load_from ./tusimple_resnet34.pth --gpus 0

可视化

 python main.py configs/tusimple.py --validate --load_from ./tusimple_resnet34.pth --gpus 0 --view

训练

python main.py configs/tusimple.py --gpus 0 --work_dirs /content/drive/MyDrive/chxsave
configs/tusimple.py
net = dict(
    type='RESANet',
)

backbone = dict(
    type='ResNetWrapper',
    resnet='resnet34',
    pretrained=True,
    replace_stride_with_dilation=[False, True, True],  #空洞卷积
    out_conv=True,                       
    fea_stride=8,
)

resa = dict(
    type='RESA',
    alpha=2.0,      #最小学习率
    iter=5,         #迭代次数
    input_channel=128,
    conv_stride=9,
)

decoder = 'BUSD'        

trainer = dict(
    type='RESA'
)

evaluator = dict(
    type='Tusimple',        
    thresh = 0.60
)

optimizer = dict(
  type='sgd',
  lr=0.020,
  weight_decay=1e-4,
  momentum=0.9
)

total_iter = 80000
import math
scheduler = dict(
    type = 'LambdaLR',
    lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
)

bg_weight = 0.4

img_norm = dict(
    mean=[103.939, 116.779, 123.68],
    std=[1., 1., 1.]
)

img_height = 368
img_width = 640
cut_height = 160
seg_label = "seg_label"

dataset_path = './data/tusimple'
test_json_file = './data/tusimple/test_label.json'

dataset = dict(
    train=dict(
        type='TuSimple',
        img_path=dataset_path,
        data_list='train_val_gt.txt',
    ),
    val=dict(
        type='TuSimple',
        img_path=dataset_path,
        data_list='test_gt.txt'
    ),
    test=dict(
        type='TuSimple',
        img_path=dataset_path,
        data_list='test_gt.txt'
    )
)


loss_type = 'cross_entropy'
seg_loss_weight = 1.0


batch_size = 4
workers = 12
num_classes = 6 + 1
ignore_label = 255
epochs = 300
log_interval = 100
eval_ep = 1
save_ep = epochs
log_note = ''

datasets/base_dataset.py

imagepath----listpath

import os.path as osp
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import torchvision
import utils.transforms as tf
from .registry import DATASETS


@DATASETS.register_module
class BaseDataset(Dataset):
    def __init__(self, img_path, data_list, list_path='list', cfg=None):
        self.cfg = cfg
        self.img_path = img_path
        self.list_path = osp.join(img_path, list_path)
        self.data_list = data_list
        self.is_training = ('train' in data_list)

        self.img_name_list = []
        self.full_img_path_list = []
        self.label_list = []
        self.exist_list = []

        self.transform = self.transform_train() if self.is_training else self.transform_val()

        self.init()

    def transform_train(self):
        raise NotImplementedError()
#尺寸改变
#用均值和标准差对图像归一化
    def transform_val(self):
        val_transform = torchvision.transforms.Compose([
            tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
            tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
                self.cfg.img_norm['std'], (1, ))),
        ])
        return val_transform
#可视化 filepath储存生成的图像
    def view(self, img, coords, file_path=None):
        for coord in coords:
        #取x,y坐标
            for x, y in coord:
                if x <= 0 or y <= 0:
                    continue
                x, y = int(x), int(y)
                #用圆圈代替车道线点
                cv2.circle(img, (x, y), 4, (255, 0, 0), 2)

        if file_path is not None:
            if not os.path.exists(osp.dirname(file_path)):
                os.makedirs(osp.dirname(file_path))
            cv2.imwrite(file_path, img)


    def init(self):
        raise NotImplementedError()


    def __len__(self):
        return len(self.full_img_path_list)

    def __getitem__(self, idx):
        img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
        #读入图片
        img = img[self.cfg.cut_height:, :, :]
        #图像切割

        if self.is_training:
        #训练
            label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
            if len(label.shape) > 2:
                label = label[:, :, 0]
            label = label.squeeze()
            label = label[self.cfg.cut_height:, :]
            exist = self.exist_list[idx]
            if self.transform:
            #图像预处理
                img, label = self.transform((img, label))
            label = torch.from_numpy(label).contiguous().long()
        else:
            img, = self.transform((img,))
        img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
        #将tensor的维度换位,通道变换
        meta = {'full_img_path': self.full_img_path_list[idx],
                'img_name': self.img_name_list[idx]}

        data = {'img': img, 'meta': meta}
        if self.is_training:
            data.update({'label': label, 'exist': exist})
        return data

datasets/registry.py
from utils import Registry, build_from_cfg

import torch

DATASETS = Registry('datasets')
def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_dataset(split_cfg, cfg):
    args = split_cfg.copy()
    args.pop('type')
    args = args.to_dict()
    args['cfg'] = cfg
    return build(split_cfg, DATASETS, default_args=args)

def build_dataloader(split_cfg, cfg, is_train=True):
    if is_train:
        shuffle = True   #随机打乱
    else:
        shuffle = False

    dataset = build_dataset(split_cfg, cfg)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size = cfg.batch_size, shuffle = shuffle,  
        num_workers = cfg.workers, pin_memory = False, drop_last = False)
      
    return data_loader

datasets/tusimple.py
  • init
  • transform_train
  • fix_gap
  • is_short
  • get_lane
  • probmap2lane
import os.path as osp
import numpy as np
import cv2
import torchvision
import utils.transforms as tf
from .base_dataset import BaseDataset
from .registry import DATASETS


@DATASETS.register_module
class TuSimple(BaseDataset):
    def __init__(self, img_path, data_list, cfg=None):
        super().__init__(img_path, data_list, 'seg_label/list', cfg)
#训练集图像预处理
    def transform_train(self):
        input_mean = self.cfg.img_norm['mean']
        train_transform = torchvision.transforms.Compose([
            tf.GroupRandomRotation(),
            tf.GroupRandomHorizontalFlip(),
            tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
            tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
                self.cfg.img_norm['std'], (1, ))),
        ])
        return train_transform


    def init(self):
        with open(osp.join(self.list_path, self.data_list)) as f:
            for line in f:
                line_split = line.strip().split(" ")
                self.img_name_list.append(line_split[0])
                self.full_img_path_list.append(self.img_path + line_split[0])
                if not self.is_training:
                    continue
                self.label_list.append(self.img_path + line_split[1])
                self.exist_list.append(
                    np.array([int(line_split[2]), int(line_split[3]),
                              int(line_split[4]), int(line_split[5]),
                              int(line_split[6]), int(line_split[7])
                              ]))

    def fix_gap(self, coordinate):
        if any(x > 0 for x in coordinate):
            start = [i for i, x in enumerate(coordinate) if x > 0][0]
            end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
            lane = coordinate[start:end+1]
            if any(x < 0 for x in lane):
                gap_start = [i for i, x in enumerate(
                    lane[:-1]) if x > 0 and lane[i+1] < 0]
                gap_end = [i+1 for i,
                           x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
                gap_id = [i for i, x in enumerate(lane) if x < 0]
                if len(gap_start) == 0 or len(gap_end) == 0:
                    return coordinate
                for id in gap_id:
                    for i in range(len(gap_start)):
                        if i >= len(gap_end):
                            return coordinate
                        if id > gap_start[i] and id < gap_end[i]:
                            gap_width = float(gap_end[i] - gap_start[i])
                            lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
                                gap_end[i] - id) / gap_width * lane[gap_start[i]])
                if not all(x > 0 for x in lane):
                    print("Gaps still exist!")
                coordinate[start:end+1] = lane
        return coordinate

    def is_short(self, lane):
        start = [i for i, x in enumerate(lane) if x > 0]
        if not start:
            return 1
        else:
            return 0

    def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None):
        """
        Arguments:
        ----------
        prob_map: prob map for single lane, np array size (h, w)
        resize_shape:  reshape size target, (H, W)
    
        Return:
        ----------
        coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
        """
        if resize_shape is None:
            resize_shape = prob_map.shape
        h, w = prob_map.shape
        H, W = resize_shape
        H -= self.cfg.cut_height
    
        coords = np.zeros(pts)
        coords[:] = -1.0
        for i in range(pts):
            y = int((H - 10 - i * y_px_gap) * h / H)
            if y < 0:
                break
            line = prob_map[y, :]
            id = np.argmax(line)
            if line[id] > thresh:
                coords[i] = int(id / w * W)
        if (coords > 0).sum() < 2:
            coords = np.zeros(pts)
        self.fix_gap(coords)
        #print(coords.shape)

        return coords

    def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
        """
        Arguments:
        ----------
        seg_pred:      np.array size (5, h, w)
        resize_shape:  reshape size target, (H, W)
        exist:       list of existence, e.g. [0, 1, 1, 0]
        smooth:      whether to smooth the probability or not
        y_px_gap:    y pixel gap for sampling
        pts:     how many points for one lane
        thresh:  probability threshold
    
        Return:
        ----------
        coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
        """
        if resize_shape is None:
            resize_shape = seg_pred.shape[1:]  # seg_pred (5, h, w)
        _, h, w = seg_pred.shape
        H, W = resize_shape
        coordinates = []
    
        for i in range(self.cfg.num_classes - 1):
            prob_map = seg_pred[i + 1]
            if smooth:
                prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
            coords = self.get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
            if self.is_short(coords):
                continue
            coordinates.append(
                [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
                 range(pts)])
    
    
        if len(coordinates) == 0:
            coords = np.zeros(pts)
            coordinates.append(
                [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
                 range(pts)])
        #print(coordinates)
    
        return coordinates

utils貌似是为configs服务的
utils/registry.py
import inspect

import six

# borrow from mmdetection

def is_str(x):
    """Whether the input is an string instance."""
    return isinstance(x, six.string_types)

class Registry(object):

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        return self._module_dict.get(key, None)

    def _register_module(self, module_class):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):
        self._register_module(cls)
        return cls


def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = {}
    obj_type = cfg.type 
    if is_str(obj_type):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_cls(**args)
models/decoder.py

在这里插入图片描述

from torch import nn
import torch.nn.functional as F

class PlainDecoder(nn.Module):
    def __init__(self, cfg):
        super(PlainDecoder, self).__init__()
        self.cfg = cfg
        self.dropout = nn.Dropout2d(0.1)
        self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.conv8(x)
        x = F.interpolate(x, size=[self.cfg.img_height,  self.cfg.img_width],
                           mode='bilinear', align_corners=False)
#双线性插值
        return x


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class non_bottleneck_1d(nn.Module):
    def __init__(self, chann, dropprob, dilated):
        super().__init__()

        self.conv3x1_1 = nn.Conv2d(
            chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)

        self.conv1x3_1 = nn.Conv2d(
            chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)

        self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)

        self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
                                   dilation=(dilated, 1))

        self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
                                   dilation=(1, dilated))

        self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)

        self.dropout = nn.Dropout2d(dropprob)

    def forward(self, input):
        output = self.conv3x1_1(input)
        output = F.relu(output)
        output = self.conv1x3_1(output)
        output = self.bn1(output)
        output = F.relu(output)

        output = self.conv3x1_2(output)
        output = F.relu(output)
        output = self.conv1x3_2(output)
        output = self.bn2(output)

        if (self.dropout.p != 0):
            output = self.dropout(output)

        # +input = identity (residual connection)
        return F.relu(output + input)

#Fine detailed branch + Coarse grained branch
class UpsamplerBlock(nn.Module):
    def __init__(self, ninput, noutput, up_width, up_height):
        super().__init__()

        self.conv = nn.ConvTranspose2d(
            ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)

        self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)

        self.follows = nn.ModuleList()
        self.follows.append(non_bottleneck_1d(noutput, 0, 1))
        self.follows.append(non_bottleneck_1d(noutput, 0, 1))

        # interpolate
        self.up_width = up_width
        self.up_height = up_height


        self.interpolate_conv = conv1x1(ninput, noutput)
        self.interpolate_bn = nn.BatchNorm2d(
            noutput, eps=1e-3, track_running_stats=True)

    def forward(self, input):
        output = self.conv(input)
        output = self.bn(output)
        out = F.relu(output)
        for follow in self.follows:
            out = follow(out)

        interpolate_output = self.interpolate_conv(input)
        interpolate_output = self.interpolate_bn(interpolate_output)
        interpolate_output = F.relu(interpolate_output)

        interpolate = F.interpolate(interpolate_output, size=[self.up_height,  self.up_width],
                                    mode='bilinear', align_corners=False)

        return out + interpolate
#Bilateral Up-sampling Decoder
class BUSD(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        img_height = cfg.img_height
        img_width = cfg.img_width
        num_classes = cfg.num_classes

        self.layers = nn.ModuleList()

        self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
                                          up_height=int(img_height)//4, up_width=int(img_width)//4))
        self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
                                          up_height=int(img_height)//2, up_width=int(img_width)//2))
        self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
                                          up_height=int(img_height)//1, up_width=int(img_width)//1))

        self.output_conv = conv1x1(16, num_classes)

    def forward(self, input):
        output = input

        for layer in self.layers:
            output = layer(output)

        output = self.output_conv(output)

        return output

models/registry.py
from utils import Registry, build_from_cfg

NET = Registry('net')

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_net(cfg):
    return build(cfg.net, NET, default_args=dict(cfg=cfg))

models/resa.py
import torch.nn as nn
import torch
import torch.nn.functional as F

from models.registry import NET
from .resnet import ResNetWrapper 
from .decoder import BUSD, PlainDecoder 


class RESA(nn.Module):
    def __init__(self, cfg):
        super(RESA, self).__init__()
        self.iter = cfg.resa.iter
        chan = cfg.resa.input_channel
        fea_stride = cfg.backbone.fea_stride
        self.height = cfg.img_height // fea_stride
        self.width = cfg.img_width // fea_stride
        self.alpha = cfg.resa.alpha
        conv_stride = cfg.resa.conv_stride

        for i in range(self.iter):
        #nn.Conv2d(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True))
            conv_vert1 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)
            conv_vert2 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)

            setattr(self, 'conv_d'+str(i), conv_vert1)
            setattr(self, 'conv_u'+str(i), conv_vert2)

            conv_hori1 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)
            conv_hori2 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)

            setattr(self, 'conv_r'+str(i), conv_hori1)
            setattr(self, 'conv_l'+str(i), conv_hori2)
 # //是整除  **是幂
            idx_d = (torch.arange(self.height) + self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_d'+str(i), idx_d)

            idx_u = (torch.arange(self.height) - self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_u'+str(i), idx_u)

            idx_r = (torch.arange(self.width) + self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_r'+str(i), idx_r)

            idx_l = (torch.arange(self.width) - self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_l'+str(i), idx_l)

    def forward(self, x):
        x = x.clone()

        for direction in ['d', 'u']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx, :])))

        for direction in ['r', 'l']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx])))

        return x


#exist 输出
class ExistHead(nn.Module):
    def __init__(self, cfg=None):
        super(ExistHead, self).__init__()
        self.cfg = cfg

        self.dropout = nn.Dropout2d(0.1)  
        self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)

        stride = cfg.backbone.fea_stride * 2
        self.fc9 = nn.Linear(
            int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)
        self.fc10 = nn.Linear(128, cfg.num_classes-1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.conv8(x)

        x = F.softmax(x, dim=1)
        x = F.avg_pool2d(x, 2, stride=2, padding=0)
        x = x.view(-1, x.numel() // x.shape[0])
        x = self.fc9(x)
        x = F.relu(x)
        x = self.fc10(x)
        x = torch.sigmoid(x)

        return x


@NET.register_module
class RESANet(nn.Module):
    def __init__(self, cfg):
        super(RESANet, self).__init__()
        self.cfg = cfg
        self.backbone = ResNetWrapper(cfg)
        self.resa = RESA(cfg)
        self.decoder = eval(cfg.decoder)(cfg)
        self.heads = ExistHead(cfg) 

    def forward(self, batch):
        fea = self.backbone(batch)
        fea = self.resa(fea)
        seg = self.decoder(fea)
        exist = self.heads(fea)

        output = {'seg': seg, 'exist': exist}

        return output

models/resnet.py
import torch
from torch import nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url


# This code is borrow from torchvision.

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError(
                'BasicBlock only supports groups=1 and base_width=64')
        # if dilation > 1:
        #     raise NotImplementedError(
        #         "Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, dilation=dilation)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNetWrapper(nn.Module):

    def __init__(self, cfg):
        super(ResNetWrapper, self).__init__()
        self.cfg = cfg
        self.in_channels = [64, 128, 256, 512]
        if 'in_channels' in cfg.backbone:
            self.in_channels = cfg.backbone.in_channels
        self.model = eval(cfg.backbone.resnet)(
            pretrained=cfg.backbone.pretrained,
            replace_stride_with_dilation=cfg.backbone.replace_stride_with_dilation, in_channels=self.in_channels)
        self.out = None
        if cfg.backbone.out_conv:
            out_channel = 512
            for chan in reversed(self.in_channels):
                if chan < 0: continue
                out_channel = chan
                break
            self.out = conv1x1(
                out_channel * self.model.expansion, 128)

    def forward(self, x):
        x = self.model(x)
        if self.out:
            x = self.out(x)
        return x


class ResNet(nn.Module):

    def __init__(self, block, layers, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, in_channels=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.in_channels = in_channels
        self.layer1 = self._make_layer(block, in_channels[0], layers[0])
        self.layer2 = self._make_layer(block, in_channels[1], layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, in_channels[2], layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        if in_channels[3] > 0:
            self.layer4 = self._make_layer(block, in_channels[3], layers[3], stride=2,
                                           dilate=replace_stride_with_dilation[2])
        self.expansion = block.expansion

        # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        if self.in_channels[3] > 0:
            x = self.layer4(x)

        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.fc(x)

        return x


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet50(pretrained=False, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet101(pretrained=False, progress=True, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)


def resnet152(pretrained=False, progress=True, **kwargs):
    r"""ResNet-152 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)


def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)

runner/evaluator/tusimple/tusimple.py
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.logger import get_logger

from runner.registry import EVALUATOR 
import json
import os
import cv2

from .lane import LaneEval

def split_path(path):
    """split path tree into list"""
    folders = []
    while True:
        path, folder = os.path.split(path) 
        if folder != "":
            folders.insert(0, folder)
        else:
            if path != "":
                folders.insert(0, path)
            break
    return folders


@EVALUATOR.register_module
class Tusimple(nn.Module):
    def __init__(self, cfg):
        super(Tusimple, self).__init__()
        self.cfg = cfg 
        exp_dir = os.path.join(self.cfg.work_dir, "output")
        if not os.path.exists(exp_dir):
            os.mkdir(exp_dir)
        self.out_path = os.path.join(exp_dir, "coord_output")
        if not os.path.exists(self.out_path):
            os.mkdir(self.out_path)                       #坐标输出                                         
        self.dump_to_json = [] 
        self.thresh = cfg.evaluator.thresh
        self.logger = get_logger('resa')            #日志
        if cfg.view:
            self.view_dir = os.path.join(self.cfg.work_dir, 'vis')  #可视化

    def evaluate_pred(self, dataset, seg_pred, exist_pred, batch):
        img_name = batch['meta']['img_name']
        img_path = batch['meta']['full_img_path']
        for b in range(len(seg_pred)):                    #看车道线是否存在   
            seg = seg_pred[b]
            exist = [1 if exist_pred[b, i] >
                     0.5 else 0 for i in range(self.cfg.num_classes-1)]
            lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh) 
            for i in range(len(lane_coords)):
                lane_coords[i] = sorted(
                    lane_coords[i], key=lambda pair: pair[1])  

            path_tree = split_path(img_name[b])
            save_dir, save_name = path_tree[-3:-1], path_tree[-1]
            save_dir = os.path.join(self.out_path, *save_dir)
            save_name = save_name[:-3] + "lines.txt"
            save_name = os.path.join(save_dir, save_name)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir, exist_ok=True)

            with open(save_name, "w") as f:
                for l in lane_coords:
                    for (x, y) in l:
                        print("{} {}".format(x, y), end=" ", file=f)
                    print(file=f)

            json_dict = {}
            json_dict['lanes'] = []
            json_dict['h_sample'] = []
            json_dict['raw_file'] = os.path.join(*path_tree[-4:])
            json_dict['run_time'] = 0
            for l in lane_coords:
                if len(l) == 0:
                    continue
                json_dict['lanes'].append([])
                for (x, y) in l:
                    json_dict['lanes'][-1].append(int(x))
            for (x, y) in lane_coords[0]:
                json_dict['h_sample'].append(y)
            self.dump_to_json.append(json.dumps(json_dict))
            if self.cfg.view:
                img = cv2.imread(img_path[b])
                new_img_name = img_name[b].replace('/', '_')
                save_dir = os.path.join(self.view_dir, new_img_name)
                dataset.view(img, lane_coords, save_dir)


    def evaluate(self, dataset, output, batch):
        seg_pred, exist_pred = output['seg'], output['exist']
        seg_pred = F.softmax(seg_pred, dim=1)
        seg_pred = seg_pred.detach().cpu().numpy()
        exist_pred = exist_pred.detach().cpu().numpy()
        self.evaluate_pred(dataset, seg_pred, exist_pred, batch)

    def summarize(self):
        best_acc = 0
        output_file = os.path.join(self.out_path, 'predict_test.json')
        with open(output_file, "w+") as f:
            for line in self.dump_to_json:
                print(line, end="\n", file=f)

        eval_result, acc = LaneEval.bench_one_submit(output_file,
                            self.cfg.test_json_file)

        self.logger.info(eval_result)
        self.dump_to_json = []
        best_acc = max(acc, best_acc)
        return best_acc

runner/logger.py
  • 日志记录
import logging

logger_initialized = {}

def get_logger(name, log_file=None, log_level=logging.INFO):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified and the process rank is 0, a FileHandler
    will also be added.
    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level. Note that only the process of
            rank 0 is affected, and other processes will set the level to
            "Error" thus be silent most of the time.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    # handle hierarchical names
    # e.g., logger "a" is initialized, then logger "a.b" will skip the
    # initialization since it is a child of "a".
    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger

    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]

    if log_file is not None:
        file_handler = logging.FileHandler(log_file, 'w')
        handlers.append(file_handler)

    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)

    logger.setLevel(log_level)

    logger_initialized[name] = True

    return logger

runnner/net_utils.py
import torch
import os
from torch import nn
import numpy as np
import torch.nn.functional
from termcolor import colored
from .logger import get_logger
#如果是最好的就保存下来
def save_model(net, optim, scheduler, recorder, is_best=False):
    model_dir = os.path.join(recorder.work_dir, 'ckpt')
    os.system('mkdir -p {}'.format(model_dir))  #-p用来创建多级文件夹
    epoch = recorder.epoch
    ckpt_name = 'best' if is_best else epoch             
    torch.save({
        'net': net.state_dict(),
        'optim': optim.state_dict(),
        'scheduler': scheduler.state_dict(),
        'recorder': recorder.state_dict(),
        'epoch': epoch
    }, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))


def load_network_specified(net, model_dir, logger=None):
    pretrained_net = torch.load(model_dir)['net']
    net_state = net.state_dict()
    state = {}
    for k, v in pretrained_net.items():
        if k not in net_state.keys() or v.size() != net_state[k].size():
            if logger:
                logger.info('skip weights: ' + k)
            continue
        state[k] = v
    net.load_state_dict(state, strict=False)


def load_network(net, model_dir, finetune_from=None, logger=None):
    if finetune_from:
        if logger:
            logger.info('Finetune model from: ' + finetune_from)
        load_network_specified(net, finetune_from, logger)
        return
    pretrained_model = torch.load(model_dir)
    net.load_state_dict(pretrained_model['net'], strict=True)

runnner/optimizer.py
import torch


_optimizer_factory = {
    'adam': torch.optim.Adam,
    'sgd': torch.optim.SGD
}


def build_optimizer(cfg, net):
    params = []
    lr = cfg.optimizer.lr
    weight_decay = cfg.optimizer.weight_decay

    for key, value in net.named_parameters():
        if not value.requires_grad:
            continue
        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

    if 'adam' in cfg.optimizer.type:
        optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay)
    else:
        optimizer = _optimizer_factory[cfg.optimizer.type](
                params, lr, weight_decay=weight_decay, momentum=cfg.optimizer.momentum)

    return optimizer

runnner/recorder.py
from collections import deque, defaultdict
import torch
import os
import datetime
from .logger import get_logger


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """
#存储一系列值,并且提供median和avg
    def __init__(self, window_size=20):
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0

    def update(self, value):
        self.deque.append(value)
        self.count += 1
        self.total += value

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque))
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count


class Recorder(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.work_dir = self.get_work_dir()  #具体的年月日和参数文件夹
        cfg.work_dir = self.work_dir          #更新到cfg
        self.log_path = os.path.join(self.work_dir, 'log.txt')#日志

        self.logger = get_logger('resa', self.log_path)
        self.logger.info('Config: \n' + cfg.text) #把cfg写到文件

        # scalars
        self.epoch = 0
        self.step = 0
        self.loss_stats = defaultdict(SmoothedValue)
        self.batch_time = SmoothedValue()
        self.data_time = SmoothedValue()
        self.max_iter = self.cfg.total_iter  
        self.lr = 0.

    def get_work_dir(self):  #记录文件夹创建
        now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')   #根据时间和参数创建文件夹
        hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size) #根据时间和参数创建文件夹
        work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str)
        if not os.path.exists(work_dir):
            os.makedirs(work_dir)
        return work_dir

    def update_loss_stats(self, loss_dict):
        for k, v in loss_dict.items():
            self.loss_stats[k].update(v.detach().cpu())

    def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
        self.logger.info(self)
        # self.write(str(self))

    def write(self, content):
        with open(self.log_path, 'a+') as f:
            f.write(content)
            f.write('\n')

    def state_dict(self):
        scalar_dict = {}
        scalar_dict['step'] = self.step
        return scalar_dict

    def load_state_dict(self, scalar_dict):
        self.step = scalar_dict['step']

    def __str__(self):
        loss_state = []
        for k, v in self.loss_stats.items():
            loss_state.append('{}: {:.4f}'.format(k, v.avg))
        loss_state = '  '.join(loss_state)

        recording_state = '  '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}'])
        eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string)


def build_recorder(cfg):
    return Recorder(cfg)


runner/registry.py
from utils import Registry, build_from_cfg

TRAINER = Registry('trainer')
EVALUATOR = Registry('evaluator')

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

def build_trainer(cfg):
    return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg))

def build_evaluator(cfg):
    return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg))

runner/resa_trainer.py ----loss计算

在这里插入图片描述

  • X - GT 分割图像, Y - Pred 分割图像.
    在这里插入图片描述
import torch.nn as nn
import torch
import torch.nn.functional as F

from runner.registry import TRAINER

def dice_loss(input, target):
    input = input.contiguous().view(input.size()[0], -1)
    target = target.contiguous().view(target.size()[0], -1).float()

    a = torch.sum(input * target, 1)
    b = torch.sum(input * input, 1) + 0.001
    c = torch.sum(target * target, 1) + 0.001
    d = (2 * a) / (b + c)
    return (1-d).mean()

@TRAINER.register_module
class RESA(nn.Module):
    def __init__(self, cfg):
        super(RESA, self).__init__()
        self.cfg = cfg
        self.loss_type = cfg.loss_type
        if self.loss_type == 'cross_entropy':
            weights = torch.ones(cfg.num_classes)
            weights[0] = cfg.bg_weight
            weights = weights.cuda()
            self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
                                              weight=weights).cuda()

        self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()

    def forward(self, net, batch):
        output = net(batch['img'])

        loss_stats = {}
        loss = 0.

        if self.loss_type == 'dice_loss':
            target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
            seg_loss = dice_loss(F.softmax(
                output['seg'], dim=1)[:, 1:], target[:, 1:])
        else:
            seg_loss = self.criterion(F.log_softmax(
                output['seg'], dim=1), batch['label'].long())

        loss += seg_loss * self.cfg.seg_loss_weight

        loss_stats.update({'seg_loss': seg_loss})

        if 'exist' in output:
            exist_loss = 0.1 * \
                self.criterion_exist(output['exist'], batch['exist'].float())
            loss += exist_loss
            loss_stats.update({'exist_loss': exist_loss})

        ret = {'loss': loss, 'loss_stats': loss_stats}

        return ret

runner/scheduler.py 学习率衰减策略Lambda LR
import torch
import math


_scheduler_factory = {
    'LambdaLR': torch.optim.lr_scheduler.LambdaLR,
}


def build_scheduler(cfg, optimizer):

    assert cfg.scheduler.type in _scheduler_factory

    cfg_cp = cfg.scheduler.copy()
    cfg_cp.pop('type')

    scheduler = _scheduler_factory[cfg.scheduler.type](optimizer, **cfg_cp)


    return scheduler 

runner/runner.py
import time
import torch
import numpy as np
from tqdm import tqdm
import pytorch_warmup as warmup

from models.registry import build_net
from .registry import build_trainer, build_evaluator
from .optimizer import build_optimizer
from .scheduler import build_scheduler
from datasets import build_dataloader
from .recorder import build_recorder
from .net_utils import save_model, load_network


class Runner(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.recorder = build_recorder(self.cfg)   #创建日志
        self.net = build_net(self.cfg)             #创建模型                       
        self.net = torch.nn.parallel.DataParallel(
                self.net, device_ids = range(self.cfg.gpus)).cuda() #多GPU
        self.recorder.logger.info('Network: \n' + str(self.net)) #把网络结构写进日志
        self.resume() #如果需要恢复模型
        self.optimizer = build_optimizer(self.cfg, self.net)
        self.scheduler = build_scheduler(self.cfg, self.optimizer)
        self.evaluator = build_evaluator(self.cfg)
        self.warmup_scheduler = warmup.LinearWarmup(
            self.optimizer, warmup_period=5000)
        self.metric = 0.

    def resume(self):
        if not self.cfg.load_from and not self.cfg.finetune_from:
            return
        load_network(self.net, self.cfg.load_from,
                finetune_from=self.cfg.finetune_from, logger=self.recorder.logger)

    def to_cuda(self, batch):  #数据送到cuda
        for k in batch:
            if k == 'meta':
                continue
            batch[k] = batch[k].cuda()
        return batch
    
    def train_epoch(self, epoch, train_loader): 
        self.net.train() #训练模式 设计BN和dropout
        end = time.time()
        max_iter = len(train_loader)
        for i, data in enumerate(train_loader):
            if self.recorder.step >= self.cfg.total_iter:    #80000
                break
            date_time = time.time() - end  
            self.recorder.step += 1  
            data = self.to_cuda(data)
            output = self.trainer.forward(self.net, data) #前向
            self.optimizer.zero_grad()  #梯度清零
            loss = output['loss']
            loss.backward()
            self.optimizer.step() 
            self.scheduler.step()
            self.warmup_scheduler.dampen()
            batch_time = time.time() - end
            end = time.time()
            self.recorder.update_loss_stats(output['loss_stats'])
            self.recorder.batch_time.update(batch_time) #批时间
            self.recorder.data_time.update(date_time) #数据时间

            if i % self.cfg.log_interval == 0 or i == max_iter - 1:
                lr = self.optimizer.param_groups[0]['lr']
                self.recorder.lr = lr
                self.recorder.record('train')

    def train(self):
        self.recorder.logger.info('start training...')
        self.trainer = build_trainer(self.cfg)
        train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True)
        val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False)

        for epoch in range(self.cfg.epochs):
            self.recorder.epoch = epoch
            self.train_epoch(epoch, train_loader)
            if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1:
                self.save_ckpt()
            if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1:
                self.validate(val_loader)
            if self.recorder.step >= self.cfg.total_iter:
                break

    def validate(self, val_loader):
        self.net.eval()
        for i, data in enumerate(tqdm(val_loader, desc=f'Validate')):
            data = self.to_cuda(data)
            with torch.no_grad():
                output = self.net(data['img'])
                self.evaluator.evaluate(val_loader.dataset, output, data)

        metric = self.evaluator.summarize()
        if not metric:
            return
        if metric > self.metric:
            self.metric = metric
            self.save_ckpt(is_best=True)
        self.recorder.logger.info('Best metric: ' + str(self.metric))

    def save_ckpt(self, is_best=False):
        save_model(self.net, self.optimizer, self.scheduler,
                self.recorder, is_best)

main.py
import os
import os.path as osp
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import cv2
import numpy as np
import models
import argparse
from utils.config import Config
from runner.runner import Runner 
from datasets import build_dataloader


def main():
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)

    cfg = Config.fromfile(args.config)
    cfg.gpus = len(args.gpus)

    cfg.load_from = args.load_from
    cfg.finetune_from = args.finetune_from
    cfg.view = args.view

    cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type

    cudnn.benchmark = True
    cudnn.fastest = True

    runner = Runner(cfg)

    if args.validate:
        val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
        runner.validate(val_loader)
    else:
        runner.train()

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path') #config
    parser.add_argument(
        '--work_dirs', type=str, default='work_dirs',            #workdir                            
        help='work dirs')
    parser.add_argument(
        '--load_from', default=None,
        help='the checkpoint file to resume from')               #resume              
    parser.add_argument(
        '--finetune_from', default=None,
        help='whether to finetune from the checkpoint')       #finetune
    parser.add_argument(
        '--validate',
        action='store_true',                                                                           
        help='whether to evaluate the checkpoint during training')
    parser.add_argument(
        '--view',
        action='store_true',
        help='whether to show visualization result')
    parser.add_argument('--gpus', nargs='+', type=int, default='0')
    parser.add_argument('--seed', type=int,
                        default=None, help='random seed')
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    main()

colab下代码复现

!wget https://s3.us-east-2.amazonaws.com/benchmark-frontend/datasets/1/test_set.zip
!7za x test_set.zip -o/content/drive/MyDrive/chx
!wget https://s3.us-east-2.amazonaws.com/benchmark-frontend/datasets/1/train_set.zip
!7za x train_set.zip -o/content/drive/MyDrive/chx
! /opt/bin/nvidia-smi
cd /content
!git clone https://github.com/ZJULearning/resa.git
cd ./resa
!pip install -r requirement.txt
!pip install pytorch_warmup
!pip install yapf

dataset_path = '/content/drive/MyDrive/chx'
test_json_file = '/content/drive/MyDrive/chx/test_label.json'
!python main.py configs/tusimple.py --gpus 0 --work_dirs /content/drive/MyDrive/chxsave
  • 0
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值