检测图像P图痕迹(论文复现)


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

备注

需要本文的详细复现过程的项目源码、数据和预训练好的模型可从该地址处获取完整版:地址跳转

一.揭秘AI图像篡改检测:让恶意P图无处遁形

在这个数字时代,图像篡改已经变得非常普遍,这给我们的社会带来了许多负面影响。为了应对这一挑战,我们开发了一款AI图像篡改检测系统,该系统可以自动识别并标记出疑似篡改的区域,让恶意P图无处遁形。 用户只需要上传疑似篡改的图像,系统便会输出该图像疑似篡改的区域。这一技术的出现,无疑为我们的社会带来了巨大的帮助,可以帮助我们识别虚假信息,保护我们的网络安全。 然而,我们需要明确一点,我们的模型并不是万能的。虽然它在检测拼接的自然图像的时候具有较高的准确率,但是对于物体擦除或者人脸P图来说,结果稍显逊色,适用的场景有限。因此,在使用我们的系统时,用户需要有一定的心理准备,不能期望它能解决所有问题。 此外,用户自己训练的模型的效果与数据集高度相关。不一样的数据集训练得到的检测结果可能不同。因此,如果用户希望得到更好的检测结果,就需要准备更高质量、更符合实际需求的数据集。 总的来说,虽然我们的AI图像篡改检测系统还有许多需要改进的地方,但我们相信,随着技术的不断发展,它的能力会越来越强,能够更好地服务于我们的社会。

一、背景及意义介绍 随着照相机、手机、平板电脑、摄像机等数码设备与Photoshop、美图秀秀等各种图片编辑软件的飞速发展,数字图像的生成与编辑已经变得非常容易,几乎人人都有能力生成、编辑大量的数字图像。这虽然方便了人们的生活,但也使图像篡改变得越来越容易,伪造图像也变得越来越不易察觉,甚至能够以假乱真。

日常生活中人们对图像进行拼接,往往是出于美化、娱乐的目的,这并不会带来不良影响,但是当图像在被恶意拼接篡改的情况下经过传播,就会引导人们对客观事物产生错误的判断,有时甚至会对社会和国家造成不良的影响。因此,图像拼接篡改检测在当今社会显得愈发重要。

2004年,小布什竞选时选用的宣传图片,实际上是将小布什的照片拼接到其他照片上得到的。虚假的拼接图像在选举时干扰了民众的决策,对选举的结果也造成了不小的影响。如图所示,上边是篡改前的图像,下边是经过篡改后的图像。

在这里插入图片描述在这里插入图片描述

因此,研究相应的图像算法检测方法,用于判断图片是否经过人为编辑非常重要。图像篡改检测的研究和发展,能够降低造假图像给社会带来的各种不良影响,有助于维护公共信任秩序,打击一些不良的图片造假犯罪行为。

二、概述

本文通过解读并复现两篇论文,来揭秘AI图像篡改检测领域的相关研究。本文解读并复现的论文是《Image Manipulation Detection by Multi-View Multi-Scale Supervision》和《MVSS-Net: Multi-View Multi-Scale Supervised Networks for Image Manipulation Detection》,其中前者2021年发表于ICCV(International Conference on Computer Vision),而《MVSS-Net: Multi-View Multi-Scale Supervised Networks for Image Manipulation Detection》是对该会议论文的改进版本,2022年发表于TPAMI(IEEE Transactions on Pattern Analysis and Machine Intelligence)期刊。众所周知,ICCV会议和TPAMI期刊都是计算机人工智能领域的顶会和顶刊。

三、论文背景

研究背景:

图像篡改技术的发展:随着图像编辑工具的普及,图像篡改变得越来越容易,而人眼很难分辨篡改图像与真实图像。这对信息安全、法律取证等领域提出了挑战。

  • 传统篡改检测方法的局限:早期的篡改检测方法主要基于手工特征,如SIFT、LBP等,泛化能力有限。深度学习方法虽然提高了性能,但主要基于单一视角和尺度,忽略了篡改痕迹的多样性。

  • 多视角多尺度学习的潜力:不同视角(如空域、频域)和尺度的特征可能包含互补的篡改线索。联合利用这些信息,有望进一步提升篡改检测性能。

研究意义:

技术创新:

  • 提出了一种新的多视角多尺度学习范式,开创了篡改检测的新思路。

  • 在特征提取、融合、监督学习等方面进行了精巧设计,实现了高效的多视角多尺度表示学习。 性能提升:

  • 在多个公开数据集上取得了领先的性能,刷新了篡改检测的技术水平。

  • 多视角多尺度学习显著优于单一视角或尺度,证明了融合互补信息的优势。

应用价值:

  • 在信息安全、法律取证、新闻真实性甄别等领域具有广阔的应用前景。 提高了篡改图像检测的准确率,有助于遏制虚假信息的传播,维护网络空间的真实可信。 启发意义:

  • 多视角多尺度学习可以推广到其他计算机视觉任务,如目标检测、语义分割等。 启发我们要从多角度理解问题,挖掘数据中的丰富信息,设计巧妙的融合机制,提升算法性能。

四、论文思路

这两篇论文的核心思路是利用多视角和多尺度的方式来监督图像篡改检测模型的训练。传统的方法通常只关注整个图像,而这些工作认为不同的区域和尺度对于检测也很重要。 具体来说,多视角是指从多个角度(视角)观察图像,比如全局视角和局部视角。全局视角关注整个图像,局部视角关注图像的局部区域。多尺度则是在不同的分辨率下观察图像。通过多视角多尺度的监督,模型能够同时学习到全局和局部的语义特征,并在不同尺度下捕获细节信息,从而提高检测性能。

  • 多视角学习: 篡改痕迹在RGB空间、噪声域、频域有不同的表现。单一视角难以全面捕捉,因此需要多视角学习。

  • 多尺度监督: 篡改痕迹在不同尺度表现不同。小尺度注重局部细节,大尺度注重全局一致性。多尺度监督有助于融合不同尺度的特征。

  • 全局与局部建模: 全局特征提供场景级别的信息,局部特征关注像素级别的变化。两者的结合可以更准确地定位篡改区域。

五、模型结构

这两篇论文的模型结构都是基于卷积神经网络,但做了相应改进。 ICCV的会议论文提出了一种新的金字塔特征融合模块,将不同尺度的特征进行融合。TPAMI的期刊论文在此基础上,引入了一种新的注意力模块,用于学习视角之间的交互关系。同时还设计了一种新的金字塔池化模块,以更好地融合多尺度特征。具体来说,可以从以下几个方面描述模型结构:

骨干网络: ICCV论文采用Res2Net,TPAMI论文采用SwinTransformer,通过更强的骨干提取多尺度特征。

多视角特征提取:

  • RGB空间:骨干网络提取特征。

  • 噪声域:用SRM滤波器提取噪声残差特征。

  • 频域:用DCT变换提取频域特征。

特征融合:

  • ICCV版论文:在每个尺度concat三个视角的特征图,再通过卷积层融合。

  • TPAMI版论文:引入自注意力机制,对多视角特征加权融合,提高表示能力。

在这里插入图片描述 模型框架图

六、损失函数

为了实现多视角多尺度监督,这两篇论文都设计了相应的损失函数。ICCV论文只使用多尺度交叉熵损失,TPAMI论文将以下两种损失相加作为总损失。

多尺度交叉熵损失:在每个尺度上计算预测结果与真值图的交叉熵损失,再加权求和。 多尺度F1损失:在每个尺度上计算预测结果的F1分数,将其转化为损失函数,再加权求和。F1损失可以缓解类别不平衡问题。

七、复现过程(重要)

