pytorch-上分之路-图像分割

项目结构

在这里插入图片描述
这个项目的数据集主要来自百度的道路分割,官网可以下到。具体代码内容如下

common

import torch.nn as nn


class ConvBNReLU(nn.Sequential):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation):
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(inplanes, planes, kernel_size, stride=1,
                      padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True),
        )


class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling.
    Ref:
        Rethinking Atrous Convolution for Semantic Image Segmentation
    """

    def __init__(self, inplanes=2048, planes=256, stride=16):
        super(ASPP, self).__init__()
        if stride == 8:
            dilation = [12, 24, 36]
        elif stride == 16:
            dilation = [6, 12, 18]
        else:
            raise NotImplementedError

        self.block1 = ConvBNReLU(inplanes, planes, 1, 0, 1)  # inchannel,outchannel,kernel,padding,dilation
        self.block2 = ConvBNReLU(inplanes, planes, 3, dilation[0], dilation[0])
        self.block3 = ConvBNReLU(inplanes, planes, 3, dilation[1], dilation[1])
        self.block4 = ConvBNReLU(inplanes, planes, 3, dilation[2], dilation[2])

        self.block5 = nn.Sequential(
            nn.AdaptiveAvgPool2d(4),
            ConvBNReLU(inplanes, planes, 1, 0, 1),
        )

        self.conv = ConvBNReLU(planes * 5, planes, 1, 0, 1)
        self.dropout = nn.Dropout(0.5)

        self._init_weight()

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        h1 = self.block1(x)
        h2 = self.block2(x)
        h3 = self.block3(x)
        h4 = self.block4(x)
        h5 = self.block5(x)
        h5 = F.interpolate(h5, size=x.size()[2:], mode='bilinear', align_corners=True)  # 这个可以替代crf

        x = torch.cat((h1, h2, h3, h4, h5), dim=1)
        x = self.conv(x)
        x = self.dropout(x)  # 这里并没有bn和dropout连用,所以没有问题
        return x


import torch.nn as nn
import torch.nn.functional as F


class Decoder(nn.Module):
    """
    Ref:
        Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.
    """

    def __init__(self, planes=128, num_classes=3):
        super(Decoder, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(planes, planes, 1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(planes + 256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1),  # kernel为1
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x1, x2, output_size):
        """

        :param x1:
        :param x2: low level feature
        :return:
        """
        out1 = self.conv1(x2)
        out0 = F.interpolate(x1, size=x2.size()[2:], mode='bilinear', align_corners=True)
        out = torch.cat((out0, out1), dim=1)
        out = self.conv2(out)
        out = F.interpolate(out, size=output_size, mode='bilinear', align_corners=True)

        return out


import torch
import math
import torch.nn as nn


class SeparableConv2d(nn.Module):
    """
    Depth Separable Convolution.
    """

    def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
        super(SeparableConv2d, self).__init__()
        padding = (kernel_size - 1) * dilation // 2
        self.depth_wise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding,
                                    dilation, groups=inplanes, bias=bias)
        # self.bn = nn.BatchNorm2d(inplanes)
        # inchannel outchannel kernel stride  padding dilation groups bias
        self.point_wise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias)

    def forward(self, x):
        x = self.depth_wise(x)
        # x = self.bn(x)
        x = self.point_wise(x)

        return x


class BasicConv2d(nn.Module):
    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(BasicConv2d, self).__init__()
        self.features = nn.Sequential(
            SeparableConv2d(inplanes, planes, 3, stride=stride, dilation=dilation),
            nn.ReLU(inplace=True),
            SeparableConv2d(planes, planes, 3, stride=1, dilation=dilation),
            nn.ReLU(inplace=True),
            SeparableConv2d(planes, planes, 3, stride=1, dilation=dilation)
        )

        self.downsample = None
        if inplanes != planes or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )

    def forward(self, x):
        identity = x
        x = self.features(x)

        if self.downsample is not None:
            identity = self.downsample(identity)

        x = x + identity

        return x


class AlignedXception(nn.Module):
    """
    Ref:
        Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.
    """

    def __init__(self, stride=16):
        """

        :param stride: Multiples of image down-sampling. The default value is 16(DeepLab v3+) or
        it can be set to 8(DeepLab v3).
        """
        super(AlignedXception, self).__init__()
        if stride == 8:
            self.stride = [1, 1]
            self.dilation = [4, 4]
        elif stride == 16:
            self.stride = [2, 1]
            self.dilation = [2, 2]
        elif stride == 32:
            self.stride = [2, 2]
            self.dilation = [1, 1]
        else:
            raise NotImplementedError

        # Entry flow
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.stage1 = nn.Sequential(
            BasicConv2d(64, 128, 2),
            nn.ReLU(inplace=True),
        )
        self.stage2 = BasicConv2d(128, 256, 2)
        self.stage3 = BasicConv2d(256, 728, self.stride[0])

        # Middle flow
        layers = []
        for _ in range(16):
            layers.append(BasicConv2d(728, 728, stride=1, dilation=self.dilation[0]))
        self.stage4 = nn.Sequential(*layers)

        # Exit flow
        self.stage5 = nn.Sequential(
            BasicConv2d(728, 1024, stride=self.stride[1], dilation=self.dilation[1]),
            nn.ReLU(inplace=True),
            SeparableConv2d(1024, 1536, dilation=self.dilation[1]),
            nn.BatchNorm2d(1536),
            nn.ReLU(inplace=True),
            SeparableConv2d(1536, 1536, dilation=self.dilation[1]),
            nn.BatchNorm2d(1536),
            nn.ReLU(inplace=True),
            SeparableConv2d(1536, 2048, dilation=self.dilation[1]),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True),
        )

        self._init_weight()

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        """

        :param x:
        :return:
            result: Output two feature map to skip connect.
        """
        x = self.stem(x)
        x = self.stage1(x)
        low_level_features = x
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)

        return x, low_level_features

config

import argparse


def get_parser():
    parser = argparse.ArgumentParser(description="segmentation test")
    parser.add_argument('--project_name', type=str, default="图像分割")
    parser.add_argument('--use_cuda', type=bool, default=True)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--data_base', type=str, default='E:/Datasets2/lane_seg')
    parser.add_argument('--resume', type=bool, default=False)
    parser.add_argument('--pretrained_model', type=str, default='./weights/')
    parser.add_argument('--lr',type=float,default=0.001)
    parser.add_argument('--milestones', type=list, default=[50, 80])
    parser.add_argument('--epoches', type=int, default=200)
    parser.add_argument('--save_path',type=str,default="./weights/")

    args = parser.parse_args()
    return args

datalist

import os

from torch.utils.data import Dataset
from torchvision import transforms

from utils import *


class BaiDuLaneDataset(Dataset):
    labels = {'void': {'id': 0, 'trainId': 0, 'category': 'void', 'catId': 0, 'ignoreInEval': False,
                       'color': [0, 0, 0]},
              's_w_d': {'id': 200, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': False,
                        'color': [70, 130, 180]},
              's_y_d': {'id': 204, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': False,
                        'color': [220, 20, 60]},
              'ds_w_dn': {'id': 213, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': True,
                          'color': [128, 20, 128]},
              'ds_y_dn': {'id': 209, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': False,
                          'color': [255, 0, 0]},
              'sb_y_do': {'id': 206, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': True,
                          'color': [0, 0, 60]},
              'sb_w_do': {'id': 207, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': True,
                          'color': [0, 60, 100]},
              'b_w_g': {'id': 201, 'trainId': 2, 'category': 'guiding', 'catId': 2, 'ignoreInEval': False,
                        'color': [0, 0, 142]},
              'b_y_g': {'id': 203, 'trainId': 2, 'category': 'guiding', 'catId': 2, 'ignoreInEval': False,
                        'color': [119, 11, 32]},
              'db_w_g': {'id': 211, 'trainId': 2, 'category': 'guiding', 'catId': 2, 'ignoreInEval': True,
                         'color': [244, 35, 232]},
              'db_y_g': {'id': 208, 'trainId': 2, 'category': 'guiding', 'catId': 2, 'ignoreInEval': True,
                         'color': [0, 0, 160]},
              'db_w_s': {'id': 216, 'trainId': 3, 'category': 'stopping', 'catId': 3, 'ignoreInEval': True,
                         'color': [153, 153, 153]},
              's_w_s': {'id': 217, 'trainId': 3, 'category': 'stopping', 'catId': 3, 'ignoreInEval': False,
                        'color': [220, 220, 0]},
              'ds_w_s': {'id': 215, 'trainId': 3, 'category': 'stopping', 'catId': 3, 'ignoreInEval': True,
                         'color': [250, 170, 30]},
              's_w_c': {'id': 218, 'trainId': 4, 'category': 'chevron', 'catId': 4, 'ignoreInEval': True,
                        'color': [102, 102, 156]},
              's_y_c': {'id': 219, 'trainId': 4, 'category': 'chevron', 'catId': 4, 'ignoreInEval': True,
                        'color': [128, 0, 0]},
              's_w_p': {'id': 210, 'trainId': 5, 'category': 'parking', 'catId': 5, 'ignoreInEval': False,
                        'color': [128, 64, 128]},
              's_n_p': {'id': 232, 'trainId': 5, 'category': 'parking', 'catId': 5, 'ignoreInEval': True,
                        'color': [238, 232, 170]},
              'c_wy_z': {'id': 214, 'trainId': 6, 'category': 'zebra', 'catId': 6, 'ignoreInEval': False,
                         'color': [190, 153, 153]},
              'a_w_u': {'id': 202, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
                        'color': [0, 0, 230]},
              'a_w_t': {'id': 220, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
                        'color': [128, 128, 0]},
              'a_w_tl': {'id': 221, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
                         'color': [128, 78, 160]},
              'a_w_tr': {'id': 222, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
                         'color': [150, 100, 100]},
              'a_w_tlr': {'id': 231, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
                          'color': [255, 165, 0]},
              'a_w_l': {'id': 224, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
                        'color': [180, 165, 180]},
              'a_w_r': {'id': 225, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
                        'color': [107, 142, 35]},
              'a_w_lr': {'id': 226, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
                         'color': [201, 255, 229]},
              'a_n_lu': {'id': 230, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
                         'color': [0, 191, 255]},
              'a_w_tu': {'id': 228, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
                         'color': [51, 255, 51]},
              'a_w_m': {'id': 229, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
                        'color': [250, 128, 224]},
              'a_y_t': {'id': 233, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
                        'color': [127, 255, 0]},
              'b_n_sr': {'id': 205, 'trainId': 8, 'category': 'reduction', 'catId': 8, 'ignoreInEval': False,
                         'color': [255, 128, 0]},
              'd_wy_za': {'id': 212, 'trainId': 8, 'category': 'attention', 'catId': 8, 'ignoreInEval': True,
                          'color': [0, 255, 255]},
              'r_wp_np': {'id': 227, 'trainId': 8, 'category': 'no parking', 'catId': 8, 'ignoreInEval': False,
                          'color': [178, 132, 190]},
              'vom_wy_n': {'id': 223, 'trainId': 8, 'category': 'others', 'catId': 8, 'ignoreInEval': True,
                           'color': [128, 128, 64]},
              'cm_n_n': {'id': 250, 'trainId': 8, 'category': 'others', 'catId': 8, 'ignoreInEval': False,
                         'color': [102, 0, 204]},
              'noise': {'id': 249, 'trainId': 0, 'category': 'ignored', 'catId': 0, 'ignoreInEval': True,
                        'color': [0, 153, 153]},
              'ignored': {'id': 255, 'trainId': 0, 'category': 'ignored', 'catId': 0, 'ignoreInEval': True,
                          'color': [255, 255, 255]},
              }

    @staticmethod
    def get_file_list(file_path, ext):
        file_list = []
        if ext == '':
            dirs = ['ColorImage_road02/ColorImage', 'ColorImage_road03/ColorImage', 'ColorImage_road04/ColorImage']
        elif ext == 'Label':
            dirs = ['Gray_Label/Label_road02', 'Gray_Label/Label_road03', 'Gray_Label/Label_road04']
        else:
            raise NotImplementedError

        for d in dirs:
            f_path = os.path.join(file_path, d, ext)
            dir_path = os.listdir(f_path)

            dir_path = sorted(dir_path)
            for dir in dir_path:
                dir = os.path.join(d, ext, dir)
                camera_file = os.listdir(file_path + '/' + dir)
                camera_file = sorted(camera_file)
                for file in camera_file:
                    path = os.path.join(dir, file)
                    for x in sorted(os.listdir(file_path + '/' + path)):
                        file_list.append(path + '/' + x)
        return file_list

    def __init__(self, root_file, phase='train', output_size=(846, 255), num_classes=8, adjust_factor=(0.3, 2.),
                 radius=(0., 1.)):
        super(BaiDuLaneDataset, self).__init__()
        assert phase in ['train', 'val', 'test']
        self.root_file = root_file
        img_ext = ''
        label_ext = 'Label'
        self.img_list = self.get_file_list(self.root_file, img_ext)
        self.label_list = self.get_file_list(self.root_file, label_ext)

        self.output_size = output_size
        self.factor = adjust_factor
        self.radius = radius
        self.transform = self.preprocess(phase)
        self.num_classes = num_classes
        self.phase = phase

        num_data = len(self.img_list)
        assert num_data == len(self.label_list)

        np.random.seed(2020)
        data_list = np.random.permutation(num_data)
        self.img_list = np.array(self.img_list)[data_list].tolist()
        self.label_list = np.array(self.label_list)[data_list].tolist()

        if phase == 'train':
            self.img_list = self.img_list[0:int(0.7 * num_data)]
            self.label_list = self.label_list[0:int(0.7 * num_data)]
        elif phase == 'val':
            self.img_list = self.img_list[int(0.7 * num_data):int(0.9 * num_data)]
            self.label_list = self.label_list[int(0.7 * num_data):int(0.9 * num_data)]
        elif phase == 'test':
            self.img_list = self.img_list[int(0.9 * num_data):]
            self.label_list = self.label_list[int(0.9 * num_data):]
        else:
            raise NotImplementedError

    def __getitem__(self, item):
        img = cv2.imread(self.root_file + '/' + self.img_list[item], cv2.IMREAD_UNCHANGED)
        target = cv2.imread(self.root_file + '/' + self.label_list[item], cv2.IMREAD_UNCHANGED)
        assert os.path.basename(self.img_list[item]).replace('.jpg', '') == \
               os.path.basename(self.label_list[item]).replace('_bin.png', '')  # 这个是保证label与data之间是对应的

        offset = 690  # 过滤上面的无用空间
        img = img[offset:, :]
        # 图片与标签之间的对齐操作
        if self.phase != 'test':
            target = target[offset:, :]
        # print(self.img_list[item])
        # print(self.label_list[item])
        target = self.encode_label_map(target)  # 读的灰度图,根据灰度图的值做标签
        # 从cv转到了PIL
        img = Image.fromarray(img)
        target = Image.fromarray(target)
        sample = {'image': img, 'label': target}

        if self.transform is not None:
            sample = self.transform(sample)

        return sample

    def __len__(self):
        if self.phase=='train':
            return len(self.img_list)
        else:
            return 10

    '''
    这个相当于是自己写的getitem
    '''

    def data_generator(self, batch_size):
        index = np.arange(0, len(self.img_list))
        while len(index):
            select = np.random.choice(index, batch_size)
            images = []
            targets = []
            for item in select:
                img = cv2.imread(self.root_file + '/' + self.img_list[item], cv2.IMREAD_UNCHANGED)
                target = cv2.imread(self.root_file + '/' + self.label_list[item], cv2.IMREAD_UNCHANGED)

                print(self.root_file + '/' + self.label_list[item])

                index = np.delete(index, select)
                sample = {'image': img, 'label': target}
                if self.transforms is not None:
                    sample = self.transform(sample)
                images.append(sample['image'])
                targets.append(sample['label'])

            yield {'image': images, 'label': targets}

            # 按字典里的内容循环,每次将符合一个类的内容的值在mask上做一次标记,最后作为mask返回

    def encode_label_map(self, mask):
        for value in self.labels.values():
            pixel = value['id']
            if value['ignoreInEval']:
                # 0: category as background
                mask[mask == pixel] = 0
            else:
                trainId = value['trainId']
                if trainId > 4:  # 这里的操作可以看出,在数据处理的时候他把4,5合并在一起了,这种要从业务层面去理解,可能4和5是一样的
                    trainId -= 1
                mask[mask == pixel] = trainId

        return mask
        # 根据mask对应的id,映射成灰度图

    def decode_label_map(self, mask):
        mask[mask == 1] = 200
        mask[mask == 2] = 201
        mask[mask == 3] = 216
        mask[mask == 4] = 210
        mask[mask == 5] = 214
        mask[mask == 6] = 202
        mask[mask == 7] = 205

        return mask
        # 根据mask对应的id,映射成彩图

    def decode_color_map(self, mask):
        new_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
        new_mask[mask == 0] = [0, 0, 0]
        new_mask[mask == 1] = [70, 130, 180]
        new_mask[mask == 2] = [0, 0, 142]
        new_mask[mask == 3] = [153, 153, 153]
        new_mask[mask == 4] = [128, 64, 128]
        new_mask[mask == 5] = [190, 153, 153]
        new_mask[mask == 6] = [0, 0, 230]
        new_mask[mask == 7] = [255, 128, 0]

        return new_mask

    def preprocess(self, phase):
        if phase == 'train':
            preprocess = transforms.Compose([
                FixedResize(self.output_size),
                Translate(50, 255),
                # RandomScale(),
                CutOut(64),
                RandomHorizontalFlip(),
                AdjustColor(self.factor),
                RandomGaussianBlur(self.radius),
                Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225)),
                ToTensor(),
            ])

        elif phase == 'val':
            preprocess = transforms.Compose([
                FixedResize(self.output_size),
                Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225)),
                ToTensor(),
            ])

        elif phase == 'test':
            preprocess = transforms.Compose([
                FixedResize(self.output_size, is_resize=False),
                Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225)),
                ToTensor(),
            ])

        else:
            raise NotImplementedError

        return preprocess

model

import numpy as np
import torch.nn as nn

from common import Decoder, ASPP, AlignedXception


def conv3x3(in_channels, out_channels, stride=1, dilation=1):
    kernel_size = np.asarray((3, 3))
    upsampled_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size
    full_padding = (upsampled_kernel_size - 1) // 2
    full_padding, kernel_size = tuple(full_padding), tuple(kernel_size)
    return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                     padding=full_padding, dilation=dilation, bias=False)


# 这里的配置使用的是deeplabv3++
class DeepLab(nn.Module):
    def __init__(self, backbone="aligned_inception", stride=16, num_classes=8, pretrained=False):
        super(DeepLab, self).__init__()
        self.backbone = AlignedXception(stride)
        planes = 128
        self.aspp = ASPP(2048, 256, 16)
        self.decoder = Decoder(planes=planes, num_classes=num_classes)

    def forward(self, x):
        x1, x2 = self.backbone(x)
        x1 = self.aspp(x1)
        x = self.decoder(x1, x2, x.size()[2:])

        return x

utils

#
# Ref:https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/dataloaders/custom_transforms.py
#

import random

import cv2
import torchvision.transforms.functional as FF
from PIL import Image, ImageOps, ImageFilter


class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Args:
        mean (tuple): means for each channel.
        std (tuple): standard deviations for each channel.
    """

    def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        img = np.array(img).astype(np.float32)
        mask = np.array(mask).astype(np.float32)
        img /= 255.0
        img -= self.mean
        img /= self.std

        return {'image': img, 'label': mask}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        img = sample['image']
        mask = sample['label']
        img = np.array(img).astype(np.float32).transpose((2, 0, 1))
        mask = np.array(mask).astype(np.float32)

        img = torch.from_numpy(img).float()
        mask = torch.from_numpy(mask).float()

        return {'image': img, 'label': mask}


