论文精读+复现:Motion Artifact Removal in Pixel-Frequency Domain via Alternate Masks and Diffusion Model

 原文地址:Motion Artifact Removal in Pixel-Frequency Domain via Alternate Masks and Diffusion Model| Proceedings of the AAAI Conference on Artificial Intelligence

开源代码:https://github.com/medcx/PFAD

摘要

磁共振成像(MRI)中存在的运动伪影会严重干扰临床诊断。去除运动伪影是一个直接明了的解决方案,并且已经得到了广泛研究。然而,在近期的研究工作中仍然严重依赖成对数据,并且k空间(频域)中的扰动没有得到很好的考虑,这限制了它们在临床领域的应用。为了解决这些问题,我们提出了一种新颖的无监督净化方法,该方法利用有噪声的MRI图像的像素-频率信息来引导预训练的扩散模型恢复清晰的MRI图像。具体来说,考虑到运动伪影主要集中在k空间的高频分量中,我们利用低频分量作为引导,以确保组织纹理的正确性。此外,鉴于高频和像素信息有助于恢复形状和细节纹理,我们设计了交替互补掩码,以同时破坏伪影结构并利用有用信息。在来自不同组织的数据集上进行了定量实验,结果表明我们的方法在多个指标上都取得了优异的性能。与放射科医生进行的定性评估也表明,我们的方法提供了更好的临床反馈。

磁共振成像(MRI)作为无创医学成像技术,已在临床诊断中广泛应用。但因其采集时间长,部分患者(如老年人、幽闭恐惧症或癫痫患者等)在扫描过程中易因不适产生运动,这会破坏 MRI 的 k 空间(频域)信号分布,产生运动伪影,严重干扰临床诊断与治疗,因此亟需探索有效缓解运动伪影影响的方法。

最近,去噪扩散概率模型 (DDPM) 在图像生成领域表现出出色的表现 (Song, Meng, and Ermon 2020;Ho 和 Salimans 2022)。已经验证,结合图像(Rombach 等人,2022 年;Saharia 等人,2022a) 或文本(Nichol 等人,2021 年;Ramesh 等人,2022 年;Saharia 等人,2022b;Si 等人,2023 年;Zhang、Rao 和 Agrawala 2023)可以在 DDPM 的推理过程中促进生成针对特定要求量身定制的所需图像。然而,尽管基于扩散的方法在纯化或重建自然图像方面具有出色的能力,但在去除 MRI 图像的运动伪影方面的性能仍然有限。这是因为运动伪影存在于 k 空间中,而在其纯化过程中只有像素域信息感兴趣。

图1 在 MRI 成像过程中,由于患者的运动而产生运动伪影。通过扩散模型和像素频率域中的替代掩模去除伪影可以帮助放射科医生做出更好的诊断。

PFAD 方法核心思想概览

  1. 预训练扩散模型(Diffusion Model)

    • 使用未配对的 清晰MRI图像 对扩散模型进行预训练,学习无伪影图像的分布(即理想分布)。

  2. 低频引导(Low-Frequency Guidance)

    • 从受污染图像中提取低频部分,提供大致结构/纹理引导,使扩散生成更具结构一致性。

  3. 交替掩码(Alternate Masks)

    • 对图像的高频/像素部分及生成图像应用掩码,有选择地打乱伪影,保留有价值信息用于恢复;

    • 掩码在每次反向扩散步骤中 交替翻转,覆盖图像不同区域,有助于全局伪影去除。

  4. 扩散扰动调节参数(Diffusion Time-Dependent Parameter)

    • 引入一个随时间变化的参数,在扩散过程中增强对伪影的扰动效果,改善最终图像质量。

  5. 频域-像素域恢复权重平衡超参数(Balancing Hyperparameter)

    • 控制频域与像素域信号之间的恢复权重,获得最优清晰图像。

贡献

  • “无监督 + 像素-频率融合”:避开了配对数据难题,这是很多医学图像问题中的痛点;

  • “替代互补掩码”机制:和常规的掩码不同,PFAD是动态交替的,这种机制确实在频域/像素域去伪影中相当有效;

  • “动态调节参数 + 消融实验”:从方法设计到实验验证都非常严谨完整。

Methodology

Preliminary

Notice