在图像篡改检测的研究中,先看实验结果图,Images列展示的是被篡改的图像,而Mask列则显示的是对应的篡改区域。这类研究的核心目标是最准确地定位出图像中的篡改部分。从图像分类的角度来理解,这个问题可以被视为一个逐像素的二分类任务。具体来说,对于被篡改的图像A中的每一个像素点(x,y),模型需要判断该像素点是否被篡改。如果模型判断一个像素点被篡改,那么它就会输出1;如果判断为未篡改,则输出0。这样,所有像素点的0和1输出组合起来,就形成了一张与原始图像A分辨率相同的Mask图像。在这张Mask图像上,白色区域代表了图像中被篡改的位置,而黑色区域则表示未被篡改的部分。通过这种方式,篡改检测模型能够生成一幅精确的篡改区域图,帮助用户识别和定位图像中的不真实内容。

在这里插入图片描述 由于原论文只给出了推理代码,而无训练代码,因此本文主要复现其训练过程,达到如上图所示结果。接下来就是复现步骤,可以从以下几个过程分解复现(只给出关键代码):

步骤1:

搭建一个最常规的Pytorch训练框架,包括数据集的加载,迭代训练,这一部分的代码可以在Pytorch官网的教程文档中找到;

以Pytorch官网教程中给的以卷积神经网络识别mnist手写数字为例,部分的训练框架如下:

<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#aa5500"># 定义数据转换步骤,包括转换为张量,以及标准化处理</span>
<span style="color:#000000">transform</span><span style="color:#981a1a">=</span><span style="color:#000000">transforms</span>.<span style="color:#000000">Compose</span>([
    <span style="color:#000000">transforms</span>.<span style="color:#000000">ToTensor</span>(),  <span style="color:#aa5500"># 将图片转换为PyTorch张量</span>
    <span style="color:#000000">transforms</span>.<span style="color:#000000">Normalize</span>((<span style="color:#116644">0.1307</span>,), (<span style="color:#116644">0.3081</span>,))  <span style="color:#aa5500"># 根据MNIST数据集的均值和标准差进行标准化</span>
    ])
​
<span style="color:#aa5500"># 加载训练集,如果不存在则下载,并应用上述转换</span>
<span style="color:#000000">dataset1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">datasets</span>.<span style="color:#000000">MNIST</span>(<span style="color:#aa1111">'../data'</span>, <span style="color:#000000">train</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>, <span style="color:#000000">download</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>,
                   <span style="color:#000000">transform</span><span style="color:#981a1a">=</span><span style="color:#000000">transform</span>)
​
<span style="color:#aa5500"># 加载测试集,并应用上述转换</span>
<span style="color:#000000">dataset2</span> <span style="color:#981a1a">=</span> <span style="color:#000000">datasets</span>.<span style="color:#000000">MNIST</span>(<span style="color:#aa1111">'../data'</span>, <span style="color:#000000">train</span><span style="color:#981a1a">=</span><span style="color:#770088">False</span>,
                   <span style="color:#000000">transform</span><span style="color:#981a1a">=</span><span style="color:#000000">transform</span>)
​
<span style="color:#aa5500"># 创建训练数据加载器,使用train_kwargs中的参数</span>
<span style="color:#000000">train_loader</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">utils</span>.<span style="color:#000000">data</span>.<span style="color:#000000">DataLoader</span>(<span style="color:#000000">dataset1</span>,<span style="color:#981a1a">**</span><span style="color:#000000">train_kwargs</span>)
​
<span style="color:#aa5500"># 创建测试数据加载器,使用test_kwargs中的参数</span>
<span style="color:#000000">test_loader</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">utils</span>.<span style="color:#000000">data</span>.<span style="color:#000000">DataLoader</span>(<span style="color:#000000">dataset2</span>, <span style="color:#981a1a">**</span><span style="color:#000000">test_kwargs</span>)
​
<span style="color:#aa5500"># 实例化模型,并将其移动到指定设备</span>
<span style="color:#000000">model</span> <span style="color:#981a1a">=</span> <span style="color:#000000">Net</span>().<span style="color:#000000">to</span>(<span style="color:#000000">device</span>)
​
<span style="color:#aa5500"># 实例化Adadelta优化器,传入模型参数和学习率</span>
<span style="color:#000000">optimizer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">optim</span>.<span style="color:#000000">Adadelta</span>(<span style="color:#000000">model</span>.<span style="color:#000000">parameters</span>(), <span style="color:#000000">lr</span><span style="color:#981a1a">=</span><span style="color:#000000">args</span>.<span style="color:#000000">lr</span>)
​
<span style="color:#aa5500"># 实例化学习率调度器,按照固定的步长和衰减因子调整学习率</span>
<span style="color:#000000">scheduler</span> <span style="color:#981a1a">=</span> <span style="color:#000000">StepLR</span>(<span style="color:#000000">optimizer</span>, <span style="color:#000000">step_size</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>, <span style="color:#000000">gamma</span><span style="color:#981a1a">=</span><span style="color:#000000">args</span>.<span style="color:#000000">gamma</span>)
​
<span style="color:#aa5500"># 开始训练循环,对于每个epoch</span>
<span style="color:#770088">for</span> <span style="color:#000000">epoch</span> <span style="color:#770088">in</span> <span style="color:#3300aa">range</span>(<span style="color:#116644">1</span>, <span style="color:#000000">args</span>.<span style="color:#000000">epochs</span> <span style="color:#981a1a">+</span> <span style="color:#116644">1</span>):
    <span style="color:#aa5500"># 执行训练过程</span>
    <span style="color:#000000">train</span>(<span style="color:#000000">args</span>, <span style="color:#000000">model</span>, <span style="color:#000000">device</span>, <span style="color:#000000">train_loader</span>, <span style="color:#000000">optimizer</span>, <span style="color:#000000">epoch</span>)
    <span style="color:#aa5500"># 执行测试过程</span>
    <span style="color:#000000">test</span>(<span style="color:#000000">model</span>, <span style="color:#000000">device</span>, <span style="color:#000000">test_loader</span>)
    <span style="color:#aa5500"># 调度器步进,根据策略调整学习率</span>
    <span style="color:#000000">scheduler</span>.<span style="color:#000000">step</span>()
​
<span style="color:#aa5500"># 如果指定了保存模型,则保存模型的参数</span>
<span style="color:#770088">if</span> <span style="color:#000000">args</span>.<span style="color:#000000">save_model</span>:
    <span style="color:#000000">torch</span>.<span style="color:#000000">save</span>(<span style="color:#000000">model</span>.<span style="color:#000000">state_dict</span>(), <span style="color:#aa1111">"mnist_cnn.pt"</span>)
