异常检测之MemSeg
MemSeg: A semi-supervised method for image surface defect detection using differences and commonalities
- 论文链接:https://arxiv.org/abs/2205.00908
- 论文开源代码:https://github.com/TooTouch/MemSeg或https://download.csdn.net/download/thisiszdy/88893714
论文概要及其相关代码
文章主要贡献:
- 提出了一种精心设计的异常模拟策略,用于模型的自监督学习,融合了目标前景异常、纹理和结构异常3个方面。
- 提出了一种具有更高效特征匹配算法的记忆模块,并创新性地在U-Net结构中引入正常模式的记忆信息来辅助模型学习。
- 通过以上两点,并结合多尺度特征融合模块和空间注意力模块,有效地将半监督异常检测简化为端到端的语义分割任务,使半监督图像表面缺陷检测更加灵活。
- 通过广泛的实验验证,MemSeg在表面缺陷检测和定位任务中具有较高的准确性,同时更好地满足工业场景的实时性要求。
本文介绍了用于检测和定位细粒度异常的新框架,MemSeg的框架如下图所示。MemSeg以U-Net为基础框架,在训练阶段借助模拟异常样本和记忆信息完成语义分割任务,在推理阶段端到端定位图像中的异常区域。MemSeg由几个重要的部分组成,我们将按照以下顺序描述这些部分:通过人工模拟的方式生成异常样本(第3.1节),记忆信息和空间注意图的生成(第3.2节),用于融合记忆信息和图像高层特征的多尺度特征融合模块(第3.3节),以及损失函数(第3.4节)。
3.1异常模拟策略
在工业场景中,异常发生的形式多种多样,在进行数据收集时无法覆盖所有异常,这限制了监督学习方法的建模。然而,在半监督框架中,仅使用正常样本而不与非正常样本进行比较,不足以让模型学习到什么是正常模式。受DRAEM的启发,设计了一种更有效的策略来模拟异常样本,并在训练过程中引入异常样本来完成自监督学习。MemSeg通过比较非正常样本的模式来总结正常样本的模式,以缓解半监督学习的缺陷。如下图所示,本文提出的异常模拟策略主要分为三个步骤。
-
第一步:生成二维Perlin噪声 P P P,然后通过阈值 T T T二值化得到掩码 M P M_P MP;由其产生的Perlin噪声具有多个随机峰值,能够提取图像中连续的区域块。同时,考虑到某些工业部件的主体在获取的图像中所占比例较小,如果直接进行数据增强而不进行处理,容易在图像的背景部分产生噪声,增大了模拟异常样本与真实异常样本在数据分布上的差异,不利于模型学习有效的判别信息。因此,对这类图像采用前景增强策略。即对输入图像 I I I进行二值化得到掩模 M I M_I MI,利用open或close操作去除二值化过程中产生的噪声。然后,通过对得到的两个掩码 M I M_I MI和 M P M_P MP进行元素积来获得最终的掩码图像 M M M。
-
第二步,将掩模图像 M M M和噪声图像 I n I_n In进行元素积,得到感兴趣区域(ROI)。遵循DRAEM的思想,在融合过程中引入透明度因子 δ \delta δ来平衡原始图像和噪声图像的融合,使模拟异常的模式更接近真实异常。因此,有噪声的前景图像 I n ′ I_n^{'} In′可以用下面的公式生成:
I n ′ = δ ( M ⨀ I n ) + ( 1 − δ ) ( M ⨀ I ) I_n^{'}=\delta(M\bigodot I_n)+(1-\delta)(M\bigodot I) In′=δ(M⨀In)+(1−δ)(M⨀I)
对于噪声图像 I n I_n In,我们希望其最大的透明度更高,以增加模型学习的难度,从而提高模型的鲁棒性。因此,对于上述公式中的 δ \delta δ,我们将从[0.15,1]中随机均匀采样。 -
在第三步中,对掩模图像 M M M进行逆求得到 M ˉ \bar{M} Mˉ,然后对 M ˉ \bar{M} Mˉ与原图像 I I I进行元素积得到图像 I ′ I^{'} I′,并根据
I A = M ˉ ⨀ I + I n ′ I_A=\bar{M}\bigodot I+I_n^{'} IA=Mˉ⨀I+In′
得到数据增强后的图像 I A I_A IA,即模拟异常图像。将原始输入图像 I I I作为背景,通过掩模图像 M M M提取噪声图像 I n I_n In中的ROI作为前景。
其中,噪声图像 I n I_n In来自两部分,一部分来自DTD纹理数据集,旨在模拟纹理异常,另一部分来自输入图像,旨在模拟结构异常。为了模拟结构异常,首先对输入图像 I I I进行镜像对称性、旋转、亮度、饱和度和色调的随机调整;然后将初步处理后的图像均匀划分为4×8网格并随机排列得到无序图像 I n I_n In。利用上述异常模拟策略,从纹理和结构两个角度获取模拟异常样本,并在目标前景上生成大部分异常区域,最大化模拟异常样本与真实异常样本的相似性。
3.2 记忆模块和空间注意力图
记忆模块:对于人类来说,我们识别异常的基础是知道什么是正常的,通过比较测试图像与记忆中的正常图像来获得异常区域。受人类学习过程和基于嵌入的方法的启发,使用少量正常样本作为记忆样本,并使用预训练编码器(ResNet18)提取记忆样本的高层特征作为记忆信息来辅助MemSeg的学习。
为了获得记忆信息,我们首先从训练数据中随机选择正常图像作为记忆样本,并输入到编码器中,分别从ResNet18的block 1、block 2、block 3中获取N×64×64×64、N×128×32×32、N×256×16×16维度的特征。这些不同分辨率的特征共同构成了内存信息。需要强调的是,为了保证记忆信息与输入图像高层特征的统一,我们始终在ResNet18中冻结block 1、block 2、block 3的模型参数,但模型的其余部分仍然是可训练的。
在训练或推理阶段,给定一个输入图像,如图2所示,编码器还提取输入图像的高级特征,以获得64×64×64, 128×32×32和256×16×16维度的特征。这些具有不同分辨率的特征共同构成了输入图像的信息。然后计算与所有记忆信息之间的L2距离,得到输入图像与记忆样本之间的差值信息:
D
I
=
⋃
i
=
1
N
∥
M
I
i
−
I
I
∥
2
DI=\bigcup_{i=1}^{N}{\lVert MI_i-II\rVert_2}
DI=i=1⋃N∥MIi−II∥2
其中
N
N
N是记忆样本的数量。对于
N
N
N个差异信息,以每个
D
I
DI
DI元素中所有元素之和最小为标准,以获得
I
I
II
II与
M
I
MI
MI之间的最佳差异信息
D
I
∗
DI^{*}
DI∗ ,即
D
I
∗
=
a
r
g
m
i
n
D
I
i
∈
D
I
∑
x
∈
D
I
i
x
DI^{*}=\underset{DI_{i}\in DI}{argmin}\sum_{x\in DI_{i}}x
DI∗=DIi∈DIargminx∈DIi∑x
其中
i
∈
[
1
,
N
]
i\in [1,N]
i∈[1,N],最佳差值信息
D
I
∗
DI^*
DI∗包含输入样本与其最相似的记忆样本之间的差值,某一位置的差值越大,该位置对应的输入图像区域异常的概率越高。然后,将最佳差分信息
D
I
∗
DI^*
DI∗与输入图像
I
I
II
II在通道维度上的高层特征进行级联操作,得到级联信息
C
I
1
,
C
I
2
,
C
I
3
CI_1,CI_2,CI_3
CI1,CI2,CI3;最后,级联信息通过多尺度特征融合模块进行特征融合,融合后的特征通过U-Net的跳跃连接流到解码器。
源码如下:
import torch
import torch.nn.functional as F
import numpy as np
from typing import List
class MemoryBank:
def __init__(self, normal_dataset, nb_memory_sample: int = 30, device='cpu'):
self.device = device
# memory bank
self.memory_information = {}
# normal dataset
self.normal_dataset = normal_dataset
# the number of samples saved in memory bank
self.nb_memory_sample = nb_memory_sample
def update(self, feature_extractor):
feature_extractor.eval()
# define sample index
samples_idx = np.arange(len(self.normal_dataset))
np.random.shuffle(samples_idx)
# extract features and save features into memory bank
with torch.no_grad():
for i in range(self.nb_memory_sample):
# select image
input_normal, _, _ = self.normal_dataset[samples_idx[i]]
input_normal = input_normal.to(self.device)
# extract features
features = feature_extractor(input_normal.unsqueeze(0))
# save features into memoery bank
for i, features_l in enumerate(features[1:-1]):
if f'level{i}' not in self.memory_information.keys():
self.memory_information[f'level{i}'] = features_l
else:
self.memory_information[f'level{i}'] = torch.cat(
[self.memory_information[f'level{i}'], features_l], dim=0)
def _calc_diff(self, features: List[torch.Tensor]) -> torch.Tensor:
# batch size X the number of samples saved in memory
diff_bank = torch.zeros(features[0].size(0), self.nb_memory_sample).to(self.device)
# level
for l, level in enumerate(self.memory_information.keys()):
# batch
for b_idx, features_b in enumerate(features[l]):
# calculate l2 loss
diff = F.mse_loss(
input=torch.repeat_interleave(features_b.unsqueeze(0), repeats=self.nb_memory_sample, dim=0),
target=self.memory_information[level],
reduction='none'
).mean(dim=[1, 2, 3])
# sum loss
diff_bank[b_idx] += diff
return diff_bank
def select(self, features: List[torch.Tensor]) -> torch.Tensor:
# calculate difference between features and normal features of memory bank
diff_bank = self._calc_diff(features=features)
# concatenate features with minimum difference features of memory bank
for l, level in enumerate(self.memory_information.keys()):
selected_features = torch.index_select(self.memory_information[level], dim=0, index=diff_bank.argmin(dim=1))
diff_features = F.mse_loss(selected_features, features[l], reduction='none')
features[l] = torch.cat([features[l], diff_features], dim=1)
return features
空间注意力图:从具体的观察和实验(第4.6节)可以明显看出,最佳差异信息 D I ∗ DI^* DI∗对异常区域的定位有重要影响。为了充分利用差异信息,文中利用 D I ∗ DI^* DI∗提取了3个空间注意图,用于加强对异常区域最佳差异信息的猜测。
对于
D
I
∗
DI^*
DI∗中具有三个不同维度的特征,在通道维度上计算平均值,分别得到大小为16×16、32×32和64×64的三个特征图。直接使用16×16特征图作为空间注意图
M
3
M_3
M3。
M
3
M_3
M3上采样后,对32×32特征图执行元素乘积操作以获得
M
2
M_2
M2;
M
2
M_2
M2上采样后,对64×64特征图执行元素乘积操作得到
M
1
M_1
M1。空间注意度图和加权后得到的信息,分别经过MSFF处理。公式如下:
M
3
=
1
C
3
∑
i
=
1
C
3
D
I
3
i
∗
M_3=\frac{1}{C_3}\sum_{i=1}^{C_3}DI^*_{3i}
M3=C31i=1∑C3DI3i∗
M
2
=
1
C
2
(
∑
i
=
1
C
2
D
I
2
i
∗
)
⨀
M
3
U
M_2=\frac{1}{C_2}(\sum_{i=1}^{C_2}DI^*_{2i})\bigodot M^U_{3}
M2=C21(i=1∑C2DI2i∗)⨀M3U
M
1
=
1
C
1
(
∑
i
=
1
C
1
D
I
1
i
∗
)
⨀
M
2
U
M_1=\frac{1}{C_1}(\sum_{i=1}^{C_1}DI^*_{1i})\bigodot M^U_{2}
M1=C11(i=1∑C1DI1i∗)⨀M2U
其中
C
3
C_3
C3表示
D
I
3
∗
DI^*_{3}
DI3∗的通道数,
D
I
3
i
∗
DI^*_{3i}
DI3i∗表示
D
I
3
∗
DI^*_{3}
DI3∗中通道
i
i
i的特征图,
M
3
U
M^U_{3}
M3U和
M
3
U
M^U_{3}
M3U表示分别为上采样
M
3
M_{3}
M3和
M
2
M_{2}
M2后得到的特征图。
源码如下:
# https://github.com/houqb/CoordAttention/blob/main/coordatt.py
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
mip = max(8, inp // reduction)
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()
self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x
n,c,h,w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)
y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
out = identity * a_w * a_h
return out
3.3 多尺度特征融合模块
在记忆模块的帮助下,我们得到了由输入图像信息 I I II II和最佳差分信息 D I ∗ DI^* DI∗组成的级联信息,一方面直接使用存在特征冗余的问题,另一方面,增加了模型的计算规模,导致推理速度下降。鉴于多尺度特征融合在目标检测中的成功,一种直观的思路是借助通道注意力机制和多尺度特征融合策略,充分融合拼接信息中的视觉信息和语义信息。
我们提出的多尺度特征融合模块如下图所示:级联的信息
C
I
n
(
n
=
1
,
2
,
3
)
CI_n(n=1,2,3)
CIn(n=1,2,3)最初由3×3卷积层进行融合,该卷积层保持通道数量。同时,考虑
C
I
n
CI_n
CIn是在通道维度上两种信息的简单拼接,使用坐标注意力(CA)来捕获通道
C
I
n
CI_n
CIn之间的信息关系。然后,对于通过坐标注意力加权的不同维度特征,继续进行多尺度信息融合:首先利用上采样对不同维度的特征图进行分辨率对齐,然后利用卷积对通道数对齐,最后进行元素相加操作以实现多尺度特征融合。通过在3.2小节中获得的空间注意图
M
n
(
n
=
1
,
2
,
3
)
M_n(n=1,2,3)
Mn(n=1,2,3)对融合的特征进行加权,然后将其输入到最终的解码器。
源码如下:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from .coordatt import CoordAtt
class MSFFBlock(nn.Module):
def __init__(self, in_channel):
super(MSFFBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1)
self.attn = CoordAtt(in_channel, in_channel)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channel, in_channel // 2, kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channel // 2, in_channel // 2, kernel_size=3, stride=1, padding=1)
)
def forward(self, x):
x_conv = self.conv1(x)
x_att = self.attn(x)
x = x_conv * x_att
x = self.conv2(x)
return x
class MSFF(nn.Module):
def __init__(self):
super(MSFF, self).__init__()
self.blk1 = MSFFBlock(128)
self.blk2 = MSFFBlock(256)
self.blk3 = MSFFBlock(512)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.upconv32 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
)
self.upconv21 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
)
def forward(self, features):
# features = [level1, level2, level3]
f1, f2, f3 = features
# MSFF Module
f1_k = self.blk1(f1)
f2_k = self.blk2(f2)
f3_k = self.blk3(f3)
f2_f = f2_k + self.upconv32(f3_k)
f1_f = f1_k + self.upconv21(f2_f)
# spatial attention
# mask
m3 = f3[:,256:,...].mean(dim=1, keepdim=True)
m2 = f2[:,128:,...].mean(dim=1, keepdim=True) * self.upsample(m3)
m1 = f1[:,64:,...].mean(dim=1, keepdim=True) * self.upsample(m2)
f1_out = f1_f * m1
f2_out = f2_f * m2
f3_out = f3_k * m3
return [f1_out, f2_out, f3_out]
MemSeg整体结构:
import torch.nn as nn
from .decoder import Decoder
from .msff import MSFF
class MemSeg(nn.Module):
def __init__(self, memory_bank, feature_extractor):
super(MemSeg, self).__init__()
self.memory_bank = memory_bank
self.feature_extractor = feature_extractor
self.msff = MSFF()
self.decoder = Decoder()
def forward(self, inputs):
# extract features
features = self.feature_extractor(inputs)
f_in = features[0]
f_out = features[-1]
f_ii = features[1:-1]
# extract concatenated information(CI)
concat_features = self.memory_bank.select(features = f_ii)
# Multi-scale Feature Fusion(MSFF) Module
msff_outputs = self.msff(features = concat_features)
# decoder
predicted_mask = self.decoder(
encoder_output = f_out,
concat_features = [f_in] + msff_outputs
)
return predicted_mask
3.4训练约束
为了确保MemSeg的预测值接近于其真实值,我们使用L1损失和focal损失来保证图像空间中所有像素的相似性。相比于L2损失,L1损失约束下预测的分割图像保留了更多的边缘信息。同时,focal loss缓解了图像中正常和异常区域面积不平衡的问题,使模型更专注于困难样本的分割,提高异常分割的准确性。
具体来说,我们分别用以下两个公式来最小化模拟图像中异常区域的ground truth
S
S
S和模型预测值
S
^
\hat{S}
S^之间的L1损失
L
l
1
L_{l1}
Ll1和focal损失
L
f
L_f
Lf。
L
l
1
=
∥
S
−
S
^
∥
L_{l1}=\lVert S-\hat{S}\rVert
Ll1=∥S−S^∥
L
f
=
−
α
t
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
L_{f}=-\alpha _t(1-p_t)^{ \gamma }log(p_t)
Lf=−αt(1−pt)γlog(pt)
其中,当 S S S中对应像素的ground truth为1时, p t p_t pt为该像素类别的预测概率 p p p;当 S S S中对应像素的ground truth为0时, p t = 1 − p p_t=1-p pt=1−p, α \alpha α和 γ \gamma γ作为控制加权程度的超参数。
最后,我们将所有这些约束组合成一个目标函数,并得到以下目标函数:
L
a
l
l
=
λ
l
1
L
l
1
+
λ
f
L
f
L_{all}=\lambda_{l1}L_{l1}+\lambda_{f}L_f
Lall=λl1Ll1+λfLf
其中
λ
l
1
\lambda_{l1}
λl1和
λ
f
\lambda_{f}
λf是平衡超参数。在训练过程中,我们的优化目标是最小化由上式定义的目标函数。
实验结果
测试结果
1.训练结果
2.安装jupyter notebook并启动[demo] model inference.ipynb
注意,源码中要修改如下才可以跑通
3.测试结果可视化