语义分割系列19-EMANet(pytorch实现)

EMANet:《Expectation-Maximization Attention Networks for Semantic Segmentation》


发布于2019ICCV,一作的权威专栏:期望最大化注意力网络 EMANet 详解


论文思路

在语义分割中,Attention机制的重要实现方法之一就是Non Local的方法,但是,Non Local的方法有一个比较大的问题,就是计算量偏大,像DANet,有一个比较大的计算量。

而本文EMANet就提出了一种新的计算Attention的方法,并且解决了计算量大的问题。就是利用EM(Expectation-Maximization)算法来学习Attention特征。其中,E步学习Attention maps,M步则是更新一组基,经过重复迭代n次后,用这组基和Attention maps重构出特征。

EM算法旨在为隐变量模型寻找最大似然解,通俗理解就是:已知结果,反过来求使结果可能性最大的参数解。基本思想就是:首先根据己经给出的观测数据(我们得到的结果),估计出模型参数的值(参数);然后再依据上一步估计出的参数值来计算新结果(上一步得到的参数值计算出新结果,发现和上一次结果有差异,将这个结果的缺失差异补上,得到一个新结果),再根据估计出的新结果重新再对参数值进行估计,然后反复迭代,直至最后收敛,迭代结束。

EM算法的详细介绍可以参考:EM算法详解

一些废话:

看到这里,就不得不佩服作者的天才之处,从传统的机器学习算法中切入,与深度学习的前沿结合,提出新模型得到更优的结果。对于做算法学习或研究而言,很多新算法都是基于旧理论来做的创新,深度学习已经发展了很多年,但是我们总能发现,一些好的工作都或多或少借鉴了一些被容易忽略的机器学习算法,尤其这这篇EMANet,以及惊艳一时并影响深远的DeepLab,利用条件随机场CRF来对结果进一步精细化。

模型

EMANet模型结构

EMANet模型结构如上图,特征图x是主干网络的输出。我们将关注点集中在EMA Unit中

如上文提到,EMA Unit中利用了EM算法来对Attention maps进行计算。其中主要包括输入1×1Conv、输出的1×1ConvBN和\large A_M\large A_E\large A_R。 

在这里,输入通过一个1x1Conv将输出的特征图,从[0, +∞]映射到[-∞, +∞],这样,后续计算的基\large \mu的值就是[-∞, +∞],否则就是[0, +∞],失去了一半的能力。

因此,X经过一个1×1卷积之后,reshape成一个X(N×C),这里N=H×W。然后初始化一个基μ∈R(K×C),也就是有K个基类,在这里,基μ的初始化使用kaiming初始化;

E步

对于E步,也就是计算\large A_E,然后计算Attention Maps,Z=softmax(\lambda X\mu^T) 这个Z与X同分布,且每一个map的大小是H×W

M步

M步需要更新μ,因为EM这个步骤类似于RNN,作者说这里可能会遇到梯度爆炸或梯度消失的问题,因此,对于基μ的更新,作者使用了moving averaging策略,其中α为momentum且α∈[0, 1]:

\bg_white \fn_jvn \large \mu ^{(0)} \leftarrow \alpha \mu ^{(0)} + (1-\alpha ))\bar{\mu} ^{(T)}

这里需要注意的是,基μ的更新只在训练时更新,测试或验证时是不更新的。

同时,如神经网络其他层需要层归一化一样,这里的\bg_white \fn_jvn \large \mu ^{(0)}也需要进行归一化,这里作者选择L2Norm(Euclidean normalization)来实现。

重复迭代

接着重复迭代E和M步,重复T次,作者在文中做了实验,得出3次是比较好的结果。

不同T值在Pascal VOC2012上的测试结果

计算结果

对于最后得到的基μ和Attention Map-Z,重构最后的结果X,公式为:

\tilde{X} = Z^{(T)}\mu ^{(T)}

对于EM这个步骤,有人提出了一种新的思考,也就是这一步使用EM算法对特征空间内的特征进行一个聚类重构,从而减少类内特征差异,使特征更加紧凑。

模型复现

backbone-ResNet50(8倍下采样)

import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    expansion: int = 4
    def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
        base_width = 64, dilation = 1, norm_layer = None):
        
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample= None,
        groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
        self.bn1 = norm_layer(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(
        self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
        width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 2
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
            
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=1, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block,
        planes,
        blocks,
        stride = 1,
        dilate = False,
    ):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = stride
            
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,  planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion))

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
        return nn.Sequential(*layers)

    def _forward_impl(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x

    def forward(self, x) :
        return self._forward_impl(x)
    def _resnet(block, layers, pretrained_path = None, **kwargs,):
        model = ResNet(block, layers, **kwargs)
        if pretrained_path is not None:
            model.load_state_dict(torch.load(pretrained_path),  strict=False)
        return model
    
    def resnet50(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 6, 3],pretrained_path,**kwargs)
    
    def resnet101(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 23, 3],pretrained_path,**kwargs)

EMAUint

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

def reduce_mean(tensor):
    """Reduce mean when distributed training."""
    if not (dist.is_available() and dist.is_initialized()):
        return tensor
    tensor = tensor.clone()
    dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
    return tensor