class RandomHorizontalFlip(object):
    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        if np.random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

        return {'image': img, 'label': mask}


class RandomRotate(object):
    def __init__(self, degree):
        self.degree = degree

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        rotate_degree = random.uniform(-1 * self.degree, self.degree)
        img = img.rotate(rotate_degree, Image.BILINEAR)
        mask = mask.rotate(rotate_degree, Image.NEAREST)

        return {'image': img, 'label': mask}


class RandomGaussianBlur(object):
    def __init__(self, radius=(0., 1.)):
        self.radius = radius

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        if np.random.random() < 0.5:
            img = img.filter(ImageFilter.GaussianBlur(
                radius=random.uniform(*self.radius)))

        return {'image': img, 'label': mask}


class RandomScaleCrop(object):
    def __init__(self, base_size, crop_size, fill=0):
        self.base_size = base_size
        self.crop_size = crop_size
        self.fill = fill

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        # random scale (short edge)
        short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
        w, h = img.size
        if h > w:
            ow = short_size
            oh = int(1.0 * h * ow / w)
        else:
            oh = short_size
            ow = int(1.0 * w * oh / h)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # pad crop
        if short_size < self.crop_size:
            padh = self.crop_size - oh if oh < self.crop_size else 0
            padw = self.crop_size - ow if ow < self.crop_size else 0
            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
            mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
        # random crop crop_size
        w, h = img.size
        x1 = random.randint(0, w - self.crop_size)
        y1 = random.randint(0, h - self.crop_size)
        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

        return {'image': img, 'label': mask}