​
<span style="color:#770088">def</span> <span style="color:#0000ff">train</span>(<span style="color:#000000">args</span>, <span style="color:#000000">model</span>, <span style="color:#000000">device</span>, <span style="color:#000000">train_loader</span>, <span style="color:#000000">optimizer</span>, <span style="color:#000000">epoch</span>):
    <span style="color:#aa5500"># 设置模型为训练模式</span>
    <span style="color:#000000">model</span>.<span style="color:#000000">train</span>()
    <span style="color:#aa5500"># 遍历训练数据集的每个批次</span>
    <span style="color:#770088">for</span> <span style="color:#000000">batch_idx</span>, (<span style="color:#000000">data</span>, <span style="color:#000000">target</span>) <span style="color:#770088">in</span> <span style="color:#3300aa">enumerate</span>(<span style="color:#000000">train_loader</span>):
        <span style="color:#aa5500"># 将数据和目标发送到计算设备(如GPU)</span>
        <span style="color:#000000">data</span>, <span style="color:#000000">target</span> <span style="color:#981a1a">=</span> <span style="color:#000000">data</span>.<span style="color:#000000">to</span>(<span style="color:#000000">device</span>), <span style="color:#000000">target</span>.<span style="color:#000000">to</span>(<span style="color:#000000">device</span>)
        <span style="color:#aa5500"># 清除优化器的梯度</span>
        <span style="color:#000000">optimizer</span>.<span style="color:#000000">zero_grad</span>()
        <span style="color:#aa5500"># 通过模型处理数据得到输出</span>
        <span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#000000">model</span>(<span style="color:#000000">data</span>)
        <span style="color:#aa5500"># 计算输出和目标之间的负对数似然损失</span>
        <span style="color:#000000">loss</span> <span style="color:#981a1a">=</span> <span style="color:#000000">F</span>.<span style="color:#000000">nll_loss</span>(<span style="color:#000000">output</span>, <span style="color:#000000">target</span>)
        <span style="color:#aa5500"># 反向传播损失以计算每个参数的梯度</span>
        <span style="color:#000000">loss</span>.<span style="color:#000000">backward</span>()
        <span style="color:#aa5500"># 根据梯度更新模型的参数</span>
        <span style="color:#000000">optimizer</span>.<span style="color:#000000">step</span>()
        <span style="color:#aa5500"># 按照设定的间隔打印训练状态信息</span>
        <span style="color:#770088">if</span> <span style="color:#000000">batch_idx</span> <span style="color:#981a1a">%</span> <span style="color:#000000">args</span>.<span style="color:#000000">log_interval</span> <span style="color:#981a1a">==</span> <span style="color:#116644">0</span>:
            <span style="color:#3300aa">print</span>(<span style="color:#aa1111">'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'</span>.<span style="color:#000000">format</span>(
                <span style="color:#000000">epoch</span>, <span style="color:#000000">batch_idx</span> <span style="color:#981a1a">*</span> <span style="color:#3300aa">len</span>(<span style="color:#000000">data</span>), <span style="color:#3300aa">len</span>(<span style="color:#000000">train_loader</span>.<span style="color:#000000">dataset</span>),
                <span style="color:#116644">100.</span> <span style="color:#981a1a">*</span> <span style="color:#000000">batch_idx</span> <span style="color:#981a1a">/</span> <span style="color:#3300aa">len</span>(<span style="color:#000000">train_loader</span>), <span style="color:#000000">loss</span>.<span style="color:#000000">item</span>()))
            <span style="color:#aa5500"># 如果是干运行(仅用于测试),则在第一个批次后中断训练循环</span>
            <span style="color:#770088">if</span> <span style="color:#000000">args</span>.<span style="color:#000000">dry_run</span>:
                <span style="color:#770088">break</span>
​</span></span>

步骤2:

搭建好基础的训练框架好,需要将官网的数据集加载过程修改为自定义的数据集类,由于图像篡改检测的ground-truth label是一张mask图像,所以需要将原先数据集中的label由代表具体类别的数值改为mask图像,且需要和被篡改图像一一对应,顺序不能乱。

首先需要先进行一下文件名的处理:篡改图像文件夹和其对应的mask文件夹需要放在同一个目录下,然后篡改图像文件名需要和应的mask文件名一致,如文件结构为:

—Dataset

--------forgery image dir

------------------1.jpg

------------------2.jpg

------------------…

--------mask image dir

------------------1.png

------------------2.png

------------------…

然后再用以下加载数据集的代码读取图像文件以及mask图像:

<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#aa5500"># 导入相关库</span>
<span style="color:#770088">from</span> <span style="color:#000000">torch</span>.<span style="color:#000000">utils</span>.<span style="color:#000000">data</span> <span style="color:#770088">import</span> <span style="color:#000000">Dataset</span>
<span style="color:#770088">from</span> <span style="color:#000000">torchvision</span>.<span style="color:#000000">utils</span> <span style="color:#770088">import</span> <span style="color:#000000">save_image</span>
<span style="color:#770088">import</span> <span style="color:#000000">os</span>
<span style="color:#770088">import</span> <span style="color:#000000">numpy</span> <span style="color:#770088">as</span> <span style="color:#000000">np</span>
<span style="color:#770088">from</span> <span style="color:#000000">torchvision</span> <span style="color:#770088">import</span> <span style="color:#000000">transforms</span> <span style="color:#770088">as</span> <span style="color:#000000">T</span>
<span style="color:#770088">import</span> <span style="color:#000000">albumentations</span> <span style="color:#770088">as</span> <span style="color:#000000">A</span>
<span style="color:#770088">from</span> <span style="color:#000000">albumentations</span>.<span style="color:#000000">pytorch</span> <span style="color:#770088">import</span> <span style="color:#000000">ToTensorV2</span>
<span style="color:#770088">import</span> <span style="color:#000000">cv2</span>
<span style="color:#770088">import</span> <span style="color:#000000">torch</span>
​
<span style="color:#aa5500"># 设置图片的最大尺寸</span>
<span style="color:#000000">max_size_w</span> <span style="color:#981a1a">=</span> <span style="color:#116644">512</span>
<span style="color:#000000">max_size_h</span> <span style="color:#981a1a">=</span> <span style="color:#116644">512</span>
​
<span style="color:#aa5500"># 定义预处理掩码的函数</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">preprocess_mask</span>(<span style="color:#000000">mask</span>):
    <span style="color:#000000">mask</span> <span style="color:#981a1a">=</span> <span style="color:#000000">mask</span>.<span style="color:#000000">astype</span>(<span style="color:#000000">np</span>.<span style="color:#000000">float32</span>)  <span style="color:#aa5500"># 将掩码转换为浮点数类型</span>
    <span style="color:#000000">mask</span> <span style="color:#981a1a">=</span> <span style="color:#000000">mask</span>[:,:,<span style="color:#116644">0</span>:<span style="color:#116644">1</span>]  <span style="color:#aa5500"># 选择掩码的第一个通道</span>
    <span style="color:#000000">mask</span>[<span style="color:#000000">mask</span><span style="color:#981a1a"><=</span><span style="color:#116644">127.5</span>] <span style="color:#981a1a">=</span> <span style="color:#116644">0.0</span>  <span style="color:#aa5500"># 将掩码中小于等于127.5的值设置为0.0</span>
    <span style="color:#000000">mask</span>[<span style="color:#000000">mask</span><span style="color:#981a1a">></span><span style="color:#116644">127.5</span>] <span style="color:#981a1a">=</span> <span style="color:#116644">255.</span>  <span style="color:#aa5500"># 将掩码中大于127.5的值设置为255.</span>
    <span style="color:#770088">return</span> <span style="color:#000000">mask</span>