符号含义类比理解
x图像数据(小写)就是我们处理的图像,比如 MRI 图像
t当前时间步相当于“第几次加噪”
T总共的时间步加噪一共加了几次,比如 1000 步
N(\cdot ;\mu ,\sigma ^{^{2}}I)高斯分布就是随机生成噪声的“随机函数”
D_{\theta }神经网络(扩散模型)我们要训练的模型,参数是 θ\thetaθ
F,F^{^{-1}}傅里叶变换/逆变换把图像变成频率域、再变回来
\phi _{l},\phi {_{h}}理想低通、高通滤波器提取图像的低频轮廓 / 高频细节
\left | \cdot \right |模值运算
Hadamard积(点乘)对应元素相乘,比如两个图像重叠的像素点相乘

Diffusion Model

扩散过程(Forward Process)

我们从一张干净图像 x{_{0}} 开始,一步步往里面加随机噪声(白噪声),过程如下:

  • 第一步:x1=在 x0 上加一点点噪声

  • 第二步:x2=再加一点噪声

  • 第 T 步后:图像变成一堆纯噪声(就像雪花电视)

公式表达:

  • 每一步加的噪声强度是 \beta _{t},随着 t 增大;

  • I是单位矩阵,表示噪声是独立的;

  • 这一步是固定的,不需要神经网络来学习。

反向去噪(Reverse Process)

这是模型要学习的核心部分!我们希望:

能不能从纯噪声 xT​,一步步“去噪”,还原出干净图像 x0​?

于是我们设计了一个神经网络 ​D_{\theta },学会从 x_{t}预测出上一步 x_{t-1}

  • μθ​(xt​,t):神经网络预测的“去噪后图像的均值”;

  • \sigma _{t}^{2}:是我们手动设置的噪声强度;

  • 模型会从x_{T}\sim N(0,I)开始,逐步生成x_{T-1},x_{T-2}\cdots x_{0}

训练方法(Learning)

如何训练这个模型呢?我们不直接让模型预测图像本身,而是让它预测“加进去的噪声 \varepsilon”!

我们构造带噪图像:

然后让模型预测出\varepsilon ,用 L2 loss 训练:

其中:

  • \varepsilon _{\theta }​ 是你训练的神经网络,输入 x_{t}和时间步 t,输出噪声预测;

  • 如果它能正确预测出加进去的噪声,那么我们就能反推出原图 x_{0}

整个流程

步骤描述类比
1. Forward清晰图像一步步加噪成纯噪声把照片撕碎
2. Reverse模型学会从纯噪声一步步恢复出清晰图像照片修复术
3. 训练目标让模型学会预测噪声,以间接学会去噪学会还原破坏过程

Motion Artifact Removal with PFAD

图2 PFAD架构
绿色区域是频域中的伪影去除过程。这个过程以运动损坏图像的低频信息和部分高频信息为导向,并使用扩散模型生成另一部分干净的高频信息。黄色区域是像素域伪影去除过程,由前向过程的部分图像信息引导。将两个域的结果组合在一起后,最终迭代生成干净的图像。

 PFAD 的原理和过程框架如如图 2 所示。首先,我们介绍了如何在频域和像素域中分别使用交替互补掩码。其次,我们介绍如何创建替代互补掩码。最后,我们设计了一个超参数来平衡频域和像素域的权重。整个过程的详细信息如算法 1 所示。

算法 1

Frequency Domain Removal

目标是:在频率域中消除运动伪影(Motion Artifact)

  • x_ori:原始的含有运动伪影的图像。

  • x'_{t-1}:在频域上重组后的图像(去伪影处理的目标)。

  • f:图像的频率信息(通过傅里叶变换得到)。

  • Φ_l(.):低频成分提取操作(Low-Frequency Extractor)。

  • Φ_h(.):高频成分提取操作(High-Frequency Extractor)。

  • M_t:交替的互补掩码(Alternate Complementary Mask),用于融合不同图像的高频部分。

  • :表示逐元素乘法。

  • F⁻¹:傅里叶反变换操作,将频域图像还原到时域。

(4)低频保持不变:

x_ori(带有运动伪影的图像)中提取出的低频部分直接保留,用于重建图像的低频结构。
因为低频部分主要包含图像的大致轮廓和颜色,不容易受到运动伪影的影响。

(5)高频交替掩码融合:

将当前带伪影图像 x_ori 与上一时刻图像 x_{t-1} 的高频部分按掩码 M_t 做融合处理。

  • M_t 是一个交替的二值掩码,用来决定当前时刻保留哪一部分频率的信息。

  • M_t 的作用是为了在时间上(或多个估计结果之间)互补性地去除伪影,防止同一个区域反复保留错误的高频信息。

  • 这一步的目标是:剔除掉受到运动伪影影响的高频区域,同时补充可靠的高频细节

