语义分割系列16-BiSeNetV1(pytorch实现)

BiSeNet:《BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation》

发布于ECCV 2018


引文

BiSeNet模型设计初衷是提升实时语义分割的速度(105FPS在Titan XP上)和精度(68.4%的mIoU在Cityscapes上),当然,这是废话,人家的定位就是实时语义分割。

在实时语义分割的算法中,大多数工作主要采用三种加速模型计算的方法:

  1. 第一是限制输入大小,通过剪裁或者调整大小来降低计算的复杂度。这也是大部分工作最初的思路,但是呢,这种方式有一个问题,就会丢失空间上的部分细节,尤其是边缘细节。
  2. 第二是修建模型的通道,把模型的通道数缩减到一定的值,比如某个阶段2048个通道,直接缩小到256甚至128这样。当然,这样缩小肯定会丢失一些信息,尤其是在较浅层,信息比较集中且重要的时候,会削弱空间上的一些信息。
  3. 第三是有的模型比较直接,删去后面几层,让深层网络变浅一点,比如ENet,放弃最后阶段的下采样操作,这会导致模型的感受野不大,导致一些物体分割不精确。

而为了提高模型的精度,很多模型都借鉴了FCN、Unet中的那种U型结构,也就是有一个skip-connection,通过融合骨干网络中的分层特征,填充细节来帮助分辨率恢复。不过这种方式会引入更多的计算。

论文思路

有了上述的这些问题和解决方案,作者在BiSeNet中设计了一个双边结构,分别为空间路径(Spatial Path)和上下文路径(Context Path)。通过一个特征融合模块(FFM)将两个路径的特征进行融合,得到分割结果。

在模型的一些细节上,作者主要引入了这么些个模块:

  • 两个路径Spatial Path + Context Path
  • U型的网络结构
  • Attention机制(ARM模块)
  • 特征融合模块 FFM
图1 BiSeNet模型

空间路径Spatial Path

很多模型试图保留输入图像的原始分辨率,用空洞卷积(扩张卷积)的方式来编码空间信息,尽量扩大感受野;还有一些方法通过空间金字塔池化或者用大卷积核来铺货空间信息,扩大感受野。空间信息和感受野对于模型精度的影响较大,但却很难同时满足两者,毕竟还要考虑速度问题。如果使用小尺寸的图像就会丢失信息。

因此,在BiSeNet中,作者就设计了一个简单但有效的快速下采样的空间路径,通过3个Conv+BN+ReLU的组合层将原图快速下采样8倍(通过卷积层的步幅来调整),保留空间信息的同时,速度却不慢。

上下文路径Context Path

空间路径能够编码足够的空间信息,但是需要更大的感受野,因此,作者设计了一个Context Path来提供上下文信息,扩大感受野。

在这个路径中,可以通过Xception或者ResNet来快速下采样到16倍和32倍,并且,作者设计了一个半U的结构,也就是只使用16x和32x下采样倍率的特征图,在保留信息的同时,不增加过多的计算量。每一个特征图都通过一个Attention Refinement Module(ARM)来计算Attention vector,突出特征。

在32x特征图的下方,作者还设计了一个全局池化的小模块,计算一个池化后的向量,加到32x特征图的ARM输出中。

特征融合模块FFM

FFM模块用于编码两个分支的特征,设计了一个类似注意力机制的融合模块,编码空间路径(低级别信息)和上下文路径(低级别信息)的输出。最后将结果上采样8倍得到原图。


模型复现

Context Path backbone - ResNet50(缩小了通道数)

import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
    expansion: int = 4
    def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
        base_width = 64, dilation = 1, norm_layer = None):
        
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample= None,
        groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
        self.bn1 = norm_layer(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(
        self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
        width_per_group = 64//4, replace_stride_with_dilation = None, norm_layer = None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64//4
        self.dilation = 2
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
            
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64//4, layers[0])
        self.layer2 = self._make_layer(block, 128//4, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256//4, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512//4, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block,
        planes,
        blocks,
        stride = 1,
        dilate = False,
    ):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = stride
            
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,  planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion))

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        out = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        out.append(x)
        x = self.layer4(x)
        out.append(x)
        return out

    def forward(self, x) :
        return self._forward_impl(x)
    def _resnet(block, layers, pretrained_path = None, **kwargs,):
        model = ResNet(block, layers, **kwargs)
        if pretrained_path is not None:
            model.load_state_dict(torch.load(pretrained_path),  strict=False)
        return model
    
    def resnet50(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 6, 3],pretrained_path,**kwargs)
    
    def resnet101(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 23, 3],pretrained_path,**kwargs)

Spatial Path

class SpatialPath(nn.Module):
    def __init__(self):
        super(SpatialPath, self).__init__()
        
        self.downpath = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.Conv2d(64, 128, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return self.downpath(x)

Context Path + ARM

class ARM(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ARM, self).__init__()
        self.reduce_conv =  nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.module = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, out_channels, 1),
            # nn.BatchNorm2d(out_channels),
            nn.Sigmoid()
        )