​
<span style="color:#aa5500"># 自定义UNet数据集类</span>
<span style="color:#770088">class</span> <span style="color:#0000ff">UNetDataset</span>(<span style="color:#000000">Dataset</span>):
    <span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">dir_train</span>, <span style="color:#000000">dir_mask</span>,<span style="color:#000000">train_transform</span><span style="color:#981a1a">=</span><span style="color:#770088">None</span>,<span style="color:#000000">val_transform</span><span style="color:#981a1a">=</span><span style="color:#770088">None</span>,<span style="color:#000000">mode</span> <span style="color:#981a1a">=</span> <span style="color:#aa1111">'train'</span>):
        <span style="color:#0055aa">self</span>.<span style="color:#000000">dirTrain</span> <span style="color:#981a1a">=</span> <span style="color:#000000">dir_train</span>  <span style="color:#aa5500"># 训练图片的目录</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">dirMask</span> <span style="color:#981a1a">=</span> <span style="color:#000000">dir_mask</span>  <span style="color:#aa5500"># 掩码图片的目录</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">mode</span> <span style="color:#981a1a">=</span> <span style="color:#000000">mode</span>  <span style="color:#aa5500"># 数据集模式(训练、验证或预测)</span>
        <span style="color:#aa5500"># 获取训练图片的文件路径列表</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">dataTrain</span> <span style="color:#981a1a">=</span> [<span style="color:#000000">os</span>.<span style="color:#000000">path</span>.<span style="color:#000000">join</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">dirTrain</span>, <span style="color:#000000">filename</span>)
                          <span style="color:#770088">for</span> <span style="color:#000000">filename</span> <span style="color:#770088">in</span> <span style="color:#000000">os</span>.<span style="color:#000000">listdir</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">dirTrain</span>)
                          <span style="color:#770088">if</span> <span style="color:#000000">filename</span>.<span style="color:#000000">endswith</span>(<span style="color:#aa1111">'.jpg'</span>) <span style="color:#770088">or</span> <span style="color:#000000">filename</span>.<span style="color:#000000">endswith</span>(<span style="color:#aa1111">'.png'</span>) <span style="color:#770088">or</span> <span style="color:#000000">filename</span>.<span style="color:#000000">endswith</span>(<span style="color:#aa1111">'.tif'</span>)<span style="color:#770088">or</span> <span style="color:#000000">filename</span>.<span style="color:#000000">endswith</span>(<span style="color:#aa1111">'.jpeg'</span>) ]
        <span style="color:#0055aa">self</span>.<span style="color:#000000">dataTrain</span>.<span style="color:#000000">sort</span>()  <span style="color:#aa5500"># 对文件路径列表进行排序</span>
        <span style="color:#aa5500"># 获取掩码图片的文件路径列表</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">dataMask</span> <span style="color:#981a1a">=</span> [<span style="color:#000000">os</span>.<span style="color:#000000">path</span>.<span style="color:#000000">join</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">dirMask</span>, <span style="color:#000000">filename</span>)
                         <span style="color:#770088">for</span> <span style="color:#000000">filename</span> <span style="color:#770088">in</span> <span style="color:#000000">os</span>.<span style="color:#000000">listdir</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">dirMask</span>)
                         <span style="color:#770088">if</span> <span style="color:#000000">filename</span>.<span style="color:#000000">endswith</span>(<span style="color:#aa1111">'.jpg'</span>) <span style="color:#770088">or</span> <span style="color:#000000">filename</span>.<span style="color:#000000">endswith</span>(<span style="color:#aa1111">'.png'</span>) <span style="color:#770088">or</span> <span style="color:#000000">filename</span>.<span style="color:#000000">endswith</span>(<span style="color:#aa1111">'.tif'</span>) <span style="color:#770088">or</span> <span style="color:#000000">filename</span>.<span style="color:#000000">endswith</span>(<span style="color:#aa1111">'.jpeg'</span>) ]
        <span style="color:#0055aa">self</span>.<span style="color:#000000">dataMask</span>.<span style="color:#000000">sort</span>()  <span style="color:#aa5500"># 对文件路径列表进行排序</span>
​
        <span style="color:#0055aa">self</span>.<span style="color:#000000">trainDataSize</span> <span style="color:#981a1a">=</span> <span style="color:#3300aa">len</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">dataTrain</span>)  <span style="color:#aa5500"># 训练图片的数量</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">maskDataSize</span> <span style="color:#981a1a">=</span> <span style="color:#3300aa">len</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">dataMask</span>)  <span style="color:#aa5500"># 掩码图片的数量</span>
​
        <span style="color:#0055aa">self</span>.<span style="color:#000000">transform1</span>  <span style="color:#981a1a">=</span> <span style="color:#000000">T</span>.<span style="color:#000000">Normalize</span>(<span style="color:#000000">mean</span><span style="color:#981a1a">=</span>(<span style="color:#116644">0.5</span>, <span style="color:#116644">0.5</span>, <span style="color:#116644">0.5</span>), <span style="color:#000000">std</span><span style="color:#981a1a">=</span>(<span style="color:#116644">0.5</span>, <span style="color:#116644">0.5</span>, <span style="color:#116644">0.5</span>))  <span style="color:#aa5500"># 归一化变换</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">toTensor</span>  <span style="color:#981a1a">=</span> <span style="color:#000000">A</span>.<span style="color:#000000">Compose</span>([<span style="color:#000000">ToTensorV2</span>()])  <span style="color:#aa5500"># 转换为PyTorch张量的变换</span>
​
        <span style="color:#aa5500"># 如果提供了训练变换,则使用提供的变换,否则使用默认的训练变换</span>
        <span style="color:#770088">if</span> <span style="color:#000000">train_transform</span> <span style="color:#770088">is</span> <span style="color:#770088">not</span> <span style="color:#770088">None</span>:
            <span style="color:#0055aa">self</span>.<span style="color:#000000">train_transform</span> <span style="color:#981a1a">=</span> <span style="color:#000000">train_transform</span>
        <span style="color:#770088">else</span>:
            <span style="color:#0055aa">self</span>.<span style="color:#000000">train_transform</span> <span style="color:#981a1a">=</span> <span style="color:#000000">A</span>.<span style="color:#000000">Compose</span>(
                                [
                                    <span style="color:#000000">A</span>.<span style="color:#000000">Resize</span>(<span style="color:#000000">max_size_h</span>,<span style="color:#000000">max_size_w</span>,<span style="color:#000000">p</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>),  <span style="color:#aa5500"># 调整图片尺寸</span>
                                    <span style="color:#000000">A</span>.<span style="color:#000000">VerticalFlip</span>(<span style="color:#000000">p</span><span style="color:#981a1a">=</span><span style="color:#116644">0.2</span>),  <span style="color:#aa5500"># 随机垂直翻转</span>
                                    <span style="color:#000000">A</span>.<span style="color:#000000">HorizontalFlip</span>(<span style="color:#000000">p</span> <span style="color:#981a1a">=</span> <span style="color:#116644">0.2</span>),  <span style="color:#aa5500"># 随机水平翻转</span>
                                    <span style="color:#000000">ToTensorV2</span>(),  <span style="color:#aa5500"># 转换为PyTorch张量</span>
                                ], <span style="color:#000000">is_check_shapes</span><span style="color:#981a1a">=</span><span style="color:#770088">False</span>
                            )
        <span style="color:#aa5500"># 如果提供了验证变换,则使用提供的变换,否则使用默认的验证变换</span>
        <span style="color:#770088">if</span> <span style="color:#000000">val_transform</span> <span style="color:#770088">is</span> <span style="color:#770088">not</span> <span style="color:#770088">None</span>:
            <span style="color:#0055aa">self</span>.<span style="color:#000000">val_transform</span> <span style="color:#981a1a">=</span> <span style="color:#000000">val_transform</span>
        <span style="color:#770088">else</span>:
            <span style="color:#0055aa">self</span>.<span style="color:#000000">val_transform</span> <span style="color:#981a1a">=</span> <span style="color:#000000">A</span>.<span style="color:#000000">Compose</span>(
                                        [
                                            <span style="color:#000000">A</span>.<span style="color:#000000">Resize</span>(<span style="color:#000000">max_size_h</span>,<span style="color:#000000">max_size_w</span>,<span style="color:#000000">p</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>),  <span style="color:#aa5500"># 调整图片尺寸</span>
                                            <span style="color:#000000">ToTensorV2</span>(),  <span style="color:#aa5500"># 转换为PyTorch张量</span>
                                        ], <span style="color:#000000">is_check_shapes</span><span style="color:#981a1a">=</span><span style="color:#770088">False</span>
                                    )
    
        <span style="color:#0055aa">self</span>.<span style="color:#000000">kernel</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">ones</span>((<span style="color:#116644">4</span>, <span style="color:#116644">4</span>), <span style="color:#000000">np</span>.<span style="color:#000000">uint8</span>)  <span style="color:#aa5500"># 用于形态学操作的核</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">feature</span></span></span>

步骤3:

定义MVSS模型(此模型在原论文中有给出),并将第一步中的卷积神经网络模型model训练模型替换为MVSS模型。