class FixScaleCrop(object):
    def __init__(self, crop_size):
        self.crop_size = crop_size

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        w, h = img.size
        if w > h:
            oh = self.crop_size
            ow = int(1.0 * w * oh / h)
        else:
            ow = self.crop_size
            oh = int(1.0 * h * ow / w)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # center crop
        w, h = img.size
        x1 = int(round((w - self.crop_size) / 2.))
        y1 = int(round((h - self.crop_size) / 2.))
        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

        return {'image': img, 'label': mask}


class FixedResize(object):
    def __init__(self, size, is_resize=True):
        if isinstance(size, tuple):
            self.size = size
        else:
            self.size = (size, size)  # size: (h, w)
        self.is_resize = is_resize

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']

        # assert img.size == mask.size

        img = img.resize(self.size, Image.BILINEAR)
        if self.is_resize:
            mask = mask.resize(self.size, Image.NEAREST)

        return {'image': img, 'label': mask}


class AdjustColor(object):
    def __init__(self, factor=(0.3, 2.)):
        self.factor = factor

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']

        assert img.size == mask.size
        brightness_factor = np.random.uniform(*self.factor)
        contrast_factor = np.random.uniform(*self.factor)
        saturation_factor = np.random.uniform(*self.factor)

        img = FF.adjust_brightness(img, brightness_factor)
        img = FF.adjust_contrast(img, contrast_factor)
        img = FF.adjust_saturation(img, saturation_factor)

        return {'image': img, 'label': mask}


