异常检测之MemSeg

异常检测之MemSeg

MemSeg: A semi-supervised method for image surface defect detection using differences and commonalities

论文概要及其相关代码

文章主要贡献:

  • 提出了一种精心设计的异常模拟策略,用于模型的自监督学习,融合了目标前景异常、纹理和结构异常3个方面。
  • 提出了一种具有更高效特征匹配算法的记忆模块,并创新性地在U-Net结构中引入正常模式的记忆信息来辅助模型学习。
  • 通过以上两点,并结合多尺度特征融合模块和空间注意力模块,有效地将半监督异常检测简化为端到端的语义分割任务,使半监督图像表面缺陷检测更加灵活。
  • 通过广泛的实验验证,MemSeg在表面缺陷检测和定位任务中具有较高的准确性,同时更好地满足工业场景的实时性要求。

本文介绍了用于检测和定位细粒度异常的新框架,MemSeg的框架如下图所示。MemSeg以U-Net为基础框架,在训练阶段借助模拟异常样本记忆信息完成语义分割任务,在推理阶段端到端定位图像中的异常区域。MemSeg由几个重要的部分组成,我们将按照以下顺序描述这些部分:通过人工模拟的方式生成异常样本(第3.1节),记忆信息和空间注意图的生成(第3.2节),用于融合记忆信息和图像高层特征的多尺度特征融合模块(第3.3节),以及损失函数(第3.4节)。
在这里插入图片描述

3.1异常模拟策略