(6)频域重建:

将上面获得的高频和低频信息组合,再通过傅里叶反变换 F⁻¹ 恢复出一个在时域中去除了部分运动伪影的图像 x'_{t-1}

总结
  1. 保留低频信息(不容易受到伪影影响)。

  2. 用互补掩码融合高频(去除运动伪影影响大的高频区域)。

  3. 组合高低频,重建图像。

整个过程就是在频率域里,通过选择性保留信息 + 互补融合机制,来减弱运动伪影对图像的破坏。

Pixel Domain Removal

前面 Frequency Domain Removal 已经做了一轮修复,但为什么还要在像素域再做一次呢?原因有两个核心点:

  • 弥补低频信息缺失,提高图像自然度:频率域处理主要是高频修复(细节、纹理),虽然有效,但缺少对低频结构(轮廓、色调等)的处理,导致整体图像会显得不自然、不协调
  • 修正频域中因取模操作造成的轻微相位偏差:频域重建时,因频域处理可能涉及相位错位(尤其在反变换时),会导致图像位置偏差或模糊。而像素域的操作可以直接在图像层面进行细化修复。

方法与公式

与频率域不同的是,这里用的是 x_for_{t-1} 而不是 x_ori,即前向过程生成的图像。

  • x''_{t-1}:在像素域中重组后的图像。

  • x^{for}_{t-1}:前向扩散过程生成的图像(forward process image),具备一定生成能力。(“由清晰图像加噪得到的第 t-1 步结果,比较接近真实,能提供修复参考。”)

  • x_{t-1}:上一个时间步的重建图像。(“正在恢复图像的当前版本,有点模糊、有些伪影。”)

  • M_t:互补掩码(和频域一样),决定保留哪个图像的哪一部分。

  • :逐元素乘法,代表按位融合。

在像素层面重新融合图像内容:

  • x^{for}_{t-1} ⊙ M_t:使用 diffusion 模型生成的图像作为“参考或修复”来源,用于替代图像中含有伪影的区域。

  • x_{t-1} ⊙ (1 - M_t):保留上一时刻的图像中未被污染的部分。

最终生成的 x''_{t-1} 是一个 经过像素级重构 的图像,既借助了扩散模型的生成能力,又保留了稳定的局部信息,从而提升整体质量和自然感。

名称形式含义来源作用
x_{t-1}reverse step 中的图像当前反扩散步骤的输出来自 x_t 通过去噪预测得到当前生成的图像版本
x^{for}_{t-1}forward step 中的图像清晰图像通过扩散正向到第 t-1从 GT 清晰图像加噪得到提供参考/指导用图像(较真实)
总结
阶段方法目的融合对象
Frequency Domain Removal在频域内对高频信息处理去除明显伪影、恢复细节纹理x_orix_{t-1} 的高频部分
Pixel Domain Removal在像素域中融合原图和扩散图像修复低频结构、弥补频域缺陷x^{for}_{t-1}x_{t-1} 的像素值

Alternate Complementary Masks

1. 设计背景:为什么要使用“掩码”?

在扩散模型中,图像是逐步生成的。由于运动伪影往往呈现区域性分布(比如条纹、拖影),我们不能整张图统一处理,否则会破坏图像中仍然保留的“有用信息”。

因此使用掩码 m_t,只处理图像中的一部分区域:

  • 有助于保留未被破坏的信息。

  • 但也存在问题:未处理区域可能继续带有运动伪影

2. 交替互补(Alternate Complementary)是啥意思?

以前的工作(比如 Lugmayr 2022)使用的是固定掩码,也就是说每一轮都遮住相同位置的信息,容易导致某些区域总是没被处理,某些区域总是被处理,信息不均衡。

而 PFAD 的创新点是:

每一轮反向步骤中使用 互补掩码

  • 假设 m_t 是一个“交错网格”掩码(比如类似国际象棋黑白格子),

  • 那么 m_{t-1} 就是它的反色(黑白互换),正好遮住之前没处理的区域。

这样迭代下来,整个图像的所有区域都会被掩盖处理一遍,从而达到完整地打破运动伪影分布的目的,同时最大程度保留有用信息。

3. 问题:掩码虽然打破伪影,但也留下伪影!