class CutOut(object):
    def __init__(self, mask_size):
        self.mask_size = mask_size

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        image = np.array(img)
        mask = np.array(mask)

        mask_size_half = self.mask_size // 2
        offset = 1 if self.mask_size % 2 == 0 else 0

        h, w = image.shape[:2]

        # find mask center coordinate
        cxmin, cxmax = mask_size_half, w + offset - mask_size_half
        cymin, cymax = mask_size_half, h + offset - mask_size_half

        cx = np.random.randint(cxmin, cxmax)
        cy = np.random.randint(cymin, cymax)

        # left-top point
        xmin, ymin = cx - mask_size_half, cy - mask_size_half
        # right-bottom point
        xmax, ymax = xmin + self.mask_size, ymin + self.mask_size

        xmin, ymin, xmax, ymax = max(0, xmin), max(0, ymin), min(w, xmax), min(h, ymax)

        if random.uniform(0, 1) < 0.5:
            image[ymin:ymax, xmin:xmax] = (0, 0, 0)
        return {'image': Image.fromarray(image), 'label': Image.fromarray(mask)}


class RandomScale(object):
    def __call__(self, sample):
        image = sample['image']
        mask = sample['label']
        image = np.array(image)
        mask = np.array(mask)

        scale = np.random.uniform(0.7, 1.5)
        h, w = image.shape[:2]
        aug_image = image.copy()
        aug_mask = mask.copy()

        aug_image = cv2.resize(aug_image, (int(scale * w), int(scale * h)))
        aug_mask = cv2.resize(aug_mask, (int(scale * w), int(scale * h)))

        if scale < 1.:
            new_h, new_w, _ = aug_image.shape
            pre_h_pad = int((h - new_h) / 2)
            pre_w_pad = int((w - new_w) / 2)
            pad_list = [[pre_h_pad, h - new_h - pre_h_pad], [pre_w_pad, w - new_w - pre_w_pad], [0, 0]]
            aug_image = np.pad(aug_image, pad_list, mode="constant", constant_values=0)
            aug_mask = np.pad(aug_mask, pad_list[:2], mode="constant", constant_values=255)

        if scale >= 1.:
            new_h, new_w = aug_image.shape[:2]
            pre_h_crop = int((new_h - h) / 2)
            pre_w_crop = int((new_w - w) / 2)
            post_h_crop = h + pre_h_crop
            post_w_crop = w + pre_w_crop
            aug_image = aug_image[pre_h_crop:post_h_crop, pre_w_crop:post_w_crop]
            aug_mask = aug_mask[pre_h_crop:post_h_crop, pre_w_crop:post_w_crop]

        return {'image': Image.fromarray(aug_image), 'label': Image.fromarray(aug_mask)}


