Non-Local Net:《Non-Local Neural Networks》
发布于2018CVPR。
引文
什么是Local和Non Local?
在语义分割论文中,最多被提到的词恐怕就是receptive field了,也就是感受野。也就是,天下模型苦receptive field久矣。如之前写的文章中所提到的,大家在增加感受野的方式上基本大同小异,无非就是扩大卷积核、用扩张卷积(空洞卷积)、多叠几个卷积层,当然,这些操作都有一定的扩张感受野能力,但是缺点也是显而易见的,包括不限于增加计算量、损失邻域信息等等。所以,可以说,卷积操作是Local的,因为只有卷积核覆盖的那一个局部的信息有交互,再远了就没了。
而Attention is All you need横空出世后,大家发现,原来Attention机制(尤其是指空间注意力)可以无视距离,来建立两个位置之间的信息交互(就像峡谷里死歌开大,跑都跑不掉那种),那这就是Non-Local的操作。
当然,全连接其实也是Non Local的,但是真的全部连接的话,参数量也是大的惊人。
为了解决这个感受野受限的问题,大家就开始对Attention机制开始各种魔改啊。其中,我认为就包括了Non-Local的操作(至少在我浅薄的知识体系中,我认为Non-Local无非是一个Attention的变体)。
论文思想
![](https://img-blog.csdnimg.cn/c629c8a2069b4b5bbe998ae6e578fe59.png)
Non-Local的思想一张图就能概括完,乍一看,这三个输入Θ、Φ、g,不就对应着Attention机制里面的qkv么。确实结构也很像。
不妨回顾一下之前写过DANet中的PAM结构
啊这。一模一样有没有。当然DANet是2019年的工作,看到NonLocal还得喊一声大哥。
当然,这里的Non-Local有一个比较不同的地方,我们都知道,attention需要计算一个相似度,而这里NonLocal选择了几个相似度计算的方法,包括Gaussian版本、embedded_gaussian、concatenation还有点积版本的,很显然,DANet中的PAM使用的是点积版本的。在NonLocal论文中使用的是Gaussian版本的。
那我们再回到Non-Local结构中来,把图在下面重新放一下。
NL结构用一个公式就能概括,和Attention类似。
对于每一个点,也可以写成:
其中i, j是两个点的位置。新计算出的y即使i位置的新值或者叫Attention激活值。g()函数比较简单,线性嵌入即可,用一个卷积核来实现,而这个f()函数就是作者介绍的几种相似度的计算函数,包括有:
Gaussian
Embedded Gaussian
Dot product
Concatenation
小结
Non-Local机制类似于空间Attention,被引入来建立任意距离的位置的链接。后续很多工作也在Non-Local机制上做了一些改动。
模型复现
由于本文注重于二维的图像分割,所以只实现了2D的Non-Local模块,并瞎写了一个模型来实现,瞎写的模型大抵类似于DeepLabv3+结构。
backbone-resnet50(8倍下采样)
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion: int = 4
def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
base_width = 64, dilation = 1, norm_layer = None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = nn.Conv2d(inplanes, planes ,3, stride=stride,
padding=dilation,groups=groups, bias=False,dilation=dilation)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes ,3, stride=stride,
padding=dilation,groups=groups, bias=False,dilation=dilation)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample= None,
groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 2
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
f"or a 3-element tuple, got {replace_stride_with_dilation}"
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=1, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(
self,
block,
planes,
blocks,
stride = 1,
dilate = False,
):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = stride
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
norm_layer(planes * block.expansion))
layers = []
layers.append(
block(
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
)
)
return nn.Sequential(*layers)
def _forward_impl(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x1 = x
x = self.layer3(x)
x = self.layer4(x)
return x1, x
def forward(self, x) :
return self._forward_impl(x)
def _resnet(block, layers, pretrained_path = None, **kwargs,):
model = ResNet(block, layers, **kwargs)
if pretrained_path is not None:
model.load_state_dict(torch.load(pretrained_path), strict=False)
return model
def resnet50(pretrained_path=None, **kwargs):
return ResNet._resnet(Bottleneck, [3, 4, 6, 3],pretrained_path,**kwargs)
def resnet101(pretrained_path=None, **kwargs):
return ResNet._resnet(Bottleneck, [3, 4, 23, 3],pretrained_path,**kwargs)
Non-Local
import torch
import torch.nn as nn
class NonLocal2d(nn.Module):
def __init__(self, in_channels, reduction=2, use_scale=True, sub_sample=False, mode='embedded_gaussian'):
super(NonLocal2d, self).__init__()
self.in_channels = in_channels
self.reduction = reduction
self.use_scale = use_scale
self.inter_channels = max(in_channels // reduction, 1)
self.mode = mode
if mode not in [
'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
]:
raise ValueError("Mode should be in 'gaussian', 'concatenation', "
f"'embedded_gaussian' or 'dot_product', but got "
f'{mode} instead.')
self.g = nn.Conv2d(
self.in_channels,
self.inter_channels,
kernel_size=1,)
self.conv_out = nn.Conv2d(
self.inter_channels,
self.in_channels,
kernel_size=1,)
if self.mode != 'gaussian':
self.theta = nn.Conv2d(
self.in_channels,
self.inter_channels,
kernel_size=1,)
self.phi = nn.Conv2d(
self.in_channels,
self.inter_channels,
kernel_size=1,)
if self.mode == 'concatenation':
self.concat_project = nn.Sequential(
nn.Conv2d(
self.inter_channels * 2,
1,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.ReLU())
self.sub_sample = sub_sample
if sub_sample:
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
self.g = nn.Sequential(self.g, max_pool_layer)
if self.mode != 'gaussian':
self.phi = nn.Sequential(self.phi, max_pool_layer)
else:
self.phi = max_pool_layer
self.init_weights()
def init_weights(self, std=0.01, zeros_init=True):
if self.mode != 'gaussian':
for m in [self.g, self.theta, self.phi]:
nn.init.normal_(m.weight.data, std=std)
else:
nn.init.normal_(self.g.weight.data, std=std)
if zeros_init:
nn.init.normal_(self.conv_out.weight.data, 0)
else:
nn.init.normal_(self.conv_out.weight.data, std=std)
def gaussian(self, theta_x, phi_x):
# NonLocal2d pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def embedded_gaussian(self, theta_x, phi_x):
# NonLocal2d pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def dot_product(self, theta_x, phi_x):
# NonLocal2d pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def concatenation(self, theta_x, phi_x):
# NonLocal2d pairwise_weight: [N, HxW, HxW]
h = theta_x.size(2)
w = phi_x.size(3)
theta_x = theta_x.repeat(1, 1, 1, w)
phi_x = phi_x.repeat(1, 1, h, 1)
concat_feature = torch.cat([theta_x, phi_x], dim=1)
pairwise_weight = self.concat_project(concat_feature)
n, _, h, w = pairwise_weight.size()
pairwise_weight = pairwise_weight.view(n, h, w)
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def forward(self, x):
# NonLocal2d x: [N, C, H, W]
n = x.size(0)
# NonLocal2d g_x: [N, HxW, C]
g_x = self.g(x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
if self.mode == 'gaussian':
theta_x = x.view(n, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(n, self.in_channels, -1)
else:
phi_x = x.view(n, self.in_channels, -1)
elif self.mode == 'concatenation':
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
else:
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(n, self.inter_channels, -1)
pairwise_func = getattr(self, self.mode)
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = pairwise_func(theta_x, phi_x)
# NonLocal2d y: [N, HxW, C]
y = torch.matmul(pairwise_weight, g_x)
# NonLocal2d y: [N, C, H, W]
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
*x.size()[2:])
output = x + self.conv_out(y)
return output
class NLHead(nn.Module):
def __init__(self, reduction=2, use_scale=True, mode='embedded_gaussian', channels=2048):
super(NLHead, self).__init__()
self.reduction = reduction
self.use_scale = use_scale
self.mode = mode
self.channels = channels
self.nl_block = NonLocal2d(
in_channels=self.channels,
reduction=self.reduction,
use_scale=self.use_scale,
mode=self.mode)
self.conv1 = nn.Sequential(
nn.Conv2d(self.channels, self.channels, 3, padding=1),
)
self.conv2 = nn.Sequential(
nn.Conv2d(self.channels, self.channels, 3, padding=1),
)
self.cat = False
def forward(self, inputs):
"""Forward function."""
x = inputs
output = self.conv1(x)
output = self.nl_block(output)
output = self.conv2(output)
if self.cat:
output = torch.cat([x, output], dim=1)
return output
Model
class NonLocalNet(nn.Module):
def __init__(self, num_classes):
super(NonLocalNet, self).__init__()
self.num_classes = num_classes
self.resnet = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
self.NLHead = NLHead()
self.upSample = nn.Sequential(
nn.Conv2d(2048, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Upsample(scale_factor=2., mode="bilinear", align_corners=True),
)
self.cls_seg = nn.Sequential(
nn.Conv2d(1024, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Upsample(scale_factor=4., mode="bilinear", align_corners=True),
nn.Conv2d(256, self.num_classes, 3, padding=1),
)
def forward(self, x):
x_4, x_8 = self.resnet(x)
x_8 = self.NLHead(x_8)
x_8 = self.upSample(x_8)
x = torch.cat([x_8, x_4], 1)
x = self.cls_seg(x)
return x
Dataset-Camvid
# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path as osp
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import matplotlib.pyplot as plt
torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
"""CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
Args:
images_dir (str): path to images folder
masks_dir (str): path to segmentation masks folder
class_values (list): values of classes to extract from segmentation mask
augmentation (albumentations.Compose): data transfromation pipeline
(e.g. flip, scale, etc.)
preprocessing (albumentations.Compose): data preprocessing
(e.g. noralization, shape manipulation, etc.)
"""
def __init__(self, images_dir, masks_dir):
self.transform = A.Compose([
A.Resize(224, 224),
A.HorizontalFlip(),
A.VerticalFlip(),
A.Normalize(),
ToTensorV2(),
])
self.ids = os.listdir(images_dir)
self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
def __getitem__(self, i):
# read data
image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
image = self.transform(image=image,mask=mask)
return image['image'], image['mask'][:,:,0]
def __len__(self):
return len(self.ids)
# 设置数据集路径
DATA_DIR = r'database/camvid/camvid/' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
train_dataset = CamVidDataset(
x_train_dir,
y_train_dir,
)
val_dataset = CamVidDataset(
x_valid_dir,
y_valid_dir,
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True,drop_last=True)
Train
model = NonLocalNet(num_classes=33).cuda()
#载入预训练模型
#model.load_state_dict(torch.load(r"checkpoints/Unet++_25.pth"),strict=False)
from d2l import torch as d2l
from tqdm import tqdm
import pandas as pd
#损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss(ignore_index=255)
#选用adam优化器来训练
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5, last_epoch=-1)
#训练50轮
epochs_num = 100
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
devices=d2l.try_all_gpus()):
timer, num_batches = d2l.Timer(), len(train_iter)
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
legend=['train loss', 'train acc', 'test acc'])
net = nn.DataParallel(net, device_ids=devices).to(devices[0])
loss_list = []
train_acc_list = []
test_acc_list = []
epochs_list = []
time_list = []
for epoch in range(num_epochs):
# Sum of training loss, sum of training accuracy, no. of examples,
# no. of predictions
metric = d2l.Accumulator(4)
for i, (features, labels) in enumerate(train_iter):
timer.start()
l, acc = d2l.train_batch_ch13(
net, features, labels.long(), loss, trainer, devices)
metric.add(l, acc, labels.shape[0], labels.numel())
timer.stop()
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,
(metric[0] / metric[2], metric[1] / metric[3],
None))
test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
animator.add(epoch + 1, (None, None, test_acc))
scheduler.step()
print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
#---------保存训练数据---------------
df = pd.DataFrame()
loss_list.append(metric[0] / metric[2])
train_acc_list.append(metric[1] / metric[3])
test_acc_list.append(test_acc)
epochs_list.append(epoch+1)
time_list.append(timer.sum())
df['epoch'] = epochs_list
df['loss'] = loss_list
df['train_acc'] = train_acc_list
df['test_acc'] = test_acc_list
df['time'] = time_list
df.to_excel("savefile/NonLocalNet_camvid.xlsx")
#----------------保存模型-------------------
if np.mod(epoch+1, 5) == 0:
torch.save(model.state_dict(), f'checkpoints/NonLocalNet_{epoch+1}.pth')
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)
结果