虽然掩码有作用,但保留下来的区域依旧含有部分运动伪影,会继续影响扩散模型生成。

怎么办?

4. 解决方法:引入权重因子 ωₜ 来动态调整掩码作用强度

其中:

  • m_{t}是布尔掩码(0或1);

  • \omega {_{t}}​ 是一个 时间相关的权重

  • \omega _{t}=1-\sqrt{\overline{\alpha _{t}}},这是 DDPM 中的固有参数。

5.  \omega {_{t}}的变化趋势:从1逐渐到0

时间步 t位置解释
t = T (初始)ωₜ ≈ 1初始噪声很多,掩码可以大胆保留有用信息
t 减小时ωₜ 减小图像越生成越清晰,逐步降低伪影区域的保留强度
t → 0(生成完成)ωₜ → 0图像已经生成完整,此时若再保留带伪影的区域,会影响质量,需要抑制它

📌 所以这个机制的意义是:

初期:注重结构生成,允许有些带伪影区域提供信息;
后期:图像结构已具备,要重点清理掉残留伪影,减少其影响。

总结

PFAD 的交替互补掩码机制,通过交错处理所有区域 + 动态调整掩码作用强度,实现了在保留有用信息的同时,逐步剥离运动伪影的影响,这是其性能优于传统 Diffusion 的关键之一。

Dual Domain Balance(双域融合平衡)

1.问题背景:为什么要双域融合?

优点缺点

频率域

Frequency Domain

擅长去除结构性伪影(如条纹、重影),恢复纹理细节容易因频域处理带来不自然感
像素域 Pixel Domain更接近图像本身表现,利于还原自然视觉对细节和纹理恢复相对弱

所以,如果只靠一个域来重建图像,会有失偏颇:

  • 频域强去伪影,但图像不自然;

  • 像素域保证自然感,但伪影清理不彻底。

所以——将两者融合,并且融合比例要“动态变化” —— 这就是 Dual Domain Balance 的核心思想。

2.引入平衡因子 γₜ

作者提出了一个随时间变化的融合权重 γₜ,控制频率域图像 x′_{t-1} 和像素域图像 x′′_{t-1} 的权重。

公式如下:

  • \gamma _{t}\in (1-a,1),即初始时接近 1,后期逐渐降低

  • 超参数 a∈[0,1],可以控制下降速率。

这意味着,在整个去噪(反扩散)过程中,随着时间 t 变小,频域的权重逐步降低,像素域权重逐步提高。

3.融合公式:生成最终的组织图像 x_{t-1}^{e}

解释如下:

  • x_{t-1}^{'}:频域重组图像(带细节,擅长去伪影);

  • x_{t-1}^{''}:像素域重组图像(保持自然感,补偿频域损失);

  • x_{t-1}^{e}​:两者加权融合的最终重组图像,进入下一步扩散。

4.为什么这个设计很妙?

时间阶段γₜ 值代表策略原因

反扩散初期

(t ≈ T)

γₜ ≈ 1主要靠频域去除大伪影此时图像噪声较大,频域更有效
中期γₜ ≈ 0.7~0.5两者加权融合图像逐步成型,开始讲究自然感

后期

(t → 0)

γₜ → 1 - a(可能≈0.2)主要靠像素域恢复真实感此时形状基本定型,需要提升视觉质量
总结

PFAD 通过引入 随时间变化的权重 γₜ,实现了频域去伪影与像素域图像自然性的动态平衡融合,实现了既清晰又自然的图像重建,是整个系统精华所在.

总结

  1. 研究贡献:提出新颖扩散模型架构用于运动伪影去除
  2. 方法创新:引入 Pixel-Frequency 双域互补掩码
  3. 实验效果:真实图像效果好,模拟图像略逊
  4. 总结意义:提升诊断可靠性,有望辅助临床医生

代码讲解

scripts/image_sample.py

PFAD 的“推理入口脚本”,它负责加载预训练模型、读取运动伪影图像、进行图像还原(artifact removal),并将还原后的图像保存下来。

总体结构:

部分说明
1️⃣ 配置与参数设置种子,读取 YAML 配置文件,准备推理参数。
2️⃣ 模型加载创建 model 和 diffusion,并加载权重。
3️⃣ 读取数据从本地读取运动伪影图像。
4️⃣ 前向推理利用 PFAD 进行 artifact removal。
5️⃣ 保存图像把还原图像写入 .tiff 文件。
"""
生成去伪影后的图像样本,并保存为.tiff格式
此脚本用于PFAD模型的推理过程
"""

import argparse
import os
import random
import numpy as np
import torch
import yaml
import SimpleITK as sitk  # 用于处理医学图像格式(如TIFF)

# 导入PFAD相关模块
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,     # 获取默认模型配置
    create_model_and_diffusion,       # 创建模型和扩散器
    add_dict_to_argparser,            # 将配置字典转为argparser参数
    args_to_dict                      # 从argparser对象提取参数为字典
)