class Translate(object):
    def __init__(self, t=50, ingore_index=255):
        self.t = t
        self.ingore_index = ingore_index

    def __call__(self, sample):
        image = sample['image']
        target = sample['label']
        image = np.array(image)
        target = np.array(target)

        if np.random.random() > 0.5:
            x = random.uniform(-self.t, self.t)
            y = random.uniform(-self.t, self.t)
            M = np.float32([[1, 0, x],
                            [0, 1, y]])
            h, w = image.shape[:2]
            image = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderValue=(0, 0, 0))
            target = cv2.warpAffine(target, M, (w, h), flags=cv2.INTER_NEAREST, borderValue=(self.ingore_index,))

        return {'image': Image.fromarray(image), 'label': Image.fromarray(target)}


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class SegmentationLosses(nn.Module):
    def __init__(self, num_classes=8, mode='CE', weights=None,
                 ignore_index=255, gamma=2, alpha=0.5, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.mode = mode
        self.weights = weights
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.num_classes = num_classes

    def forward(self, preds, target):
        """"""
        H1, W1 = preds.size()[2:]
        H2, W2 = target.size()[1:]
        assert H1 == H2 and W1 == W2

        if self.mode == 'CE':
            return self.CrossEntropyLoss(preds, target)
        elif self.mode == 'FL':
            return self.FocalLoss(preds, target)
        elif self.mode == 'Dice':
            return self.GeneralizedSoftDiceLoss(preds, target)
        elif self.mode == 'Dice2':
            return self.BatchSoftDeviceLoss(preds, target)
        elif self.mode == 'CE || Dice':
            loss = self.CrossEntropyLoss(preds, target) + \
                   self.GeneralizedSoftDiceLoss(preds, target)
            return loss
        else:
            raise NotImplementedError

    def CrossEntropyLoss(self, preds, target):
        """

        :param preds: Tensor of shape [N, C, H, W]
        :param target: Tensor of shape [N, H, W]
        :return:
        """
        device = target.device
        # if self.weights is not None:
        #     weight = self.weights.to(device)
        # else:
        #     arr = target.data.cpu().numpy().reshape(-1)
        #     weight = np.bincount(arr)
        #     weight = weight.astype(np.float)
        #     # weight = weight.sum() / weight
        #     weight = weight / weight.sum()
        #     median = np.median(weight)
        #     for i in range(weight.shape[0]):
        #         if int(weight[i]) == 0:
        #             continue
        #         weight[i] = median / weight[i]
        #     weight = torch.from_numpy(weight).to(device).float()

        return F.cross_entropy(preds, target, weight=self.weights.to(device), ignore_index=self.ignore_index)

    def FocalLoss(self, preds, target):
        """
        FL = alpha * (1 - pt) ** beta * log(pt)
        :param preds: Tensor of shape [N, C, H, W]
        :param target: Tensor of shape [N, H, W]
        :return:
        """
        logits = -F.cross_entropy(preds, target.long(),
                                  ignore_index=self.ignore_index)
        pt = torch.exp(logits)
        if self.alpha is not None:
            logits *= self.alpha
        loss = -((1 - pt) ** self.gamma) * logits

        return loss

    def GeneralizedSoftDiceLoss(self, preds, target):
        """
        Paper:
            https://arxiv.org/pdf/1606.04797.pdf
        :param preds: Tensor of shape [N, C, H, W]
        :param target: Tensor of shape [N, H, W]
        :return:
        """
        # overcome ignored label
        ignore = target.data.cpu() == self.ignore_index
        label = target.clone()
        label[ignore] = 0
        lb_one_hot = torch.zeros_like(preds).scatter_(1, label.unsqueeze(1), 1)
        ignore = ignore.nonzero()
        _, M = ignore.size()
        a, *b = ignore.chunk(M, dim=1)
        lb_one_hot[[a, torch.arange(lb_one_hot.size(1)).long(), *b]] = 0
        lb_one_hot = lb_one_hot.detach()

        # compute loss
        probs = torch.sigmoid(preds)
        numer = torch.sum((probs * lb_one_hot), dim=(2, 3))
        denom = torch.sum(probs.pow(1) + lb_one_hot.pow(1), dim=(2, 3))
        if not self.weights is None:
            numer = numer * self.weight.view(1, -1)
            denom = denom * self.weight.view(1, -1)
        numer = torch.sum(numer, dim=1)
        denom = torch.sum(denom, dim=1)
        smooth = 1
        loss = 1 - (2 * numer + smooth) / (denom + smooth)

        if self.reduction == 'mean':
            loss = loss.mean()
        return loss

    def BatchSoftDeviceLoss(self, preds, target):
        """

        :param preds:
        :param target:
        :return:
        """
        # overcome ignored label
        ignore = target.data.cpu() == self.ignore_index
        target = target.clone()
        target[ignore] = 0
        lb_one_hot = torch.zeros_like(preds).scatter_(1, target.unsqueeze(1), 1)
        ignore = ignore.nonzero()
        _, M = ignore.size()
        a, *b = ignore.chunk(M, dim=1)
        lb_one_hot[[a, torch.arange(lb_one_hot.size(1)).long(), *b]] = 0
        lb_one_hot = lb_one_hot.detach()

        # compute loss
        probs = torch.sigmoid(preds)
        numer = torch.sum((probs * lb_one_hot), dim=(2, 3))
        denom = torch.sum(probs.pow(1) + lb_one_hot.pow(1), dim=(2, 3))
        if not self.weights is None:
            numer = numer * self.weight.view(1, -1)
            denom = denom * self.weight.view(1, -1)
        numer = torch.sum(numer)
        denom = torch.sum(denom)
        smooth = 1
        loss = 1 - (2 * numer + smooth) / (denom + smooth)

        return loss


if __name__ == '__main__':
    criteria = SegmentationLosses(mode='CE')
    #  logits = torch.randn(16, 19, 14, 14)
    im = torch.randn(16, 3, 14, 14)
    label = torch.randint(0, 19, (16, 14, 14)).long()
    net = torch.nn.Conv2d(3, 19, 3, 1, 1)
    print(label.dtype)
    label[2, 3, 3] = 255
    print(label.dtype)

    logits = net(im)
    loss = criteria(logits, label)
    loss.backward()
    print(loss)

train

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.backends import cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm

from config import get_parser
from datalist import BaiDuLaneDataset
from model import DeepLab
from utils import SegmentationLosses

metric_loss=0
class train():
    def __init__(self):
        self.args = get_parser()
        print(f"-----------{self.args.project_name}-------------")

        use_cuda = self.args.use_cuda and torch.cuda.is_available()
        if use_cuda:
            torch.cuda.manual_seed(self.args.seed)
            torch.cuda.manual_seed_all(self.args.seed)
        else:
            torch.manual_seed(self.args.seed)
        self.device = torch.device('cuda' if use_cuda else 'cpu')
        train_kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}
        test_kwargs = {'num_workers': 0, 'pin_memory': False} if use_cuda else {}
        '''
        构造DataLoader
        '''
        self.train_dataset = BaiDuLaneDataset(root_file=self.args.data_base, phase='train')
        self.test_dataset = BaiDuLaneDataset(root_file=self.args.data_base, phase='test')
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=10, **train_kwargs)
        self.test_dataloader = DataLoader(self.test_dataset, batch_size=10, **test_kwargs)
        '''
        定义模型
        '''
        self.model = DeepLab().to(self.device)

        '''
        CUDA加速
        '''
        if use_cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
            cudnn.benchmark = True

        '''
        根据需要加载与训练模型权重参数
        '''
        if self.args.resume and self.args.pretrained_model:
            data_dict = torch.load(self.args.pretrained_model)
            new_data_dict = {}
            for k, v in data_dict.items():
                new_data_dict[k] = v
            self.model.load_state_dict(new_data_dict, strict=False)
            print("load pretrained model successful!")
        else:

            print("initial net weights from stratch!")

        '''
        构造loss目标函数
        选择优化器
        学习率变化选择
        '''

        weights = torch.FloatTensor([0.00289, 0.2411, 1.068, 2.547, 7.544, 0.2689, 0.9043, 1.572])
        self.criterion = SegmentationLosses(mode='CE', weights=weights).to(self.device)  # 这里使用了weighted crossentropy
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.args.milestones,
                                                              gamma=0.1)

        for epoch in range(1, self.args.epoches):
            self.train(epoch)
            if epoch % 1 == 0:
                self.test(epoch)
        torch.cuda.empty_cache()
        print("model finish training")

    def train(self, epoch):
        global metric_loss
        self.model.train()
        average_loss = []
        pbar = tqdm(self.train_dataloader, desc=f'Train Epoch{epoch}/{self.args.epoches}')
        for data in pbar:
            img, target = data['image'], data['label']
            img, target = img.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()

            outputs = self.model(img)
            loss = self.criterion(outputs, target.long()).cpu()
            average_loss.append(loss.item())
            loss.backward()
            self.optimizer.step()
            pbar.set_description(
                f'Train Epoch:{epoch}/{self.args.epoches} '
                f'train_loss:{round(np.mean(average_loss), 2)} '
                f'learning_rate:{self.optimizer.state_dict()["param_groups"][0]["lr"]}')
        self.scheduler.step()
        if np.mean(average_loss)<metric_loss and self.args.save_path:
            metric_loss=np.mean(average_loss)
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': round(np.mean(average_loss), 2)
            },
                './weights/'+f'Epoch-{epoch}-loss-{metric_loss}.pth')
            print("model saved")



    def test(self, epoch):
        self.model.eval()
        with torch.no_grad():
            pbar = tqdm(self.test_dataloader, desc=f'Test Epoch{epoch}/{self.args.epoches}')
            for data in pbar:
                img, target = data['image'], data['label']
                img, target = img.to(self.device), target.to(self.device)
                outputs = self.model(img)
                outputs = F.interpolate(outputs, size=(1020, 3384), mode='bilinear', align_corners=True)
                preds = outputs.data.max(1)[1].cpu().numpy()

                pbar.set_description(
                    f'【Test Epoch】:{epoch}/{self.args.epoches} '

                )
            # 最后一个批次里的一张图拿出来看效果
            temp = img
            img = img.cpu().numpy()

            img = np.transpose(img[0], axes=[1, 2, 0])
            img *= (0.229, 0.224, 0.225)
            img += (0.485, 0.456, 0.406)
            img *= 255.0
            img = img.astype(np.uint8)
            img = cv2.resize(img, (3384, 1710))

            mask = np.zeros((temp.size(0), 690, 3384))
            preds = np.hstack((mask.astype(preds.dtype), preds))

            pred = preds[0].astype(np.uint8)
            pred = self.test_dataset.decode_color_map(pred)
            result = np.vstack((pred, img))
            cv2.imwrite("./result/epoch-" + str(epoch) + "_result_predict.jpg", result)


train = train()

总结

在这里插入图片描述
如图所示,这是我train的几个epoch 之后的效果,这里大致可以通过更换datalist和model实现不同而的分割项目

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值