class ContextPath(nn.Module):
    def __init__(self, out_channels=128):
        super(ContextPath, self).__init__()
        self.resnet = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.ARM16 = ARM(256, 128)
        self.ARM32 = ARM(512, 128)
        self.conv_head32 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.conv_head16 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.conv_avg = conv_head16 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, out_channels, 1),
            # nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.up32 = nn.Upsample(scale_factor=2., mode="bilinear")
        self.up16 = nn.Upsample(scale_factor=2., mode="bilinear")
        
    def forward(self, x):
        feat16, feat32 = self.resnet(x)
        avg = self.conv_avg(feat32)
        
        feat32_arm = self.ARM32(feat32) + avg
        feat32_up = self.up32(feat32_arm)
        feat32_up = self.conv_head32(feat32_up)
        
        feat16_arm = self.ARM16(feat16) + feat32_up
        feat16_up = self.up16(feat16_arm)
        feat16_up = self.conv_head16(feat16_up)     
        
        return feat16_up, feat32_up

FFM

class FFM(nn.Module):
    def __init__(self, channels=128):
        super(FFM, self).__init__()
        self.fuse = nn.Sequential(
            nn.Conv2d(2*channels, channels, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        
        self.skip_forward = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels//4, 1),
            nn.ReLU(),
            nn.Conv2d(channels//4, channels, 1),
            nn.Sigmoid()
        )
        
    def forward(self, SP_input, CP_input):
        x = torch.cat([SP_input, CP_input], 1)
        x = self.fuse(x)
        identify = self.skip_forward(x)
        out = torch.mul(x, identify) + x
        return out

BiSeNet

class BiSeNet(nn.Module):
    def __init__(self, num_classes):
        super(BiSeNet, self).__init__()
        self.num_classes = num_classes
        
        self.SpatialPath = SpatialPath()
        self.ContexPath = ContextPath()
        self.FFM = FFM()
        self.cls_seg = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor=8., mode="bilinear"),
            nn.Conv2d(128, self.num_classes, 3, padding=1),  
        )
        
    def forward(self, x):
        b, c, h, w = x.size()
        
        SP_out = self.SpatialPath(x)
        CP_out16, CP_Out32 = self.ContexPath(x)
        FFM_out = self.FFM(SP_out, CP_out16)
        return self.cls_seg(FFM_out)

数据集

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
import warnings
warnings.filterwarnings("ignore")
import os.path as osp
from PIL import Image
import numpy as np
from albumentations.pytorch.transforms import ToTensorV2
import matplotlib.pyplot as plt
import albumentations as A
torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
 
    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 设置数据集路径
DATA_DIR = r'database/camvid/camvid/' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    
train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)
 
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True,drop_last=True)

模型训练

model = BiSeNet(num_classes=33).cuda()
#载入预训练模型
#model.load_state_dict(torch.load(r"checkpoints/Unet++_25.pth"),strict=False)

from d2l import torch as d2l
import pandas as pd
#损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss(ignore_index=255)
#选用adam优化器来训练
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5, last_epoch=-1)

#训练50轮
epochs_num = 100
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
               devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    
    loss_list = []
    train_acc_list = []
    test_acc_list = []
    epochs_list = []
    time_list = []
    
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(
                net, features, labels.long(), loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
        scheduler.step()
        print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} ---  train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
        
        #---------保存训练数据---------------
        df = pd.DataFrame()
        loss_list.append(metric[0] / metric[2])
        train_acc_list.append(metric[1] / metric[3])
        test_acc_list.append(test_acc)
        epochs_list.append(epoch+1)
        time_list.append(timer.sum())
        
        df['epoch'] = epochs_list
        df['loss'] = loss_list
        df['train_acc'] = train_acc_list
        df['test_acc'] = test_acc_list
        df['time'] = time_list
        df.to_excel("savefile/BiSeNetV1_camvid.xlsx")
        #----------------保存模型-------------------
        if np.mod(epoch+1, 5) == 0:
            torch.save(model.state_dict(), f'checkpoints/BiSeNetV1_{epoch+1}.pth')
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

训练结果

 简单测试一下

def evaluate_accuracy_gpu(net,val_loader, device=None):
    from tqdm import tqdm
    import time
    
    """Compute the accuracy for a model on a dataset using a GPU.
    Defined in :numref:`sec_lenet`"""
    if isinstance(net, nn.Module):
        net.eval()  # Set the model to evaluation mode
        if not device:
            device = next(iter(net.parameters())).device
    # No. of correct predictions, no. of predictions
    metric = d2l.Accumulator(2)
    
    with torch.no_grad():
        start = time.time()
        num_batches = 0

        for X,y in tqdm(val_loader, ncols = 60):
            num_batches += X.size()[0]
            if isinstance(X, list):
                # Required for BERT Fine-tuning (to be covered later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), d2l.size(y))
            
            
        end = time.time()
        costTime = round(end - start, 2)
        print("测试耗费时间: ", costTime, "s")
        print("FPS: ", round(num_batches/costTime, 1))
        print("测试集精度: ", round(test_acc*100, 2), "%")
        
    return metric[0] / metric[1]

test_acc = evaluate_accuracy_gpu(model, val_loader)

  • 1
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yumaomi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值