def main():
    # 指定GPU设备(这里默认用第0块卡)
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    # 固定随机种子,保证结果可复现
    seed = 42
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # 命令行参数:配置路径、图像目录、保存目录
    parser = argparse.ArgumentParser()
    parser.add_argument('--conf_path', type=str, default='../conf/brain_sample_config.yml')
    parser.add_argument('--img_dir', type=str, default='brain')  # 图像子目录名
    parser.add_argument('--save_path', type=str, default='motion_remove')  # 结果保存目录
    config = parser.parse_args()

    # 加载 YAML 配置文件
    with open(config.conf_path) as f:
        c = yaml.load(f, Loader=yaml.FullLoader)

    # 将配置文件中参数添加到 parser 中,并解析
    add_dict_to_argparser(parser, c)
    args = parser.parse_args()

    # 设置分布式(这里虽然是单卡,但PFAD框架是为多卡设计的)
    dist_util.setup_dist()

    # 日志记录初始化
    logger.configure()
    logger.log("creating model and diffusion...")

    # 创建模型与diffusion对象
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )

    # 加载预训练模型权重
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )

    # 将模型放到GPU上
    model.to(dist_util.dev())

    # 若设置fp16精度,进行转换
    if args.use_fp16:
        model.convert_to_fp16()

    # 获取 PFAD 中定义的采样函数
    sample_fn = diffusion.PFAD_sample
    logger.log("sampling...")

    # ========== 图像推理主循环 ==========

    # 设置运动伪影图像路径与结果保存路径
    motion_corrupted_dir = f'../data/{args.img_dir}'
    save_path = f'../results/{args.save_path}/{args.img_dir}'
    os.makedirs(save_path, exist_ok=True)

    # 获取所有待处理图像名列表
    motion_corrupted_list = os.listdir(motion_corrupted_dir)
    motion_corrupted_list.sort()  # 保证处理顺序一致

    # 遍历每张图像进行去伪影
    for i, d in enumerate(motion_corrupted_list):
        d_path = f'{motion_corrupted_dir}/{d}'

        # 读取图像并归一化到 [-1, 1]
        t = sitk.GetArrayFromImage(sitk.ReadImage(d_path)).astype(np.float32)
        t = (t - t.min()) / (t.max() - t.min()) * 2 - 1

        # 增加 batch 维度和通道维度,形成 shape: (1, 1, H, W)
        t_d = np.expand_dims(t, axis=(0, 1))

        logger.log(f"Artifact Removal: {d}")

        # 构建输入 batch
        batch = {
            'GT': t_d,  # GT 是运动伪影图像(PFAD 以此为输入)
        }

        # 转为 tensor,并移动到 GPU
        for k in batch.keys():
            if not isinstance(batch[k], torch.Tensor):
                batch[k] = torch.from_numpy(batch[k])
            batch[k] = batch[k].to(dist_util.dev())

        # 传入模型的附加信息
        model_kwargs = {
            "gt": batch['GT'],  # PFAD网络需要的原始图像
            "diff": diffusion   # 采样函数可能会用到diffusion对象本身
        }

        # 推理过程,不需要计算梯度
        with torch.no_grad():
            sample = sample_fn(
                model,
                (t_d.shape[0], 1, args.image_size, args.image_size),  # 输出尺寸
                clip_denoised=args.clip_denoised,
                model_kwargs=model_kwargs,
                device=dist_util.dev(),
                conf=args
            )

            # 取出结果,并还原回 [0, 65535](保存为 uint16 图像)
            recovered = sample['out'].cpu().detach().numpy().squeeze()
            recovered = recovered.clip(-1, 1)
            recovered = (((recovered + 1) / 2) * 65535).astype(np.uint16)

            # 保存为 tiff 图像
            recovered = sitk.GetImageFromArray(recovered)
            sitk.WriteImage(recovered, f'{save_path}/clean_{i}.tiff')

    logger.log("sampling complete")


