异常检测 | Anomaly Detection via Reverse Distillation from One-Class Embedding |很不错的方法,可以用来找出特例!!!

本文的主旨,不是讲这个论文,而是 全面的了解这个算法!!!

🐧大模型系列篇章
💖 Fine-tuning 🔎 zero-shot模型的微调,同时保持原始模型的鲁棒性 🔎 wise-ft
💖 多模态大模型 🔎 GroundingDINO 论文总结
💖 端到端目标检测 🔎 从DETR 到 GroundingDINO 🔥
💖 多模态大模型 👉 CLIP论文总结
💖 多模态大模型 👉 EVA-CLIP
💚 生成模型 👉 从 VAE 到 Diffusion Model (上)
💚 生成模型 👉 从 VAE 到 Diffusion Model (下)🔥
💧 天气大模型

🐧深度学习基础知识篇

💖 深度学习基础知识干货 🔎 Batch Normalization 批量归一化
💖 深度学习基础知识干货 🔎 卷积模型的Memory, Params, Flop是如何计算的?
💖 深度学习基础知识干货 🔎 Cross-Entropy Loss 多分类损失函数
💖 深度学习基础知识干货 🔎 Videos 动作检测
💖 深度学习基础知识干货 🔎 目标检测(Object Detection): 你需要知道的一些概念
💖 深度学习基础知识干货 🔎 微调(fine-tuning)和泛化(generalization)
💖 深度学习基础知识干货 🔎 Group Convolution / Depthwise Convolution 轻量模型的必有的卷积
💖 深度学习基础知识干货 🔎 Gradient checkpointing
💖 深度学习基础知识干货 🔎 Softmax中温度(temperature)参数
💖 深度学习基础知识干货 🔎 什么是few-shot learning

欢迎订阅专栏,第一时间掌握最新科技
大模型系列篇章 专栏链接
深度学习基础知识 专栏链接

本文的主旨,不是讲这个论文,而是 全面的了解这个算法!!!
在这里插入图片描述

本文的主旨,不是讲这个论文,而是 全面的了解这个算法!!!


论文链接:https://arxiv.org/pdf/2201.10703
代码仓库链接:https://github.com/hq-deng/RD4AD

1. 算法简介

本文提出了一种新的异常检测方法,名为“反向蒸馏”。该方法利用 预训练的教师模型提取图像特征,并将其蒸馏到学生解码器中

学生解码器的目标是重建教师模型的多尺度特征,但由于学生模型只学习正常模式,因此无法重建异常特征,从而实现异常检测。

主要贡献

  • 反向蒸馏框架: 教师模型为编码器,学生模型为解码器,打破传统蒸馏模型的结构限制,提高模型对异常的区分能力。
  • 一类别瓶颈嵌入模块: 将教师模型的高维特征压缩到低维空间,有效抑制异常特征的传播,增强异常检测效果。
  • 实验结果: 在多个公开数据集上取得了优于现有方法的性能,证明了该方法的有效性和泛化能力。

2. 训练过程(纯干货!!!)

请看,大家先了解一下,这个算法结构。(下面有图,先看图)

  1. 算法结构: 预训练过的教师encoder E + 可训练的one-class bottleneck embedding 模块 + 学生decoder D
  2. 然后,用一个multi-scale feature fusion(MFF) 去整合E输出的从高维到低维的特征,然后用one-class embedding(OCE)模块,把MFF降维。
  3. 在训练中,学生D, 去模仿E 的行为