<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#770088">def</span> <span style="color:#0000ff">get_mvss</span>(<span style="color:#000000">backbone</span><span style="color:#981a1a">=</span><span style="color:#aa1111">'resnet50'</span>, <span style="color:#000000">pretrained_base</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>, <span style="color:#000000">nclass</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>, <span style="color:#000000">sobel</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>, <span style="color:#000000">n_input</span><span style="color:#981a1a">=</span><span style="color:#116644">3</span>, <span style="color:#000000">constrain</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>, <span style="color:#981a1a">**</span><span style="color:#000000">kwargs</span>):
    <span style="color:#000000">model</span> <span style="color:#981a1a">=</span> <span style="color:#000000">MVSSNet</span>(<span style="color:#000000">nclass</span>, <span style="color:#000000">backbone</span><span style="color:#981a1a">=</span><span style="color:#000000">backbone</span>,
                    <span style="color:#000000">pretrained_base</span><span style="color:#981a1a">=</span><span style="color:#000000">pretrained_base</span>,
                    <span style="color:#000000">sobel</span><span style="color:#981a1a">=</span><span style="color:#000000">sobel</span>,
                    <span style="color:#000000">n_input</span><span style="color:#981a1a">=</span><span style="color:#000000">n_input</span>,
                    <span style="color:#000000">constrain</span><span style="color:#981a1a">=</span><span style="color:#000000">constrain</span>,
                    <span style="color:#981a1a">**</span><span style="color:#000000">kwargs</span>)
    <span style="color:#770088">return</span> <span style="color:#000000">model</span>
​
​
<span style="color:#770088">class</span> <span style="color:#0000ff">MVSSNet</span>(<span style="color:#000000">ResNet50</span>):
    <span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">nclass</span>, <span style="color:#000000">aux</span><span style="color:#981a1a">=</span><span style="color:#770088">False</span>, <span style="color:#000000">sobel</span><span style="color:#981a1a">=</span><span style="color:#770088">False</span>, <span style="color:#000000">constrain</span><span style="color:#981a1a">=</span><span style="color:#770088">False</span>, <span style="color:#000000">n_input</span><span style="color:#981a1a">=</span><span style="color:#116644">3</span>, <span style="color:#981a1a">**</span><span style="color:#000000">kwargs</span>):
        <span style="color:#3300aa">super</span>(<span style="color:#000000">MVSSNet</span>, <span style="color:#0055aa">self</span>).<span style="color:#000000">__init__</span>(<span style="color:#000000">pretrained</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>, <span style="color:#000000">n_input</span><span style="color:#981a1a">=</span><span style="color:#000000">n_input</span>)
        <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nclass</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">aux</span> <span style="color:#981a1a">=</span> <span style="color:#000000">aux</span>
​
        <span style="color:#0055aa">self</span>.<span style="color:#000000">__setattr__</span>(<span style="color:#aa1111">'exclusive'</span>, [<span style="color:#aa1111">'head'</span>])
​
        <span style="color:#0055aa">self</span>.<span style="color:#000000">upsample</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Upsample</span>(<span style="color:#000000">scale_factor</span><span style="color:#981a1a">=</span><span style="color:#116644">2</span>, <span style="color:#000000">mode</span><span style="color:#981a1a">=</span><span style="color:#aa1111">"bilinear"</span>, <span style="color:#000000">align_corners</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
        <span style="color:#0055aa">self</span>.<span style="color:#000000">upsample_4</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Upsample</span>(<span style="color:#000000">scale_factor</span><span style="color:#981a1a">=</span><span style="color:#116644">4</span>, <span style="color:#000000">mode</span><span style="color:#981a1a">=</span><span style="color:#aa1111">"bilinear"</span>, <span style="color:#000000">align_corners</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
        <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel</span> <span style="color:#981a1a">=</span> <span style="color:#000000">sobel</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">constrain</span> <span style="color:#981a1a">=</span> <span style="color:#000000">constrain</span>
​
        <span style="color:#0055aa">self</span>.<span style="color:#000000">erb_db_1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ERB</span>(<span style="color:#116644">256</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>)
        <span style="color:#0055aa">self</span>.<span style="color:#000000">erb_db_2</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ERB</span>(<span style="color:#116644">512</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>)
        <span style="color:#0055aa">self</span>.<span style="color:#000000">erb_db_3</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ERB</span>(<span style="color:#116644">1024</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>)
        <span style="color:#0055aa">self</span>.<span style="color:#000000">erb_db_4</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ERB</span>(<span style="color:#116644">2048</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>)
​
        <span style="color:#0055aa">self</span>.<span style="color:#000000">erb_trans_1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ERB</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>)
        <span style="color:#0055aa">self</span>.<span style="color:#000000">erb_trans_2</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ERB</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>)
        <span style="color:#0055aa">self</span>.<span style="color:#000000">erb_trans_3</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ERB</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>)
​
        <span style="color:#770088">if</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel</span>:
            <span style="color:#3300aa">print</span>(<span style="color:#aa1111">"----------use sobel-------------"</span>)
            <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel_x1</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel_y1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">get_sobel</span>(<span style="color:#116644">256</span>, <span style="color:#116644">1</span>)
            <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel_x2</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel_y2</span> <span style="color:#981a1a">=</span> <span style="color:#000000">get_sobel</span>(<span style="color:#116644">512</span>, <span style="color:#116644">1</span>)
            <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel_x3</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel_y3</span> <span style="color:#981a1a">=</span> <span style="color:#000000">get_sobel</span>(<span style="color:#116644">1024</span>, <span style="color:#116644">1</span>)
            <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel_x4</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">sobel_y4</span> <span style="color:#981a1a">=</span> <span style="color:#000000">get_sobel</span>(<span style="color:#116644">2048</span>, <span style="color:#116644">1</span>)