# 用于未来扩展时的argparser创建函数(当前没用上)
def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()

GaussianDiffusion

核心任务:​​​​​​​

  • 前向扩散过程:给图像添加噪声。

  • 反向去噪采样过程:逐步还原图像。

  • 训练损失计算:构建监督信号进行模型训练。

方法名作用所属阶段
q_sample将噪声添加到图像,模拟前向过程前向扩散
q_mean_variance给定原图和时间步,计算前向分布的均值和方差前向扩散
q_posterior_mean_variance给定 t 时刻噪声图,预测 t-1 时刻的图反向过程
p_mean_variance模型预测的均值、方差输出反向采样
p_sample单步采样,即从 t 到 t-1 的预测反向采样
training_losses训练阶段计算损失函数训练
PFAD_samplePFAD中特有,带mask引导的扩散采样反向采样
deal_out结合 GT mask 和 gamma 计算最终输出图像后处理
_scale_timesteps如果使用浮点数时间步,进行归一化处理时间步变换

第1步:首先,了解扩散模型的前向过程(加噪声)和反向过程(去噪声)

def q_sample(self, x_start, t, noise=None):
    """
    执行正向扩散步骤,通过向数据添加噪声。

    该函数从 q(x_t | x_0) 中采样,通过将噪声添加到输入图像 `x_start` 中进行 `t` 次扩散步骤。

    :param x_start: 初始数据批次(原始图像)。
    :param t: 当前扩散步骤的时间步(从 0 开始,表示一次扩散步骤,`t` 对应于扩散的次数)。
    :param noise: 可选参数,添加的噪声。如果未提供,将生成正态分布的随机噪声。
    :return: 扩散 `t` 步之后的噪声图像。
    """
    if noise is None:
        # 如果未提供噪声,则生成与 `x_start` 形状相同的随机噪声。
        noise = th.randn_like(x_start)
    
    # 检查噪声的形状是否与 `x_start` 相匹配。
    assert noise.shape == x_start.shape
    
    # 通过应用扩散过程的 alpha 系数来计算带噪声的图像。
    # 我们为每个时间步提取相关的 alpha 值,并将噪声添加到图像中。
    return (
        _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start  # 缩放的 x_start
        + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise  # 噪声
    )