在工业场景中,异常发生的形式多种多样,在进行数据收集时无法覆盖所有异常,这限制了监督学习方法的建模。然而,在半监督框架中,仅使用正常样本而不与非正常样本进行比较,不足以让模型学习到什么是正常模式。受DRAEM的启发,设计了一种更有效的策略来模拟异常样本,并在训练过程中引入异常样本来完成自监督学习。MemSeg通过比较非正常样本的模式来总结正常样本的模式,以缓解半监督学习的缺陷。如下图所示,本文提出的异常模拟策略主要分为三个步骤。
在这里插入图片描述

  1. 第一步:生成二维Perlin噪声 P P P,然后通过阈值 T T T二值化得到掩码 M P M_P MP;由其产生的Perlin噪声具有多个随机峰值,能够提取图像中连续的区域块。同时,考虑到某些工业部件的主体在获取的图像中所占比例较小,如果直接进行数据增强而不进行处理,容易在图像的背景部分产生噪声,增大了模拟异常样本与真实异常样本在数据分布上的差异,不利于模型学习有效的判别信息。因此,对这类图像采用前景增强策略。即对输入图像 I I I进行二值化得到掩模 M I M_I MI,利用open或close操作去除二值化过程中产生的噪声。然后,通过对得到的两个掩码 M I M_I MI M P M_P MP进行元素积来获得最终的掩码图像 M M M

  2. 第二步,将掩模图像 M M M和噪声图像 I n I_n In进行元素积,得到感兴趣区域(ROI)。遵循DRAEM的思想,在融合过程中引入透明度因子 δ \delta δ来平衡原始图像和噪声图像的融合,使模拟异常的模式更接近真实异常。因此,有噪声的前景图像 I n ′ I_n^{'} In可以用下面的公式生成:
    I n ′ = δ ( M ⨀ I n ) + ( 1 − δ ) ( M ⨀ I ) I_n^{'}=\delta(M\bigodot I_n)+(1-\delta)(M\bigodot I) In=δ(MIn)+(1δ)(MI)
    对于噪声图像 I n I_n In,我们希望其最大的透明度更高,以增加模型学习的难度,从而提高模型的鲁棒性。因此,对于上述公式中的 δ \delta δ,我们将从[0.15,1]中随机均匀采样。

  3. 在第三步中,对掩模图像 M M M进行逆求得到 M ˉ \bar{M} Mˉ,然后对 M ˉ \bar{M} Mˉ与原图像 I I I进行元素积得到图像 I ′ I^{'} I,并根据
    I A = M ˉ ⨀ I + I n ′ I_A=\bar{M}\bigodot I+I_n^{'} IA=MˉI+In
    得到数据增强后的图像 I A I_A IA,即模拟异常图像。将原始输入图像 I I I作为背景,通过掩模图像 M M M提取噪声图像 I n I_n In中的ROI作为前景。

其中,噪声图像 I n I_n In来自两部分,一部分来自DTD纹理数据集,旨在模拟纹理异常,另一部分来自输入图像,旨在模拟结构异常。为了模拟结构异常,首先对输入图像 I I I进行镜像对称性、旋转、亮度、饱和度和色调的随机调整;然后将初步处理后的图像均匀划分为4×8网格并随机排列得到无序图像 I n I_n In。利用上述异常模拟策略,从纹理和结构两个角度获取模拟异常样本,并在目标前景上生成大部分异常区域,最大化模拟异常样本与真实异常样本的相似性。

3.2 记忆模块和空间注意力图

记忆模块:对于人类来说,我们识别异常的基础是知道什么是正常的,通过比较测试图像与记忆中的正常图像来获得异常区域。受人类学习过程和基于嵌入的方法的启发,使用少量正常样本作为记忆样本,并使用预训练编码器(ResNet18)提取记忆样本的高层特征作为记忆信息来辅助MemSeg的学习。

为了获得记忆信息,我们首先从训练数据中随机选择正常图像作为记忆样本,并输入到编码器中,分别从ResNet18的block 1、block 2、block 3中获取N×64×64×64、N×128×32×32、N×256×16×16维度的特征。这些不同分辨率的特征共同构成了内存信息。需要强调的是,为了保证记忆信息与输入图像高层特征的统一,我们始终在ResNet18中冻结block 1、block 2、block 3的模型参数,但模型的其余部分仍然是可训练的。

在训练或推理阶段,给定一个输入图像,如图2所示,编码器还提取输入图像的高级特征,以获得64×64×64, 128×32×32和256×16×16维度的特征。这些具有不同分辨率的特征共同构成了输入图像的信息。然后计算与所有记忆信息之间的L2距离,得到输入图像与记忆样本之间的差值信息:
D I = ⋃ i = 1 N ∥ M I i − I I ∥ 2 DI=\bigcup_{i=1}^{N}{\lVert MI_i-II\rVert_2} DI=i=1NMIiII2
其中 N N N是记忆样本的数量。对于 N N N个差异信息,以每个 D I DI DI元素中所有元素之和最小为标准,以获得 I I II II M I MI MI之间的最佳差异信息 D I ∗ DI^{*} DI ,即
D I ∗ = a r g m i n D I i ∈ D I ∑ x ∈ D I i x DI^{*}=\underset{DI_{i}\in DI}{argmin}\sum_{x\in DI_{i}}x DI=DIiDIargminxDIix
其中 i ∈ [ 1 , N ] i\in [1,N] i[1,N],最佳差值信息 D I ∗ DI^* DI包含输入样本与其最相似的记忆样本之间的差值,某一位置的差值越大,该位置对应的输入图像区域异常的概率越高。然后,将最佳差分信息 D I ∗ DI^* DI与输入图像 I I II II在通道维度上的高层特征进行级联操作,得到级联信息 C I 1 , C I 2 , C I 3 CI_1,CI_2,CI_3 CI1,CI2,CI3;最后,级联信息通过多尺度特征融合模块进行特征融合,融合后的特征通过U-Net的跳跃连接流到解码器。

源码如下:

import torch
import torch.nn.functional as F

import numpy as np
from typing import List


class MemoryBank:
    def __init__(self, normal_dataset, nb_memory_sample: int = 30, device='cpu'):
        self.device = device

        # memory bank
        self.memory_information = {}

        # normal dataset
        self.normal_dataset = normal_dataset

        # the number of samples saved in memory bank
        self.nb_memory_sample = nb_memory_sample

    def update(self, feature_extractor):
        feature_extractor.eval()

        # define sample index
        samples_idx = np.arange(len(self.normal_dataset))
        np.random.shuffle(samples_idx)

        # extract features and save features into memory bank
        with torch.no_grad():
            for i in range(self.nb_memory_sample):
                # select image
                input_normal, _, _ = self.normal_dataset[samples_idx[i]]
                input_normal = input_normal.to(self.device)

                # extract features
                features = feature_extractor(input_normal.unsqueeze(0))

                # save features into memoery bank
                for i, features_l in enumerate(features[1:-1]):
                    if f'level{i}' not in self.memory_information.keys():
                        self.memory_information[f'level{i}'] = features_l
                    else:
                        self.memory_information[f'level{i}'] = torch.cat(
                            [self.memory_information[f'level{i}'], features_l], dim=0)

    def _calc_diff(self, features: List[torch.Tensor]) -> torch.Tensor:
        # batch size X the number of samples saved in memory
        diff_bank = torch.zeros(features[0].size(0), self.nb_memory_sample).to(self.device)

        # level
        for l, level in enumerate(self.memory_information.keys()):
            # batch
            for b_idx, features_b in enumerate(features[l]):
                # calculate l2 loss
                diff = F.mse_loss(
                    input=torch.repeat_interleave(features_b.unsqueeze(0), repeats=self.nb_memory_sample, dim=0),
                    target=self.memory_information[level],
                    reduction='none'
                ).mean(dim=[1, 2, 3])

                # sum loss
                diff_bank[b_idx] += diff

        return diff_bank

    def select(self, features: List[torch.Tensor]) -> torch.Tensor:
        # calculate difference between features and normal features of memory bank
        diff_bank = self._calc_diff(features=features)

        # concatenate features with minimum difference features of memory bank
        for l, level in enumerate(self.memory_information.keys()):
            selected_features = torch.index_select(self.memory_information[level], dim=0, index=diff_bank.argmin(dim=1))
            diff_features = F.mse_loss(selected_features, features[l], reduction='none')
            features[l] = torch.cat([features[l], diff_features], dim=1)

        return features

空间注意力图:从具体的观察和实验(第4.6节)可以明显看出,最佳差异信息 D I ∗ DI^* DI对异常区域的定位有重要影响。为了充分利用差异信息,文中利用 D I ∗ DI^* DI提取了3个空间注意图,用于加强对异常区域最佳差异信息的猜测。

对于 D I ∗ DI^* DI中具有三个不同维度的特征,在通道维度上计算平均值,分别得到大小为16×16、32×32和64×64的三个特征图。直接使用16×16特征图作为空间注意图 M 3 M_3 M3 M 3 M_3 M3上采样后,对32×32特征图执行元素乘积操作以获得 M 2 M_2 M2; M 2 M_2 M2上采样后,对64×64特征图执行元素乘积操作得到 M 1 M_1 M1。空间注意度图和加权后得到的信息,分别经过MSFF处理。公式如下:
M 3 = 1 C 3 ∑ i = 1 C 3 D I 3 i ∗ M_3=\frac{1}{C_3}\sum_{i=1}^{C_3}DI^*_{3i} M3=C31i=1C3DI3i
M 2 = 1 C 2 ( ∑ i = 1 C 2 D I 2 i ∗ ) ⨀ M 3 U M_2=\frac{1}{C_2}(\sum_{i=1}^{C_2}DI^*_{2i})\bigodot M^U_{3} M2=C21(i=1C2DI2i)M3U
M 1 = 1 C 1 ( ∑ i = 1 C 1 D I 1 i ∗ ) ⨀ M 2 U M_1=\frac{1}{C_1}(\sum_{i=1}^{C_1}DI^*_{1i})\bigodot M^U_{2} M1=C11(i=1C1DI1i)M2U
其中 C 3 C_3 C3表示 D I 3 ∗ DI^*_{3} DI3的通道数, D I 3 i ∗ DI^*_{3i} DI3i表示 D I 3 ∗ DI^*_{3} DI3中通道 i i i的特征图, M 3 U M^U_{3} M3U M 3 U M^U_{3} M3U表示分别为上采样 M 3 M_{3} M3 M 2 M_{2} M2后得到的特征图。

源码如下:

# https://github.com/houqb/CoordAttention/blob/main/coordatt.py

import torch
import torch.nn as nn
import math
import torch.nn.functional as F

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

3.3 多尺度特征融合模块

在记忆模块的帮助下,我们得到了由输入图像信息 I I II II和最佳差分信息 D I ∗ DI^* DI组成的级联信息,一方面直接使用存在特征冗余的问题,另一方面,增加了模型的计算规模,导致推理速度下降。鉴于多尺度特征融合在目标检测中的成功,一种直观的思路是借助通道注意力机制和多尺度特征融合策略,充分融合拼接信息中的视觉信息和语义信息。

我们提出的多尺度特征融合模块如下图所示:级联的信息 C I n ( n = 1 , 2 , 3 ) CI_n(n=1,2,3) CIn(n=1,2,3)最初由3×3卷积层进行融合,该卷积层保持通道数量。同时,考虑 C I n CI_n CIn是在通道维度上两种信息的简单拼接,使用坐标注意力(CA)来捕获通道 C I n CI_n CIn之间的信息关系。然后,对于通过坐标注意力加权的不同维度特征,继续进行多尺度信息融合:首先利用上采样对不同维度的特征图进行分辨率对齐,然后利用卷积对通道数对齐,最后进行元素相加操作以实现多尺度特征融合。通过在3.2小节中获得的空间注意图 M n ( n = 1 , 2 , 3 ) M_n(n=1,2,3) Mn(n=1,2,3)对融合的特征进行加权,然后将其输入到最终的解码器。
在这里插入图片描述
源码如下:

import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from .coordatt import CoordAtt

class MSFFBlock(nn.Module):
    def __init__(self, in_channel):
        super(MSFFBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1)
        self.attn = CoordAtt(in_channel, in_channel)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channel, in_channel // 2, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channel // 2, in_channel // 2, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        x_conv = self.conv1(x)
        x_att = self.attn(x)
        
        x = x_conv * x_att
        x = self.conv2(x)
        return x

    
class MSFF(nn.Module):
    def __init__(self):
        super(MSFF, self).__init__()
        self.blk1 = MSFFBlock(128)
        self.blk2 = MSFFBlock(256)
        self.blk3 = MSFFBlock(512)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.upconv32 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        )
        self.upconv21 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, features):
        # features = [level1, level2, level3]
        f1, f2, f3 = features 
        
        # MSFF Module
        f1_k = self.blk1(f1)
        f2_k = self.blk2(f2)
        f3_k = self.blk3(f3)

        f2_f = f2_k + self.upconv32(f3_k)
        f1_f = f1_k + self.upconv21(f2_f)

        # spatial attention
        
        # mask 
        m3 = f3[:,256:,...].mean(dim=1, keepdim=True)
        m2 = f2[:,128:,...].mean(dim=1, keepdim=True) * self.upsample(m3)
        m1 = f1[:,64:,...].mean(dim=1, keepdim=True) * self.upsample(m2)
        
        f1_out = f1_f * m1
        f2_out = f2_f * m2
        f3_out = f3_k * m3
        
        return [f1_out, f2_out, f3_out]

MemSeg整体结构:

import torch.nn as nn
from .decoder import Decoder
from .msff import MSFF

class MemSeg(nn.Module):
    def __init__(self, memory_bank, feature_extractor):
        super(MemSeg, self).__init__()

        self.memory_bank = memory_bank
        self.feature_extractor = feature_extractor
        self.msff = MSFF()
        self.decoder = Decoder()

    def forward(self, inputs):
        # extract features
        features = self.feature_extractor(inputs)
        f_in = features[0]
        f_out = features[-1]
        f_ii = features[1:-1]

        # extract concatenated information(CI)
        concat_features = self.memory_bank.select(features = f_ii)

        # Multi-scale Feature Fusion(MSFF) Module
        msff_outputs = self.msff(features = concat_features)

        # decoder
        predicted_mask = self.decoder(
            encoder_output  = f_out,
            concat_features = [f_in] + msff_outputs
        )

        return predicted_mask

3.4训练约束

为了确保MemSeg的预测值接近于其真实值,我们使用L1损失和focal损失来保证图像空间中所有像素的相似性。相比于L2损失,L1损失约束下预测的分割图像保留了更多的边缘信息。同时,focal loss缓解了图像中正常和异常区域面积不平衡的问题,使模型更专注于困难样本的分割,提高异常分割的准确性。

具体来说,我们分别用以下两个公式来最小化模拟图像中异常区域的ground truth S S S和模型预测值 S ^ \hat{S} S^之间的L1损失 L l 1 L_{l1} Ll1和focal损失 L f L_f Lf
L l 1 = ∥ S − S ^ ∥ L_{l1}=\lVert S-\hat{S}\rVert Ll1=SS^
L f = − α t ( 1 − p t ) γ l o g ( p t ) L_{f}=-\alpha _t(1-p_t)^{ \gamma }log(p_t) Lf=αt(1pt)γlog(pt)

其中,当 S S S中对应像素的ground truth为1时, p t p_t pt为该像素类别的预测概率 p p p;当 S S S中对应像素的ground truth为0时, p t = 1 − p p_t=1-p pt=1p α \alpha α γ \gamma γ作为控制加权程度的超参数。

最后,我们将所有这些约束组合成一个目标函数,并得到以下目标函数:
L a l l = λ l 1 L l 1 + λ f L f L_{all}=\lambda_{l1}L_{l1}+\lambda_{f}L_f Lall=λl1Ll1+λfLf
其中 λ l 1 \lambda_{l1} λl1 λ f \lambda_{f} λf是平衡超参数。在训练过程中,我们的优化目标是最小化由上式定义的目标函数。

实验结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

测试结果

1.训练结果
在这里插入图片描述
2.安装jupyter notebook并启动[demo] model inference.ipynb
在这里插入图片描述
注意,源码中要修改如下才可以跑通
在这里插入图片描述
3.测试结果可视化
在这里插入图片描述

  • 14
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值