所以在整个训练中, 锁住教师网络E,训练学生decoder D, 和 One-Class Bottleneck Embedding模块


	# 所以在整个训练中, 锁住教师网络E,训练学生decoder D, 和 One-Class Bottleneck Embedding模块
    encoder, bn = wide_resnet50_2(pretrained=True)
    encoder = encoder.to(device)
    bn = bn.to(device)
    encoder.eval()
    decoder = de_wide_resnet50_2(pretrained=False)
    decoder = decoder.to(device)
    optimizer = torch.optim.Adam(list(decoder.parameters())+list(bn.parameters()), lr=learning_rate, betas=(0.5,0.999))
   
    #  学生网络  还原/重建/学习  教师网络 的特征中... 
    for epoch in range(epochs):
        bn.train()
        decoder.train()
        loss_list = []
        for img, label in train_dataloader:
            img = img.to(device)
            inputs = encoder(img)
            outputs = decoder(bn(inputs))#bn(inputs))
            loss = loss_fucntion(inputs, outputs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_list.append(loss.item())
        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, np.mean(loss_list)))
        if (epoch + 1) % 10 == 0:
            auroc_px, auroc_sp, aupro_px = evaluation(encoder, bn, decoder, test_dataloader, device)
            print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px))
            torch.save({'bn': bn.state_dict(),
                        'decoder': decoder.state_dict()}, ckp_path)
    return auroc_px, auroc_sp, aupro_px

  1. 在推理阶段,Teacher 模型 E 会接收输入图像并进行特征提取。Student 模型 D 会接收 Teacher 模型输出的嵌入表示作为输入。
    由于 Student 模型在训练阶段只学习重建正常模式,因此它输出的特征将只包含正常模式的信息,而不会包含异常模式的信息。
    通过比较 Teacher 模型和 Student 模型输出的特征,可以判断图像中是否存在异常。
    with torch.no_grad():
        for img, gt, label, _ in dataloader:

            img = img.to(device)
            inputs = encoder(img)
            outputs = decoder(bn(inputs))
            
            # 计算输入和输出之间的异常图。这里使用了余弦相似度作为异常度量标准。
            # anomaly_map 是一个包含异常分数的二维矩阵,每个元素表示对应像素点的异常程度。
            anomaly_map, _ = cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a')
            # 使用高斯滤波器对异常图进行平滑,去除噪声
            anomaly_map = gaussian_filter(anomaly_map, sigma=4)
            # 标签中的异常像素设置为 1,正常像素设置为 0。
            gt[gt > 0.5] = 1
            gt[gt <= 0.5] = 0
            if label.item()!=0:
            	# compute_pro(gt, anomaly_map) 计算精确度-召回率曲线下的面积 (AUROC)
                aupro_list.append(compute_pro(gt.squeeze(0).cpu().numpy().astype(int),
                                              anomaly_map[np.newaxis,:,:]))
            gt_list_px.extend(gt.cpu().numpy().astype(int).ravel())
            pr_list_px.extend(anomaly_map.ravel())
            gt_list_sp.append(np.max(gt.cpu().numpy().astype(int)))
            pr_list_sp.append(np.max(anomaly_map))
	        auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3)
	        auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3)
	    return auroc_px, auroc_sp, round(np.mean(aupro_list),3)

在这里插入图片描述
让我们深入地探讨一下 Reverse Distillation(反向蒸馏) 模型的训练过程:在这里插入图片描述

  • 得到hearmap:
    D1的余弦距离map --> gaussian_filter --> 最大最小值一化 * 255 --> heatmap
  • 计算auroc:
    D1,D2,D3 的余弦距离map累加 --> gaussian_filter 进行后续计算

2.1. 数据

  • 从 MVTec AD 或语义数据集中选取 无异常图像作为训练数据。
  • 将图像输入教师编码器,提取多尺度特征。蓝色的就是多尺度特征,后续通过MFF模块和OCE模块变成一个低维的空间,这个就是学生网络的输入。
    在这里插入图片描述

2.2 OCBE(One-Class Bottleneck Embedding) 模块

  • MFF 模块: 将教师编码器的不同层级的特征进行融合,例如将低层特征上采样并与高层特征拼接,得到更丰富的特征表示。 将 MFF 模块输出的特征进一步压缩到低维空间,保留对学生解码器恢复教师编码器响应至关重要的信息。
  • OCE 模块:使用 1x1 卷积层 将 MFF block 生成的丰富特征进一步压缩到一个低维空间,形成一个紧凑的嵌入表示。这个嵌入将作为学生解码器的输入。高维变为低维,会损失他的语义细节信息,所以小小的异常区域会被 ”抑制掉“,没有那么多丰富的语义来表述这个异常点
    在这里插入图片描述

2.3. 学生解码器训练

  • 以 OCBE 模块的输出为输入,重建教师编码器的多尺度特征。
  • 使用多尺度特征之间的余弦相似度损失函数进行训练,目标是使学生解码器输出的特征与教师编码器特征尽可能相似。

重点!!!学生网络 D 的训练目标是重建教师网络的多尺度特征,这意味着它需要 学习如何从紧凑的一类嵌入中恢复出教师网络的正常模式特征。

2.4. 损失函数

  • 多尺度余弦相似度损失: 计算教师编码器和学生解码器在每个尺度上的特征之间的余弦相似度,并将其作为损失函数的一部分。
encoder, bn = wide_resnet50_2(pretrained=True)
encoder = encoder.to(device)
bn = bn.to(device)
encoder.eval()
decoder = de_wide_resnet50_2(pretrained=False)
decoder = decoder.to(device)
# 学生decoderD 与 OCBD模块 一起优化。
optimizer = torch.optim.Adam(list(decoder.parameters())+list(bn.parameters()), lr=learning_rate, betas=(0.5,0.999))