​
        <span style="color:#770088">if</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">constrain</span>:
            <span style="color:#3300aa">print</span>(<span style="color:#aa1111">"----------use constrain-------------"</span>)
            <span style="color:#0055aa">self</span>.<span style="color:#000000">noise_extractor</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ResNet50</span>(<span style="color:#000000">n_input</span><span style="color:#981a1a">=</span><span style="color:#116644">3</span>, <span style="color:#000000">pretrained</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
            <span style="color:#0055aa">self</span>.<span style="color:#000000">constrain_conv</span> <span style="color:#981a1a">=</span> <span style="color:#000000">BayarConv2d</span>(<span style="color:#000000">in_channels</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>, <span style="color:#000000">out_channels</span><span style="color:#981a1a">=</span><span style="color:#116644">3</span>, <span style="color:#000000">padding</span><span style="color:#981a1a">=</span><span style="color:#116644">2</span>)
            <span style="color:#0055aa">self</span>.<span style="color:#000000">head</span> <span style="color:#981a1a">=</span> <span style="color:#000000">_DAHead</span>(<span style="color:#116644">2048</span><span style="color:#981a1a">+</span><span style="color:#116644">2048</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>, <span style="color:#000000">aux</span>, <span style="color:#981a1a">**</span><span style="color:#000000">kwargs</span>)
        <span style="color:#770088">else</span>:
            <span style="color:#0055aa">self</span>.<span style="color:#000000">head</span> <span style="color:#981a1a">=</span> <span style="color:#000000">_DAHead</span>(<span style="color:#116644">2048</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_class</span>, <span style="color:#000000">aux</span>, <span style="color:#981a1a">**</span><span style="color:#000000">kwargs</span>)
​
    <span style="color:#770088">def</span> <span style="color:#0000ff">forward</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">x</span>):
        <span style="color:#000000">size</span> <span style="color:#981a1a">=</span> <span style="color:#000000">x</span>.<span style="color:#000000">size</span>()[<span style="color:#116644">2</span>:]
        <span style="color:#000000">input_</span> <span style="color:#981a1a">=</span> <span style="color:#000000">x</span>.<span style="color:#000000">clone</span>()
        <span style="color:#000000">feature_map</span>, <span style="color:#000000">_</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">base_forward</span>(<span style="color:#000000">input_</span>)
        <span style="color:#000000">c1</span>, <span style="color:#000000">c2</span>, <span style="color:#000000">c3</span>, <span style="color:#000000">c4</span> <span style="color:#981a1a">=</span> <span style="color:#000000">feature_map</span>
​
        <span style="color:#770088">if</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">constrain</span>:
            <span style="color:#000000">x</span> <span style="color:#981a1a">=</span> <span style="color:#000000">rgb2gray</span>(<span style="color:#000000">x</span>)
            <span style="color:#000000">x</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">constrain_conv</span>(<span style="color:#000000">x</span>)
            <span style="color:#000000">constrain_features</span>, <span style="color:#000000">_</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">noise_extractor</span>.<span style="color:#000000">base_forward</span>(<span style="color:#000000">x</span>)
            <span style="color:#000000">constrain_feature</span> <span style="color:#981a1a">=</span> <span style="color:#000000">constrain_features</span>[<span style="color:#981a1a">-</span><span style="color:#116644">1</span>]
            <span style="color:#000000">c4</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">cat</span>([<span style="color:#000000">c4</span>, <span style="color:#000000">constrain_feature</span>], <span style="color:#000000">dim</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>)
​
        <span style="color:#000000">x</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">head</span>(<span style="color:#000000">c4</span>)
        <span style="color:#000000">x0</span> <span style="color:#981a1a">=</span> <span style="color:#000000">F</span>.<span style="color:#000000">interpolate</span>(<span style="color:#000000">x</span>[<span style="color:#116644">0</span>], <span style="color:#000000">size</span>, <span style="color:#000000">mode</span><span style="color:#981a1a">=</span><span style="color:#aa1111">'bilinear'</span>, <span style="color:#000000">align_corners</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
​
        <span style="color:#770088">return</span> <span style="color:#000000">x0</span></span></span>

步骤4:

定义损失函数,本文用图像分类中加权的BCE Loss和Dice Loss来对模型进行训练。

<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#770088">import</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">functional</span> <span style="color:#770088">as</span> <span style="color:#000000">F</span>
<span style="color:#770088">import</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span> <span style="color:#770088">as</span> <span style="color:#000000">nn</span>
<span style="color:#aa5500">######## WeightedBCE Loss ###########</span>
<span style="color:#770088">class</span> <span style="color:#0000ff">WeightedBCE</span>(<span style="color:#000000">nn</span>.<span style="color:#000000">Module</span>):
    <span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">weights</span><span style="color:#981a1a">=</span>[<span style="color:#116644">0.2</span>, <span style="color:#116644">0.8</span>]):
        <span style="color:#3300aa">super</span>(<span style="color:#000000">WeightedBCE</span>, <span style="color:#0055aa">self</span>).<span style="color:#000000">__init__</span>()
        <span style="color:#0055aa">self</span>.<span style="color:#000000">weights</span> <span style="color:#981a1a">=</span> <span style="color:#000000">weights</span>
​
    <span style="color:#770088">def</span> <span style="color:#0000ff">forward</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">logit_pixel</span>, <span style="color:#000000">truth_pixel</span>):
        <span style="color:#000000">logit</span> <span style="color:#981a1a">=</span> <span style="color:#000000">logit_pixel</span>.<span style="color:#000000">reshape</span>(<span style="color:#981a1a">-</span><span style="color:#116644">1</span>)
        <span style="color:#000000">truth</span> <span style="color:#981a1a">=</span> <span style="color:#000000">truth_pixel</span>.<span style="color:#000000">reshape</span>(<span style="color:#981a1a">-</span><span style="color:#116644">1</span>)
        <span style="color:#770088">assert</span>(<span style="color:#000000">logit</span>.<span style="color:#000000">shape</span><span style="color:#981a1a">==</span><span style="color:#000000">truth</span>.<span style="color:#000000">shape</span>)
        <span style="color:#000000">loss</span> <span style="color:#981a1a">=</span> <span style="color:#000000">F</span>.<span style="color:#000000">binary_cross_entropy</span>(<span style="color:#000000">logit</span>, <span style="color:#000000">truth</span>, <span style="color:#000000">reduction</span><span style="color:#981a1a">=</span><span style="color:#aa1111">'mean'</span>)
        <span style="color:#000000">pos</span> <span style="color:#981a1a">=</span> (<span style="color:#000000">truth</span><span style="color:#981a1a">>=</span><span style="color:#116644">0.35</span>).<span style="color:#000000">float</span>()
        <span style="color:#000000">neg</span> <span style="color:#981a1a">=</span> (<span style="color:#000000">truth</span><span style="color:#981a1a"><</span><span style="color:#116644">0.35</span>).<span style="color:#000000">float</span>()
        <span style="color:#000000">pos_weight</span> <span style="color:#981a1a">=</span> <span style="color:#000000">pos</span>.<span style="color:#000000">sum</span>().<span style="color:#000000">item</span>() <span style="color:#981a1a">+</span> <span style="color:#116644">1e-12</span>
        <span style="color:#000000">neg_weight</span> <span style="color:#981a1a">=</span> <span style="color:#000000">neg</span>.<span style="color:#000000">sum</span>().<span style="color:#000000">item</span>() <span style="color:#981a1a">+</span> <span style="color:#116644">1e-12</span>
        <span style="color:#000000">loss</span> <span style="color:#981a1a">=</span> (<span style="color:#0055aa">self</span>.<span style="color:#000000">weights</span>[<span style="color:#116644">0</span>]<span style="color:#981a1a">*</span><span style="color:#000000">pos</span><span style="color:#981a1a">*</span><span style="color:#000000">loss</span><span style="color:#981a1a">/</span><span style="color:#000000">pos_weight</span> <span style="color:#981a1a">+</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">weights</span>[<span style="color:#116644">1</span>]<span style="color:#981a1a">*</span><span style="color:#000000">neg</span><span style="color:#981a1a">*</span><span style="color:#000000">loss</span><span style="color:#981a1a">/</span><span style="color:#000000">neg_weight</span>).<span style="color:#000000">sum</span>()
        <span style="color:#770088">return</span> <span style="color:#000000">loss</span>
​
<span style="color:#aa5500">######## WeightedDice Loss ###########</span>
<span style="color:#770088">class</span> <span style="color:#0000ff">WeightedDiceLoss</span>(<span style="color:#000000">nn</span>.<span style="color:#000000">Module</span>):
    <span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">weights</span><span style="color:#981a1a">=</span>[<span style="color:#116644">0.5</span>, <span style="color:#116644">0.5</span>]): <span style="color:#aa5500"># W_pos=0.8, W_neg=0.2</span>
        <span style="color:#3300aa">super</span>(<span style="color:#000000">WeightedDiceLoss</span>, <span style="color:#0055aa">self</span>).<span style="color:#000000">__init__</span>()
        <span style="color:#0055aa">self</span>.<span style="color:#000000">weights</span> <span style="color:#981a1a">=</span> <span style="color:#000000">weights</span>
​
    <span style="color:#770088">def</span> <span style="color:#0000ff">forward</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">logit</span>, <span style="color:#000000">truth</span>, <span style="color:#000000">smooth</span><span style="color:#981a1a">=</span><span style="color:#116644">1e-5</span>):
        <span style="color:#000000">batch_size</span> <span style="color:#981a1a">=</span> <span style="color:#3300aa">len</span>(<span style="color:#000000">logit</span>)
        <span style="color:#000000">logit</span> <span style="color:#981a1a">=</span> <span style="color:#000000">logit</span>.<span style="color:#000000">reshape</span>(<span style="color:#000000">batch_size</span>,<span style="color:#981a1a">-</span><span style="color:#116644">1</span>)
        <span style="color:#000000">truth</span> <span style="color:#981a1a">=</span> <span style="color:#000000">truth</span>.<span style="color:#000000">reshape</span>(<span style="color:#000000">batch_size</span>,<span style="color:#981a1a">-</span><span style="color:#116644">1</span>)
        <span style="color:#770088">assert</span>(<span style="color:#000000">logit</span>.<span style="color:#000000">shape</span><span style="color:#981a1a">==</span><span style="color:#000000">truth</span>.<span style="color:#000000">shape</span>)
        <span style="color:#000000">p</span> <span style="color:#981a1a">=</span> <span style="color:#000000">logit</span>.<span style="color:#000000">view</span>(<span style="color:#000000">batch_size</span>,<span style="color:#981a1a">-</span><span style="color:#116644">1</span>)
        <span style="color:#000000">t</span> <span style="color:#981a1a">=</span> <span style="color:#000000">truth</span>.<span style="color:#000000">view</span>(<span style="color:#000000">batch_size</span>,<span style="color:#981a1a">-</span><span style="color:#116644">1</span>)
​
        <span style="color:#000000">w</span> <span style="color:#981a1a">=</span> <span style="color:#000000">truth</span>.<span style="color:#000000">detach</span>()
        <span style="color:#000000">w</span> <span style="color:#981a1a">=</span> <span style="color:#000000">w</span><span style="color:#981a1a">*</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">weights</span>[<span style="color:#116644">1</span>]<span style="color:#981a1a">-</span><span style="color:#0055aa">self</span>.<span style="color:#000000">weights</span>[<span style="color:#116644">0</span>])<span style="color:#981a1a">+</span><span style="color:#0055aa">self</span>.<span style="color:#000000">weights</span>[<span style="color:#116644">0</span>]
        <span style="color:#000000">p</span> <span style="color:#981a1a">=</span> <span style="color:#000000">w</span><span style="color:#981a1a">*</span>(<span style="color:#000000">p</span>)
        <span style="color:#000000">t</span> <span style="color:#981a1a">=</span> <span style="color:#000000">w</span><span style="color:#981a1a">*</span>(<span style="color:#000000">t</span>)
        <span style="color:#000000">intersection</span> <span style="color:#981a1a">=</span> (<span style="color:#000000">p</span> <span style="color:#981a1a">*</span> <span style="color:#000000">t</span>).<span style="color:#000000">sum</span>(<span style="color:#981a1a">-</span><span style="color:#116644">1</span>)
        <span style="color:#000000">union</span> <span style="color:#981a1a">=</span>  (<span style="color:#000000">p</span> <span style="color:#981a1a">*</span> <span style="color:#000000">p</span>).<span style="color:#000000">sum</span>(<span style="color:#981a1a">-</span><span style="color:#116644">1</span>) <span style="color:#981a1a">+</span> (<span style="color:#000000">t</span> <span style="color:#981a1a">*</span> <span style="color:#000000">t</span>).<span style="color:#000000">sum</span>(<span style="color:#981a1a">-</span><span style="color:#116644">1</span>)
        <span style="color:#000000">dice</span>  <span style="color:#981a1a">=</span> <span style="color:#116644">1</span> <span style="color:#981a1a">-</span> (<span style="color:#116644">2</span><span style="color:#981a1a">*</span><span style="color:#000000">intersection</span> <span style="color:#981a1a">+</span> <span style="color:#000000">smooth</span>) <span style="color:#981a1a">/</span> (<span style="color:#000000">union</span> <span style="color:#981a1a">+</span><span style="color:#000000">smooth</span>)
        <span style="color:#000000">loss</span> <span style="color:#981a1a">=</span> <span style="color:#000000">dice</span>.<span style="color:#000000">mean</span>()
        <span style="color:#770088">return</span> <span style="color:#000000">loss</span>
​
​
<span style="color:#aa5500">######## Total Loss  = WeightedDice Loss  + WeightedBCE###########</span>
<span style="color:#770088">class</span> <span style="color:#0000ff">WeightedDiceBCE</span>(<span style="color:#000000">nn</span>.<span style="color:#000000">Module</span>):
    <span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>,<span style="color:#000000">dice_weight</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>,<span style="color:#000000">BCE_weight</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>):
        <span style="color:#3300aa">super</span>(<span style="color:#000000">WeightedDiceBCE</span>, <span style="color:#0055aa">self</span>).<span style="color:#000000">__init__</span>()
        <span style="color:#0055aa">self</span>.<span style="color:#000000">BCE_loss</span> <span style="color:#981a1a">=</span> <span style="color:#000000">WeightedBCE</span>(<span style="color:#000000">weights</span><span style="color:#981a1a">=</span>[<span style="color:#116644">0.8</span>, <span style="color:#116644">0.2</span>])
        <span style="color:#0055aa">self</span>.<span style="color:#000000">dice_loss</span> <span style="color:#981a1a">=</span> <span style="color:#000000">WeightedDiceLoss</span>(<span style="color:#000000">weights</span><span style="color:#981a1a">=</span>[<span style="color:#116644">0.5</span>, <span style="color:#116644">0.5</span>])
        <span style="color:#0055aa">self</span>.<span style="color:#000000">BCE_weight</span> <span style="color:#981a1a">=</span> <span style="color:#000000">BCE_weight</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">lovasz_weight</span> <span style="color:#981a1a">=</span> <span style="color:#116644">0</span>
        <span style="color:#0055aa">self</span>.<span style="color:#000000">dice_weight</span> <span style="color:#981a1a">=</span> <span style="color:#000000">dice_weight</span>
        
    <span style="color:#770088">def</span> <span style="color:#0000ff">forward</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">inputs</span>, <span style="color:#000000">targets</span>):
        <span style="color:#000000">dice</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">dice_loss</span>(<span style="color:#000000">inputs</span>, <span style="color:#000000">targets</span>)
        <span style="color:#000000">BCE</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">BCE_loss</span>(<span style="color:#000000">inputs</span>, <span style="color:#000000">targets</span>)
        <span style="color:#000000">dice_BCE_loss</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">dice_weight</span> <span style="color:#981a1a">*</span> <span style="color:#000000">dice</span> <span style="color:#981a1a">+</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">BCE_weight</span> <span style="color:#981a1a">*</span> <span style="color:#000000">BCE</span>
        <span style="color:#770088">return</span> <span style="color:#000000">dice_BCE_loss</span></span></span>

步骤5:

最后的训练框架就会变为:

<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#555555">#=============================数据集设置===================================</span>
    <span style="color:#555555"># 训练图片目录</span>
    <span style="color:#000000">img_dir</span> <span style="color:#981a1a">=</span><span style="color:#aa1111">'./test_data/train_img'</span>
    <span style="color:#555555"># 训练掩码目录</span>
    <span style="color:#000000">gt_mask_dir</span> <span style="color:#981a1a">=</span><span style="color:#aa1111">'./test_data/train_mask'</span>
<span style="color:#555555"># =============================数据集设置===================================</span>
    <span style="color:#555555"># 创建训练数据集</span>
    <span style="color:#000000">train_dataset</span> <span style="color:#981a1a">=</span> <span style="color:#000000">UNetDataset</span>(<span style="color:#000000">img_dir</span>, <span style="color:#000000">gt_mask_dir</span>, <span style="color:#000000">mode</span><span style="color:#981a1a">=</span><span style="color:#aa1111">'train'</span>)
    <span style="color:#555555"># 创建训练数据加载器</span>
    <span style="color:#000000">train_loader</span> <span style="color:#981a1a">=</span> <span style="color:#000000">DataLoader</span>(
        <span style="color:#000000">train_dataset</span>,  <span style="color:#000000">#</span> <span style="color:#000000">使用之前定义的UNetDataset</span>
        <span style="color:#000000">batch_size</span><span style="color:#981a1a">=</span><span style="color:#000000">params</span>[<span style="color:#aa1111">"batch_size"</span>],  <span style="color:#000000">#</span> <span style="color:#000000">批处理大小</span>
        <span style="color:#000000">shuffle</span><span style="color:#981a1a">=</span><span style="color:#000000">True</span>,  <span style="color:#000000">#</span> <span style="color:#000000">打乱数据</span>
        <span style="color:#000000">num_workers</span><span style="color:#981a1a">=</span><span style="color:#000000">params</span>[<span style="color:#aa1111">"num_workers"</span>],  <span style="color:#000000">#</span> <span style="color:#000000">加载数据的工作进程数</span>
        <span style="color:#000000">pin_memory</span><span style="color:#981a1a">=</span><span style="color:#000000">True</span>,  <span style="color:#000000">#</span> <span style="color:#000000">将数据加载到CUDA中的固定内存中</span>
        <span style="color:#000000">drop_last</span><span style="color:#981a1a">=</span><span style="color:#000000">True</span>,  <span style="color:#000000">#</span> <span style="color:#000000">如果最后一个批次的数据量小于batch_size,则丢弃该批次</span>
    )
    <span style="color:#555555"># 获取模型,使用resnet50作为骨干网络,预训练基础模型,输出类别数为1,使用Sobel算子,应用约束,输入通道数为3</span>
    <span style="color:#000000">model</span> <span style="color:#981a1a">=</span> <span style="color:#000000">get_mvss</span>(<span style="color:#000000">backbone</span><span style="color:#981a1a">=</span><span style="color:#aa1111">'resnet50'</span>,
                                <span style="color:#000000">pretrained_base</span><span style="color:#981a1a">=</span><span style="color:#000000">True</span>,
                                <span style="color:#000000">nclass</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>,
                                <span style="color:#000000">sobel</span><span style="color:#981a1a">=</span><span style="color:#000000">True</span>,
                                <span style="color:#000000">constrain</span><span style="color:#981a1a">=</span><span style="color:#000000">True</span>,
                                <span style="color:#000000">n_input</span><span style="color:#981a1a">=</span><span style="color:#116644">3</span>,
                                )
​
    <span style="color:#555555"># 实例化优化器,使用AdamW算法,传入模型参数和学习率</span>
    <span style="color:#000000">optimizer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">optim</span>.<span style="color:#000000">AdamW</span>(<span style="color:#000000">model</span>.<span style="color:#000000">parameters</span>(), <span style="color:#000000">lr</span><span style="color:#981a1a">=</span><span style="color:#000000">params</span>[<span style="color:#aa1111">"lr"</span>])
    <span style="color:#555555"># 实例化损失函数,使用加权Dice和BCE损失,权重分别为0.3和0.7,并将其移动到GPU上</span>
    <span style="color:#000000">criterion_1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">WeightedDiceBCE</span>(<span style="color:#000000">dice_weight</span><span style="color:#981a1a">=</span><span style="color:#116644">0.3</span>,<span style="color:#000000">BCE_weight</span><span style="color:#981a1a">=</span><span style="color:#116644">0.7</span>).<span style="color:#000000">cuda</span>()
    <span style="color:#555555"># 开始训练循环,从epoch_start开始,到设定的epochs结束</span>
    <span style="color:#770088">for</span> <span style="color:#000000">epoch</span> <span style="color:#000000">in</span> <span style="color:#0000ff">range</span>(<span style="color:#000000">epoch_start</span>, <span style="color:#000000">params</span>[<span style="color:#aa1111">"epochs"</span>] <span style="color:#981a1a">+</span> <span style="color:#116644">1</span>):
        <span style="color:#555555"># 将模型设置为训练模式</span>
        <span style="color:#000000">model</span>.<span style="color:#000000">train</span>()
        <span style="color:#555555"># 创建进度条,显示处理过程,颜色为青色</span>
        <span style="color:#000000">stream</span> <span style="color:#981a1a">=</span> <span style="color:#000000">tqdm</span>(<span style="color:#000000">train_loader</span>,<span style="color:#000000">desc</span><span style="color:#981a1a">=</span><span style="color:#aa1111">'processing'</span>,<span style="color:#000000">colour</span><span style="color:#981a1a">=</span><span style="color:#aa1111">'CYAN'</span>)
        <span style="color:#555555"># 遍历数据加载器中的每个批次</span>
        <span style="color:#770088">for</span> <span style="color:#000000">i</span>, (<span style="color:#000000">images</span>, <span style="color:#000000">masks</span>,<span style="color:#000000">_</span>) <span style="color:#000000">in</span> <span style="color:#0000ff">enumerate</span>(<span style="color:#000000">stream</span>, <span style="color:#000000">start</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>):
            <span style="color:#555555"># 将图片和掩码发送到GPU上,非阻塞式传输</span>
            <span style="color:#000000">images</span> <span style="color:#981a1a">=</span> <span style="color:#000000">images</span>.<span style="color:#000000">cuda</span>(<span style="color:#000000">non_blocking</span><span style="color:#981a1a">=</span><span style="color:#000000">params</span>[<span style="color:#aa1111">'non_blocking_'</span>])
            <span style="color:#000000">masks</span> <span style="color:#981a1a">=</span> <span style="color:#000000">masks</span>.<span style="color:#000000">cuda</span>(<span style="color:#000000">non_blocking</span><span style="color:#981a1a">=</span><span style="color:#000000">params</span>[<span style="color:#aa1111">'non_blocking_'</span>])
            <span style="color:#555555"># 通过模型处理图片得到输出</span>
            <span style="color:#000000">reg_outs</span> <span style="color:#981a1a">=</span> <span style="color:#000000">model</span>(<span style="color:#000000">images</span>)
            <span style="color:#555555"># 应用Sigmoid激活函数</span>
            <span style="color:#000000">reg_outs</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">sigmoid</span>(<span style="color:#000000">reg_outs</span>)
            <span style="color:#555555"># 计算损失</span>
            <span style="color:#000000">loss_region</span> <span style="color:#981a1a">=</span> <span style="color:#000000">criterion1</span>(<span style="color:#000000">reg_outs</span>, <span style="color:#000000">masks</span>)
            <span style="color:#555555"># 清除优化器的梯度</span>
            <span style="color:#000000">optimizer</span>.<span style="color:#000000">zero_grad</span>()
            <span style="color:#555555"># 反向传播损失</span>
            <span style="color:#000000">loss_region</span>.<span style="color:#000000">backward</span>()
            <span style="color:#555555"># 更新模型的参数</span>
            <span style="color:#000000">optimizer</span>.<span style="color:#000000">step</span>()</span></span>

使用附件代码以及附件数据集(附件txt文件中附上数据集下载地址),运行python main.py就可以一键实现对图像篡改检测定位的训练。如果想要直接得到结果,可以使用附件txt中提供的作者预训练好的模型,运行test.py,在其中修改测试数据集路径,或者把图像及其对应的mask图像放到./test_dataset/val_img目录下以及./test_dataset/val_mask目录下,要求两个目录中都有文件名相同的文件,如果测试集没有mask图像,则将274行代码val_gt_mask_dir=xxx也设置为与val_img_dir = xxx一样即可。预测所得图像在./test_dataset/predict_results目录中。

此外,还开发了一个桌面版的GUI界面和一个网页版的web供算法的可视化演示。 首先是桌面版的GUI界面,运行app.py,就会出来一个界面,点击请选择待检测图像按钮,在test_data的img文件夹中存放着几张例子图像,其对应的篡改区域可以见同一个目录下的mask文件夹。选择图像后,点击OK后界面左边会加载显示疑似篡改图像,点击篡改检测按钮,稍等几秒后界面右边就会显示疑似的篡改P图区域。值得一提的是,该界面使用了热力图的形式来标注篡改区域,越接近蓝色则表明其不是篡改区域,越接近红色则表明其是篡改区域。除了可以检测一般的P图,它还可以检测聊天截图中可能存在的篡改区域。 其次是网页版的web,运行python web.py文件,可以在浏览器中通过与以上类似的操作,得到一张篡改检测的图像。 在这里插入图片描述在这里插入图片描述

在这里插入图片描述

部署方式

环境要求:

  • python 3.7

  • torch==1.13.1

  • torchvision==0.14.1

  • segmentation-models-pytorch==0.3.3

  • opencv-python-headless==4.9.0.80

  • Pillow==9.5.0

  • imageio==2.31.2

  • dearpygui>=1.3.0

  • scikit-learn==1.0.2


需要本文的详细复现过程的项目源码、数据和预训练好的模型可从该地址处获取完整版:地址跳转

 ​​

希望对你有帮助!加油!

若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值