def p_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None
):
    """
    从模型中采样前一个时间步 `x_{t-1}`。

    该函数根据给定的时间步 `t` 从模型中采样前一个状态 `x_{t-1}`。
    模型生成用于去噪的预测均值和方差。

    :param model: 用于执行去噪的模型。
    :param x: 当前在时间步 `t` 的噪声图像。
    :param t: 当前时间步。
    :param clip_denoised: 如果为 True,则对去噪后的图像进行裁剪,将其值限制在 [-1, 1] 之间。
    :param denoised_fn: 如果不为 None,则应用该函数来处理去噪后的 `x_start` 预测结果。
    :param cond_fn: 如果不为 None,则这是一个类似于模型的梯度函数,用于条件化模型。
    :param model_kwargs: 如果不为 None,这是一个字典,包含要传递给模型的额外关键词参数。这可以用于条件化。
    :return: 一个字典,包含以下键:
             - 'sample': 从模型中采样得到的随机样本。
             - 'pred_xstart': 对 `x_0` 的预测。
    """
    # 生成随机噪声,形状与当前输入图像 `x` 相同。
    noise = th.randn_like(x)

    # 获取模型对当前图像 `x` 在时间步 `t` 下的均值和方差。
    out = self.p_mean_variance(
        model,
        x,
        t,
        clip_denoised=clip_denoised,
        denoised_fn=denoised_fn,
        model_kwargs=model_kwargs,
    )

    # 创建一个非零掩码,如果当前时间步 `t` 不为 0,则掩码为 1,否则为 0。
    nonzero_mask = (
        (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
    )

    # 如果提供了条件函数,则使用该函数来条件化均值。
    if cond_fn is not None:
        out["mean"] = self.condition_mean(
            cond_fn, out, x, t, model_kwargs=model_kwargs
        )

    # 使用均值和方差对噪声进行采样。
    sample = out["mean"] + nonzero_mask * \
             th.exp(0.5 * out["log_variance"]) * noise

    # 返回一个字典,包含采样结果和 `x_0` 的预测。
    result = {"sample": sample,
              "pred_xstart": out["pred_xstart"], 'gt': model_kwargs.get('gt')}

    return result


def PFAD_sample(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        conf=None
):
    """
    从模型生成样本,并在每个扩散时间步生成中间样本。

    参数与 `p_sample_loop()` 函数相同。
    返回一个生成器,每次迭代返回一个字典,字典包含 `p_sample()` 函数的返回值。

    :param model: 用于生成样本的模型。
    :param shape: 生成样本的形状(例如,图像的大小)。
    :param noise: 添加到样本的噪声,默认为 None。
    :param clip_denoised: 如果为 True,则对去噪后的图像进行裁剪。
    :param denoised_fn: 如果不为 None,则应用该函数来处理去噪后的图像。
    :param cond_fn: 如果不为 None,则使用该函数进行条件化。
    :param model_kwargs: 包含额外的关键词参数,传递给模型。
    :param device: 设备(如 CPU 或 GPU),默认为 None。
    :param conf: 配置参数,控制扩散过程的细节(例如,时间步数、补丁大小等)。
    :return: 返回一个字典,包含最终生成的图像。
    """
    if device is None:
        # 如果未指定设备,则使用模型的默认设备(通常是 GPU)。
        device = next(model.parameters()).device
    
    # 确保 `shape` 是元组或列表。
    assert isinstance(shape, (tuple, list))

    # 获取原始图像
    ori_img = model_kwargs['gt']
    # 初始化图像
    image_after_step = th.randn_like(ori_img)

    if conf.middel_removal:
        # 如果启用了中间去除,则按照指定的中间去除步数进行扩散。
        t_f = th.tensor([conf.middel_removal] * shape[0], device=device)
        image_after_step = self.q_sample(ori_img, t_f, noise)
        time = conf.middel_removal
    else:
        # 否则,使用配置中的时间步数。
        time = int(conf.timestep_respacing) if conf.timestep_respacing is not None else 1000

    # 创建交替的互补掩码
    patch_size = conf.patch_size
    image_size = (shape[-1], shape[-1])
    mask = create_mask(patch_size, image_size)

    # 设定扩散步骤的时间序列,从高时间步到低时间步。
    times = range(time - 1, -1, -1)
    from tqdm import tqdm
    time_pairs = tqdm(times)
    j = 0

    # 在每个扩散时间步进行样本生成
    for t_cur in time_pairs:
        t = th.tensor([t_cur] * shape[0], device=device)
        img_forward = self.q_sample(ori_img, t, noise)
        if j % 2 == 1:
            mask = 1 - mask  # 切换掩码

        with th.no_grad():
            # 从模型中采样并去噪
            out = self.p_sample(
                model,
                image_after_step,
                t,
                clip_denoised=clip_denoised,
                denoised_fn=denoised_fn,
                cond_fn=cond_fn,
                model_kwargs={},
            )
            # 处理输出图像
            image_after_step = self.deal_out(out["sample"], ori_img, img_forward, t, mask, conf.gamma_t)
        j += 1

    return {'out': image_after_step}  # 返回最终生成的图像

PFAD_sample

PFAD 中的核心采样过程,它基于扩散模型的反向过程进行图像恢复,同时加入了一种创新性的 “交替掩膜机制” 来控制哪些区域恢复、哪些保持原始结构,是 PFAD 的最大亮点之一

这个函数从一个随机噪声图像开始,利用反向扩散过程一步步还原图像,同时引导模型专注于恢复图像的模糊区域,并保留清晰区域不被破坏。

函数结构拆解图(流程图级别理解):

输入:
  model, shape, ori_img (from model_kwargs), mask配置conf, etc.

↓(初始化)
1. image_after_step ← 纯噪声图(randn_like ori_img)
2. 如果 conf 设置了 middel_removal → 用 q_sample 加中间t加噪(=partial blur)

↓(for 循环,逆扩散)
3. for t in T-1 to 0:
    - 构造 img_forward = ori_img 在当前 t 的加噪版本
    - 隔一步翻转 mask(实现交替采样)
    - 执行 p_sample 得到模型预测图 sample
    - 调用 deal_out 将 sample + ori_img + img_forward + mask 组合成新的 image_after_step

↑ 重复

输出:
  image_after_step(恢复后的图像)

deal_out(双频融合)

PFAD 的核心处理函数之一,它将频域知识融合进了扩散模型的图像恢复过程中

deal_out 接收模型预测图(img_reverse)、原图(ori_img)、当前加噪图(img_forward)和交替掩膜,通过频域混合 + 空间域插值,生成新的更合理的图像状态 img_re,以便进入下一个扩散步。

大致结构如下(流程拆解):

img_reverse     ← 当前模型预测的图像
ori_img         ← 原始清晰图像(GT)
img_forward     ← 当前时间步的前向噪声图(模糊版本)
gt_keep_mask    ← 控制哪个区域保留原图的掩膜
gamma_t         ← 控制两个输出融合程度的参数

→ step 1: 频域融合(得到x_1)
→ step 2: 空间域融合(得到x_2)
→ step 3: 两个结果再加权融合(输出img_re)

第一步:频域重建 x_1:结合原图和模型预测图的频率成分

gt_k_space = fft2(ori_img_2d)               # 原图的频谱
x_reverse_space = fft2(img_reverse_2d)      # 模型预测图的频谱

傅里叶变换后的图像变成频域图,每个值表示一种频率成分的强弱。

接着:

low_freq[keep, :] = gt_k_space[keep, :]
high_freq_keep[change, :] = gt_k_space[change, :]
high_freq_change[change, :] = x_reverse_space[change, :]

把:

  • 低频部分(图像结构信息) → 保留原图

  • 高频部分 → 保留原图的一部分 + 模型预测图的一部分

然后构造融合频谱:

x_re_freq = low_freq + high_freq_keep * freq_mask + high_freq_change * (1 - freq_mask)

这个 freq_mask 控制“用 GT 还是预测的高频”。

最后你用反傅里叶:

x_re = np.abs(ifft2(x_re_freq))

得到 x_1,它是一个 频域混合出来的新图像

第二步:像素混合图 x_2

k_2 = gt_keep_mask_2d * omega_t
x_2 = img_forward_2d * k_2 + img_reverse_2d * (1 - k_2)

得到 x_2,是另一个版本的新图像。

第三步:加权融合两个图像 x_1 & x_2

k = 1 - gamma_t * np.exp(-t_c / 1000)
img_re = k * x_1 + (1 - k) * x_2

一个随时间变化的融合系数 k

  • 时间步越小,k 越大 → 更信任 x_1(频域精细重建)

  • 时间步越大,k 越小 → 更信任 x_2(空间更平滑)

最终输出 img_re融合频域重建图像和空间混合图像的结果

总结

deal_outPFAD_sample 就是双域互补掩码的核心代码:

模块作用
🔁 PFAD_sample扩散逆过程控制流程(判断 mask 轮换、执行 deal_out
🔀 deal_out执行“双域融合”:频域和空间域互补 + 掩膜引导

 dist_util.py

封装了分布式训练中最常用的辅助函数,比如:

  • 初始化多GPU环境

  • 获取当前 GPU

  • 多GPU 同步模型参数

  • 从主GPU广播变量

  • 多卡下加载模型权重

功能函数名用途
初始化通信setup_dist()构建分布式环境(GPU通信)
设备分配dev()给当前进程分配对应 GPU
载入模型权重load_state_dict()避免多卡重复加载模型
参数同步sync_params()保证多卡的模型参数完全一致
通信端口管理_find_free_port()自动找可用端口做通信

👉 分布式训练流程:

  1. 所有进程并行执行一套相同的训练逻辑

  2. 每个 GPU 加载自己的 mini-batch 数据;

  3. 前向传播 → 损失计算 → 反向传播;

  4. PyTorch 自动在多个 GPU 之间 同步梯度(AllReduce)

  5. 更新模型参数。

代码复现

运用conda+mamba搭建环境实现:

1.配置环境

mamba create -n pfa_test python=3.9 -y

conda activate pfa_test

conda install -c conda-forge simpleitk=2.3.1 libitk=5.3.0 -y

验证(pfa_test) D:\B——SR_project\PFAD>conda list | findstr "itk":

libitk                    5.3.0                h170c4af_8    conda-forge
simpleitk                 2.3.1            py39h99910a6_2    conda-forge

安装必要包:

pip install numpy

pip install blobfile

pip install mpi4py

pip install scipy

pip install numpy==1.24.3

pip install tqdm

由于我的电脑是单卡,所以,修改原来的开源代码,禁用分布式训练,修改:PFAD\scripts\guided_diffusion\dist_util.py文件,

测试:

cd scripts
python image_sample.py --conf_path ../conf/brain_sample_config.yml --img_dir brain --save_path motion_remove

效果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值