class EMAModule(nn.Module):
    """Expectation Maximization Attention Module used in EMANet.

    Args:
        channels (int): Channels of the whole module.
        num_bases (int): Number of bases.
        num_stages (int): Number of the EM iterations.
    """

    def __init__(self, channels, num_bases, num_stages, momentum):
        super(EMAModule, self).__init__()
        assert num_stages >= 1, 'num_stages must be at least 1!'
        self.num_bases = num_bases
        self.num_stages = num_stages
        self.momentum = momentum

        bases = torch.zeros(1, channels, self.num_bases)
        bases.normal_(0, math.sqrt(2. / self.num_bases))
        # [1, channels, num_bases]
        bases = F.normalize(bases, dim=1, p=2)
        self.register_buffer('bases', bases)

    def forward(self, feats):
        """Forward function."""
        batch_size, channels, height, width = feats.size()
        # [batch_size, channels, height*width]
        feats = feats.view(batch_size, channels, height * width)
        # [batch_size, channels, num_bases]
        bases = self.bases.repeat(batch_size, 1, 1)

        with torch.no_grad():
            for i in range(self.num_stages):
                # [batch_size, height*width, num_bases]
                attention = torch.einsum('bcn,bck->bnk', feats, bases)
                attention = F.softmax(attention, dim=2)
                # l1 norm
                attention_normed = F.normalize(attention, dim=1, p=1)
                # [batch_size, channels, num_bases]
                bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
                # l2 norm
                bases = F.normalize(bases, dim=1, p=2)

        feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
        feats_recon = feats_recon.view(batch_size, channels, height, width)

        if self.training:
            bases = bases.mean(dim=0, keepdim=True)
            bases = reduce_mean(bases)
            # l2 norm
            bases = F.normalize(bases, dim=1, p=2)
            self.bases = (1 - self.momentum) * self.bases + self.momentum * bases

        return feats_recon


class EMAHead(nn.Module):
    def __init__(self,
                 ema_channels,
                 num_bases,
                 num_stages,
                 concat_input=True,
                 in_channels=2048,
                 channels=256,
                 momentum=0.1,
                 **kwargs):
        super(EMAHead, self).__init__(**kwargs)
        self.ema_channels = ema_channels
        self.num_bases = num_bases
        self.num_stages = num_stages
        self.concat_input = concat_input
        self.momentum = momentum
        self.ema_module = EMAModule(self.ema_channels, self.num_bases,
                                    self.num_stages, self.momentum)

        self.channels = channels
        self.in_channels = in_channels
        self.ema_in_conv = nn.Conv2d(
            self.in_channels,
            self.ema_channels,
            3,
            padding=1)
        # project (0, inf) -> (-inf, inf)
        self.ema_mid_conv = nn.Conv2d(
            self.ema_channels,
            self.ema_channels,
            1)
        
        for param in self.ema_mid_conv.parameters():
            param.requires_grad = False

        self.ema_out_conv = nn.Conv2d(
            self.ema_channels,
            self.ema_channels,
            1)
        self.bottleneck = nn.Conv2d(
            self.ema_channels,
            self.channels,
            3,
            padding=1)
        if self.concat_input:
            self.conv_cat = nn.Conv2d(
                self.in_channels + self.channels,
                self.channels,
                kernel_size=3,
                padding=1)
        self.attention = None
    def forward(self, inputs):
        """Forward function."""
        x = inputs
        feats = self.ema_in_conv(x)
        identity = feats
        feats = self.ema_mid_conv(feats)
        # recon就是attention map 可以调用出来
        recon = self.ema_module(feats)
        self.attention = recon
        recon = F.relu(recon, inplace=True)
        recon = self.ema_out_conv(recon)
        output = F.relu(identity + recon, inplace=True)
        output = self.bottleneck(output)
        if self.concat_input:
            output = self.conv_cat(torch.cat([x, output], dim=1))
        return output

EMANet

class EMANet(nn.Module):
    def __init__(self, num_classes):
        super(EMANet, self).__init__()
        self.num_classes = num_classes
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.Head = EMAHead(in_channels=2048, channels=256, ema_channels=512, num_bases=64, num_stages=3, momentum=0.1)
        self.cls_seg = nn.Sequential(
            nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, self.num_classes, 3, padding=1),
        )
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.Head(x)
        x = self.cls_seg(x)
        return x

 Dataset-Camvid

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path as osp
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
 
    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 设置数据集路径
DATA_DIR = r'database/camvid/camvid/' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    
train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)
 
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True,drop_last=True)

Train

需要注意的是,在训练时,需要更新基的参数mu,所以model.Head.ema_module.training = True

而验证和测试时不需要更新,设置model.Head.ema_module.training = False

model = EMANet(num_classes=33).cuda()

from d2l import torch as d2l
from tqdm import tqdm
import pandas as pd
#损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss(ignore_index=255)
#选用adam优化器来训练
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5, last_epoch=-1)

#训练50轮
epochs_num = 100
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
               devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    
    loss_list = []
    train_acc_list = []
    test_acc_list = []
    epochs_list = []
    time_list = []
    
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            # 训练时需要改成True, 对基base的参数mu进行滑动更新
            model.Head.ema_module.training = True
            timer.start()
            l, acc = d2l.train_batch_ch13(
                net, features, labels.long(), loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
                
        # 预测和验证时需要改成false
        model.Head.ema_module.training = False
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
        scheduler.step()
        print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} ---  train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
        
        #---------保存训练数据---------------
        df = pd.DataFrame()
        loss_list.append(metric[0] / metric[2])
        train_acc_list.append(metric[1] / metric[3])
        test_acc_list.append(test_acc)
        epochs_list.append(epoch+1)
        time_list.append(timer.sum())
        
        df['epoch'] = epochs_list
        df['loss'] = loss_list
        df['train_acc'] = train_acc_list
        df['test_acc'] = test_acc_list
        df['time'] = time_list
        df.to_excel("savefile/EMANet_camvid.xlsx")
        #----------------保存模型-------------------
        if np.mod(epoch+1, 5) == 0:
            torch.save(model.state_dict(), f'checkpoints/EMANet_{epoch+1}.pth')
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

Result

 

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yumaomi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值