# 这个损失函数就是多尺度余弦相似度损失函数
def loss_fucntion(a, b):
    #mse_loss = torch.nn.MSELoss()
    cos_loss = torch.nn.CosineSimilarity()
    loss = 0
    for item in range(len(a)):
        #print(a[item].shape)
        #print(b[item].shape)
        #loss += 0.1*mse_loss(a[item], b[item])
        loss += torch.mean(1-cos_loss(a[item].view(a[item].shape[0],-1),
                                      b[item].view(b[item].shape[0],-1)))
    return loss

2.5. 训练策略

  • 冻结教师编码器: 为了防止教师编码器在训练过程中发生改变,其参数在训练过程中被冻结。
  • Adam 优化器: 使用 Adam 优化器进行参数更新。

2.6. 推理过程

  • 将测试图像输入教师编码器,提取多尺度特征。
  • 将教师编码器的特征输入 OCBE 模块,得到紧凑的嵌入。
  • 将嵌入输入学生解码器,重建教师编码器的特征。
  • 计算教师编码器和学生解码器特征之间的差异,得到异常评分图。
  • 根据异常评分图进行异常检测和定位。

2.7. 模型评估

  • 使用 AUROC、PRO 等指标评估模型的异常检测和定位性能。
  • 可以通过消融实验分析不同模块和参数对模型性能的影响。

2.8 总结

Reverse Distillation 模型的训练过程涉及多个模块和步骤,需要仔细设计和调整。通过反向蒸馏策略、紧凑嵌入和多尺度特征融合等技术,该模型能够有效地学习无异常图像的特征,并准确地检测和定位异常。

3. 教师编码器的作用

3.1 教师编码器 不需要 在反向蒸馏框架中进行训练

  • 教师编码器通常采用预训练模型,例如 ImageNet 上训练的 ResNet 或 WideResNet。这些预训练模型已经具备了强大的特征提取能力,能够从图像中提取丰富的语义和结构信息。冻结教师编码器的参数可以避免破坏其已有的特征提取能力,并使其专注于提取异常特征。

  • 如果教师编码器进行训练,那么学生解码器在反向蒸馏过程中可能会学习到教师编码器的所有特征,包括异常特征。这将导致 T-S 模型在异常样本上的特征差异消失,从而无法有效进行异常检测。

  • 教师编码器充当了知识蒸馏过程中的知识来源,其丰富的特征表示为学生解码器提供了学习目标。

  • 学生解码器通过学习模仿教师编码器的行为,逐步建立起自身的特征表示能力。

  • 教师编码器在反向蒸馏框架中发挥着知识传递和特征提取的关键作用,为异常检测任务提供了强有力的支持。

3.2 教师编码器和学生解码器的不同

教师编码器和学生解码器在反向蒸馏框架中会共享一些特征,但它们之间的差异同样重要,这些差异正是学生解码器进行异常检测的关键。

  • 差异来源

    1. 网络结构: 教师编码器和学生解码器采用不同的网络结构。教师编码器通常采用卷积神经网络,用于提取图像特征;而学生解码器则采用与之相反的卷积网络结构,用于重建图像。这种结构上的差异导致两个模型在特征提取和重建方面存在不同的侧重点和表达方式。
    2. 知识蒸馏顺序: 反向蒸馏中,知识蒸馏的顺序是从高维语义特征到低维特征。这意味着学生解码器首先学习理解图像的语义信息,然后再学习重建低层特征。这种顺序有助于学生解码器更好地捕捉图像的整体结构和语义,从而更有效地识别异常。
    3. 信息瓶颈: OCBE 模块充当了信息瓶颈,将教师编码器的高维特征压缩到一个低维空间。这种压缩过程有助于去除冗余信息,并保留对异常检测有用的关键信息。因此,学生解码器从 OCBE 模块接收到的特征已经过筛选,包含更少的异常信息,更容易区分正常样本和异常样本。
  • 异常检测机制:

    • 当输入图像为正常样本时,学生解码器能够成功重建图像,重建误差较小。
    • 当输入图像为异常样本时,由于学生解码器缺乏对异常特征的理解,重建误差较大。因为OCBE模块去除了异常信息的特征。
    • 通过比较学生解码器重建图像的误差,可以判断输入图像是否为异常样本。

总结: 虽然教师编码器和学生解码器之间存在一些特征共享,但它们之间的差异同样重要。这些差异来源于网络结构、知识蒸馏顺序和信息瓶颈,共同构成了学生解码器进行异常检测的基础。

4. 应用

  • 他是按类别来应用的,比如说, 每一个类别单独训练,每一个类别,需要训练一个单独的 模型

if __name__ == '__main__':

    setup_seed(111)
    item_list = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill',
                 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']
    for i in item_list:
        train(i)

在这里插入图片描述

  • 这个是 bottle 类别,也就是说,每一个类别,需要训练一个单独的 模型
    在这里插入图片描述
  • 12
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值