TowardsDataScience 2023 博客中文翻译(一百八十九)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

IID: 初学者的意义和解释

原文:towardsdatascience.com/iid-meaning-and-interpretation-for-beginners-dbffab29022f

独立同分布

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Jae Kim

·发表于Towards Data Science ·阅读时间 9 分钟·2023 年 8 月 19 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Yu Kato提供,来源于Unsplash

在统计学、数据分析和机器学习主题中,IID 概念作为一个基本假设或条件经常出现。它代表了“独立同分布”。IID 随机变量或序列是统计模型或机器学习模型的重要组成部分,同时也在时间序列分析中发挥作用。

在这篇文章中,我以直观的方式解释了在采样、建模和预测性三个不同背景下的 IID 概念。文中提供了一个带有 R 代码的应用案例,涉及时间序列分析和预测性。

采样中的 IID

表示 X ~ IID(μ,σ²)的符号表示从具有均值μ和方差σ²的总体中以纯随机的方式对(X1, …, Xn)进行采样。即,

  • 每个连续的 X 的实现都是独立的,与前一个或后一个没有关联;并且

  • 每个连续的 X 的实现都来自具有相同均值和方差的相同分布。

示例

假设从一个国家的年收入分布中采集了样本(X1, …, Xn)。

  1. 一名研究人员选择了男性作为 X1,女性作为 X2,男性作为 X3,然后女性作为 X4,这种模式保持到 Xn。这不是一个 IID 采样,因为采样中的可预测或系统性模式是非随机的,违反了独立性条件

  2. 一名研究人员从最贫困的群体中选择了(X1, … X500),然后从最富有的群体中选择了(X501, … X1000)。这不是一个 IID 采样,因为这两个群体的收入分布具有不同的均值和方差,违反了同一性条件

建模中的 IID

假设 Y 是你想建模或解释的变量。那么,它可以分解为两个部分,即,

Y = 系统性成分 + 不系统性成分。

系统性成分 是由与其他因素的基本关系驱动的 Y 部分。它是可以通过理论常识典型事实 解释或预期的部分。它是 Y 的基础部分,具有实质性和实际重要性。

不系统性成分Y 中不受基本因素驱动的部分,无法通过理论、推理或典型事实解释或预期。它捕捉 Y 中无法通过系统性成分解释的变动。它应该是 纯随机的 和特有的,没有任何系统性可预测的模式。在统计模型中称为误差项,通常表示为 IID 随机变量。

例如,考虑以下形式的线性回归模型:

方程 (1)

在 (1) 中,α + βX 是系统性成分,而 (1) 中的误差项 u 是不系统性成分。

如果 β 的值接近 0 或在实际中可以忽略,则变量 XY 的解释力(用 R² 测量)较低,表明它不能令人满意地解释 Y 的基本变动。

假设误差项 u 是一个 IID 随机变量,均值为零且方差固定,表示为 u ~ IID(0, σ²),这是纯随机的,代表 Y 中的不系统或意外变动。

如果 u 不是纯随机的且具有明显的模式,则系统性成分可能没有被正确指定,因为它缺少某些实质性或基本内容。

示例:自相关

假设误差项具有以下模式:

方程 (2)

这是线性依赖(或自相关),这是一个系统性模式。这种可预测模式应纳入模型部分,这将更好地解释 Y 的系统性成分。实现这一目标的一种方法是包含 Y 的滞后项在 (2) 中。即,

方程 (3)

在 (3) 中包含的 Yt 的滞后项能够捕捉 (2) 中误差项的自相关,因此 (3) 中的误差项 e 是 IID。

示例:异方差性

假设误差项显示出以下系统性模式:

方程 (4)

这种误差项模式称为异方差性,其中误差项的变异性随 X 变量的变化而变化。例如,假设 Y 是食品支出,X 是个人的可支配收入。方程 (4) 意味着高收入者的食品支出变异性更高。

这是一个可预测的模式,而具有性质(4)的误差项违反了 IID 的假设,因为误差项的方差不是常数。为了将这种模式纳入系统组件中,可以通过以下方式进行广义或加权最小二乘估计:

方程(5)

方程(5)是一个带有变换变量的回归,可以写成

方程(6)

其中

适用于异方差误差的变换

上述对YX的变换提供了方程(6)中的变换误差项(ut*),它是一个 IID 且不再具有异方差性。即,

这意味着,通过上述变换,误差项中的系统模式现在有效地纳入了系统组件。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者创作的图像

上述图形以直观的方式展示了变换的效果。在变换之前(左侧图),变量Y的变异性随着X的变化而增加,这反映了异方差性。变换有效地将异方差模式纳入了Y的系统组件中,变换后的误差项现在是一个 IID 随机变量,如右侧图所示。

许多回归或机器学习模型中的模型诊断测试旨在检查误差项是否遵循 IID 随机变量,使用从估计模型中得到的残差。这也称为残差分析。通过残差分析和诊断检查,可以改善模型的系统组件的规范。

IID 和可预测性

纯粹随机的 IID 序列完全没有可预测的模式。也就是说,它的过去历史对序列未来的走向没有任何信息。

示例:自回归模型

考虑一个自回归模型,记作 AR(1),

方程(7)

其中ut ~ IID(0,σ²)且 -1 < ρ < 1(ρ ≠ 0)。

如果ρ = 0,时间序列 Yt 是一个 IID 且不可预测的,因为它不依赖于自己的过去,仅由不可预测的冲击驱动。

为了简化,假设 Y0 = 0 且ρ ≠ 0,进行以下持续替代:

Y1 = u1;

Y2 = ρY1 + u2 = ρu1 + u2;

Y3 = ρY2 + u3 = ρ²u1 + ρ u2 + u3;

Y4 = ρY3 + u4 = ρ³u1 + ρ²u2 + ρu3 + u4;

其一般表达式为

方程(8)

方程(6)表明,一个时间序列(如自回归)可以表示为过去和当前 IID 误差(或冲击)的移动平均,并具有指数衰减的权重。

注意,远程冲击如(8)中的 u1 和 u2 对Yt的影响很小,因为它们的权重微不足道。例如,当ρ = 0.5 且t = 100 时,ρ⁹⁹和ρ⁹⁸几乎为 0。只有当前或最近的冲击,如 u100、u99 和 u98,才可能实际相关。

因此,如果研究人员在时间 t 对ρ有一个良好的估计(来自数据)并观察了当前和近期的冲击,如 ut、ut-1、ut-2 和 ut-3,她或他可能能够通过将(8)中的移动平均投射到未来,合理准确地预测 Yt+1 的值。

示例:随机游走

当ρ = 1 时,(7)中的时间序列变成了一个随机游走,其中当前的Y变化是一个纯粹不可预测的 IID 冲击:即,

在这种情况下,从(8)中,ρ = 1,我们得到

换句话说,随机游走是所有过去和当前 IID 冲击的总和,其权重为 1。因此,远离的冲击与近期和当前冲击同等重要。例如,如果 t = 100,冲击 u1 对 Y100 的影响与 u100 相同。

作为所有过去和当前冲击的总和,随机游走时间序列是完全不可预测的。它还表现出高度的不确定性和持久性(对过去的依赖),具有以下分析结果

方程(9)

这意味着随机游走的变异性随着时间的推移而增加,表明随时间的不确定性高且可预测性低。

此外,Yt 和 Yt-k 之间的相关性几乎等于 1,对于几乎所有的k值。例如,当 t = 100 时,Y100 和 Y99 的相关系数为 99/100 = 0.99。

应用

作为一个应用,通过时间图和自相关函数比较了 IID 过程、ρ ∈ {0.3, 0.6, 0.9}的 AR(1)时间序列和随机游走的基本描述特性。

时间图

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

时间图:图像由作者创建

  • IID 序列 Y1 作为一个 AR(1)时间序列,ρ = 0 时,完全没有规律,随机且频繁地在均值 0 附近波动。它有很强的回归均值的倾向。

  • 对于 Y2 到 Y4,当ρ的值从 0.3 增加到 0.9 时,时间序列变得更平滑且频率较低,反映出对自身过去的依赖性增加。均值回归的程度也随着ρ值的增加而下降。

  • 随机游走 Y5 显示出一个可以随机改变方向的趋势(称为随机趋势)。它表现出随时间增加的变异性,如(9)中的第一个结果所示,并且随着时间的推移有一点回归均值的倾向(均值回避)。

自相关函数

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

自相关函数(图像由作者提供)

时间序列的自相关函数绘制了 Corr(Yt,Yt-k) 与滞后值 k 的关系。它提供了时间序列结构依赖性的视觉总结。例如,Corr(Yt,Yt-1) 测量的是 Y 在相隔 1 周期的值之间的相关性。蓝色带表示 95% 的置信区间,自相关值在此带内意味着该相关性在 5% 显著性水平下统计上与 0 无显著差异。

  • 一个 IID 时间序列 Y1 的所有自相关值实际上都可以忽略不计,统计上为 0。

  • 随着 ρ 值从 0.3 增加到 0.9,Y 对自身过去的依赖程度增加,因为更多的自相关值变得显著大于 0,并且在统计上有所不同。

  • 随机游走时间序列 Y5 的所有自相关值都极接近 1,表明对自身过去的高度依赖(持久性)。这反映了第(9)点中给出的第二个属性。

该应用展示了 IID 时间序列的基本统计属性,并与 AR(1) 和随机游走的属性进行比较。它说明了依赖于过去的程度(或可预测性)如何随着 AR(1) 系数值从 0 变到 1 而变化,即从 IID 时间序列变为随机游走。如上所述,当依赖程度适中且 ρ 的值大于 0 但小于 1 时,时间序列是可预测的。

R 代码

时间序列和图表是通过以下 R 代码生成的:

set.seed(1234)

n=500  # Sample size
# IID
Y1 = rnorm(n)    
# AR(1) with rho = 0.3, 0.6, and 0.9
Y2 = arima.sim(list(order=c(1,0,0), ar=0.3), n)
Y3 = arima.sim(list(order=c(1,0,0), ar=0.6), n)
Y4 = arima.sim(list(order=c(1,0,0), ar=0.9), n)
# Random Walk
Y5 = cumsum(rnorm(n))

par(mfrow=c(3,1))
# Time plots
plot.ts(Y1,main="IID",lwd=2)
plot.ts(Y2,main="AR(1) with rho=0.3",lwd=2)
plot.ts(Y3,main="AR(1) with rho=0.6",lwd=2)
plot.ts(Y4,main="AR(1) with rho=0.9",lwd=2)
plot.ts(Y5,main="Random Walk",lwd=2)

# Autocorrelation functions
acf(Y1,main="IID"); 
acf(Y2,main="AR(1) with rho=0.3"); 
acf(Y3,main="AR(1) with rho=0.6"); 
acf(Y4,main="AR(1) with rho=0.9"); 
acf(Y5,main="Random Walk");

结论

IID 的概念在统计分析和机器学习模型中是基础的。本文回顾了 IID 在三种不同背景下的应用:抽样、建模和时间序列分析中的可预测性。展示了一个应用,该应用比较了 IID 时间序列与平稳 AR(1) 和随机游走的基本描述统计属性。

闪耀的洞察:GPT 从图表和表格中提取意义

原文:towardsdatascience.com/illuminating-insights-gpt-extracts-meaning-from-charts-and-tables-a0b71c991d34

使用 GPT 视觉来解释和汇总图像数据。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Ilia Teimouri PhD

·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 12 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 David Travis 拍摄,发布在 Unsplash

许多领域的专家认为,将图像等视觉输入与文本和语音整合到大型语言模型(LLMs)中,被视为 AI 研究中的一个重要新方向。通过增强这些模型处理除语言之外的多种数据模式,有可能显著拓宽它们的应用范围,同时提高它们在现有自然语言处理任务中的整体智能和性能。

多模态 AI 的前景从更具吸引力的用户体验,如能够感知其周围环境并提及周围物体的对话代理,到能够通过结合语言和视觉知识流畅地将指令转化为物理动作的机器人。通过将历史上分离的 AI 领域统一到一个模型架构中,多模态技术可能会加速依赖多种技能的任务的进展,如视觉问答或图像描述。不同领域的学习算法、数据类型和模型设计之间的协同作用可能会导致快速进步。

许多公司已经以各种形式采纳了多模态技术:OpenAIAnthropic,谷歌的 BardGemini 允许用户上传自己的图像或文本数据并进行聊天。

在这篇文章中,我希望展示大语言模型与计算机视觉在金融领域的一种简单而强大的应用。股票研究员和投资银行分析师可能会发现这特别有用,因为你们可能会花费大量时间阅读包含各种表格和图表的报告和声明。阅读冗长的表格和图表并正确解释它们需要大量时间、领域知识以及足够的专注以避免错误。更繁琐的是,分析师偶尔需要手动从 PDF 中输入表格数据,以便创建新的图表。一个自动化的解决方案可以通过提取和解释关键信息来减轻这些痛苦,而无需人工监督或疲劳。

实际上,通过将自然语言处理与计算机视觉相结合,我们可以创建一个助手来处理许多重复的分析任务,从而让分析师专注于更高级的战略和决策制定。

近年来,在使用 光学字符识别 或视觉文档理解(图像转文本)从图像/PDF 数据中提取文本方面取得了很多进展。然而,由于当前可用的训练数据的性质,现有方法仍然难以处理许多财务报表、研究报告和监管文件中的复杂布局和格式。

GPT-4V(ision)用于表格和图表

在 2023 年 9 月,OpenAI 发布了 GPT-4 Vision。根据 OpenAI 的说法:

GPT-4 带有视觉功能(GPT-4V)使用户能够指示 GPT-4 分析用户提供的图像输入。

GPT-4V 的视觉能力来自于 GPT-4,因此两个模型的训练方式相似。首先,研究人员向系统输入了大量文本,以教会它语言的基本知识。目标是预测文档中的下一个词。然后是使用一种称为人类反馈强化学习(RLHF)的精细训练方法。这涉及根据人类训练者的积极反应进一步微调模型,以产生我们认为真正有用的输出。

在这篇文章中,我将创建一个 Steamlit 应用程序,用户可以上传图像并询问关于图像的各种问题。我将使用的图像是金融 PDF 文档的截图。实际上,该文档是公开的 基金事实表

代码的主要部分有两个,第一个是一个将图像从给定文件路径编码的函数:

# Function to encode the image from a file path
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

你需要这个功能,因为模型期望你的输入图像是 base 64 编码格式。接下来的主要代码部分将是你如何将请求发送到 OpenAI 的 API:

# Function to send the request to OpenAI API
def get_image_analysis(api_key, base64_image, question):
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }

    payload = {
        "model": "gpt-4-vision-preview",
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": question},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
                ]
            }
        ],
        "max_tokens": 300
    }

    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    return response.json()['choices'][0]['message']['content']

在这里,我们将模型名称设置为gpt-4-vision-preview。如你所见,这与通常的文本到文本的 OpenAI API 调用非常不同。在这种情况下,我们定义了一个名为payload的 json 对象,其中包含你的文本以及图像数据。

你可以扩展get_image_analysis方法,以发送多个图像,或通过detail参数控制模型如何处理图像。详细信息请参见这里

剩余的代码主要是 Streamlit 方法,我们允许用户上传他们的图像,并通过提问与图像互动。

完整代码:(也可以在Github上找到)

import streamlit as st
import os
import requests
import base64
from PIL import Image

# Function to encode the image from a file path
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

# Function to save the uploaded file
def save_uploaded_file(directory, file):
    if not os.path.exists(directory):
        os.makedirs(directory)
    file_path = os.path.join(directory, file.name)
    with open(file_path, "wb") as f:
        f.write(file.getbuffer())
    return file_path

# Function to send the request to OpenAI API
def get_image_analysis(api_key, base64_image, question):
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }

    payload = {
        "model": "gpt-4-vision-preview",
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": question},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
                ]
            }
        ],
        "max_tokens": 300
    }

    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    return response.json()['choices'][0]['message']['content']

def main():
    st.title("Image Analysis Application")

    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="file_uploader")

    if uploaded_file is not None:

        file_path = save_uploaded_file('data', uploaded_file)

        # Encode the uploaded image
        base64_image = encode_image(file_path)

        # Session state to store the base64 encoded image
        if 'base64_image' not in st.session_state or st.session_state['base64_image'] != base64_image:
            st.session_state['base64_image'] = base64_image

        image = Image.open(uploaded_file)
        st.image(image, caption='Uploaded Image.', use_column_width=True)

    question = st.text_input("Enter your question about the image:", key="question_input")

    submit_button = st.button("Submit Question")

    api_key = os.getenv("OPENAI_API_KEY")

    if submit_button and question and 'base64_image' in st.session_state and api_key:
        # Get the analysis from OpenAI's API
        response = get_image_analysis(api_key, st.session_state['base64_image'], question)
        st.write(response)
    elif submit_button and not api_key:
        st.error("API key not found. Please set your OpenAI API key.")

if __name__ == "__main__":
    main()

输出和总结

现在让我们来看几个例子:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像由作者生成。图表来自公开的UBS 基金事实说明书

在这个例子中,问题是关于性能的峰值。我们可以看到模型正确地识别了峰值。模型还能够理解虚线是图例中的指数表现。在计算机视觉中,理解虚线和点线通常比较困难,但只要截图质量好(细节足够),GPT Vision 就可以轻松完成任务。

另一个例子:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像由作者生成。资料来自公开的UBS 基金事实说明书

在这个例子中,我尝试检验模型在以下方面的表现:1)从其他数据中提取相关表格 2)提取表格的相关部分 3)进行一些基本的数学运算。

如所示,模型成功满足了此任务的所有三个要求——鉴于传统上涉及的复杂性,这并非易事。手动操作时,分析师即使使用光学字符识别(OCR)工具,也很难提取锁定在 PDF 中的双栏表格。还需要额外的编码来将图表解析为结构化的数据框,以便于汇总。这可能会在回答原始问题之前消耗大量时间。然而,在这里,只需一个提示就能实现预期结果。避免了解码图像、抓取数据、处理电子表格和编写脚本的繁琐工作,极大提高了效率。

最后:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像由作者生成。资料来自公开的UBS 基金事实说明书

排序算法根据指定的排序规则系统地重新排列列表或数组的元素。然而,与传统代码不同,像 GPT 这样的 LLM 没有预定义的排序例程。

相反,GPT 被训练成根据先前的上下文预测序列中的下一个词。通过足够的数据和模型能力,排序能力从学习文本模式中显现出来。

上面的例子说明了这一点——GPT 正确地对从 PDF 截图中提取的表格中的两列进行排序,这是一项复杂的工作,涉及光学字符识别、数据提取和处理技能。即便在 Excel 中,多列排序也需要一定的专业知识。但只需在提示中提供目标,GPT 就能自动处理这些复杂的步骤。

与传统算法遵循严格的逐步指令不同,像 GPT 这样的语言模型通过在训练过程中识别文本中的关系来发展排序能力。这使得它们能够从多样的曝光中吸收各种能力,而不是被预定义的编程所限制。

为什么这很重要?

通过将这种灵活性应用于我们在这里看到的专业任务,提示可以解锁高效的问题解决方案,这些方案否则将需要大量的手动工作和技术知识。

揭开文本生成 AI 的黑箱

原文:towardsdatascience.com/illuminating-the-black-box-of-ai-ddea07e65c35

需求洞察

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Anthony Alcaraz

·发表于 Towards Data Science ·8 分钟阅读·2023 年 12 月 17 日

人工智能软件用于提升本文的语法、流畅性和可读性。

像 ChatGPT、Claude 3、Gemini 和 Mistral 这样的语言模型以其表达能力和博学引人注目。然而,这些大型语言模型仍然是黑箱,掩盖了驱动其响应的复杂机制。它们生成类似人类的文本的能力超越了我们理解其机器思维如何运作的能力。

但随着人工智能在信任和透明度至关重要的场景中发挥作用,如招聘和风险评估,可解释性现在成为了重点。可解释性不再是复杂系统的可选配件,而是安全推动高影响领域 AI 的必要前提。

为了揭开这些黑箱模型的面纱,生动的可解释 NLP 领域提供了越来越多的工具——从揭示关注模式的注意力可视化,到探查输入的随机部分以量化影响。像 LIME 这样的某些方法创建了模拟关键决策的简化模型。其他方法,如 SHAP,则借鉴了合作博弈论的概念,将“信贷”和“责任”分配到模型输入的不同部分,基于其最终输出。

无论技术如何,所有方法都追求相同的关键目标:阐明语言模型如何利用我们提供的大量文本来编写连贯的段落或进行重要评估。

人工智能已经在影响人类生活的决策中发挥作用——选择性地评估申请者、审查仇恨内容、诊断疾病。

解释不仅仅是附加功能——它们将对监督这些强大的模型在社会中的普及起到关键作用。

随着大型语言模型的不断进步,它们的内部工作机制仍然笼罩在神秘之中。然而,可信的人工智能需要对其在重大决策中的推理过程保持透明。

解释性 NLP 的充满活力的领域提供了两种主要方法来阐明模型逻辑:

  1. 基于扰动的方法:如 LIME 和 SHAP 的技术通过遮蔽输入组件系统地探测模型,并根据输出变化量化重要性。这些外部视角将模型视为黑箱。

  2. 自我解释:一种替代范式使模型能够通过生成文本解释自己的推理。例如,突出影响预测的关键输入特征。这依赖于内省的模型意识,而不是强加的解释。

早期分析发现这两种方法都很有前景——LIME 和 SHAP 擅长忠实捕捉模型行为,而自我解释则更符合人类的理性。然而,当前的实践也难以充分评估这两者,建议重新考虑评估策略。

理想情况下,两者之间的协同作用可以结合起来推动进展。例如,自我声明的重要因素可以与扰动实验进行验证。并且归因分数可以增加验证信号,为自由形式的解释提供支撑。

随着模型不断吸收更多的世界知识,阐明它们多方面的推理变得越来越重要。多样化的新兴想法可能对应对这一挑战至关重要。

解释 AI 的平衡艺术

构建解释不可避免地需要简化。但过度简化会导致扭曲。以常见的基于注意力的解释为例——它们突出模型 supposedly 关注的输入部分。然而,注意力分数往往与 AI 系统的实际推理过程不一致。

更严格的技术如 SHAP 通过系统地遮蔽不同的输入组件并直接测量对输出的影响来避免这一点。通过比较有无每个特征的预测,SHAP 为每个特征分配一个“重要性分数”以表示其影响。这种基于扰动的方法更好地反映了模型的逻辑。

然而,忠实性往往以可理解性为代价。移除单词和子句的组合很快变得认知负担过重。因此,研究社区强调平衡两个关键标准:

忠实性:解释多准确地捕捉了模型的实际决策过程?基于遮蔽的扰动方法在这里表现出色。

可理解性:解释对目标受众的直观性和易消化程度如何?简化的线性模型有助于理解,但可能会扭曲。

理想情况下,解释应同时展现两者特征。但即便是忠实性高的 SHAP,在模型处理大量文本和不受限制的生成时,也会遇到困难——需要处理的输出组合呈指数级增长。对 10,000 字文章的所有可能遮蔽排列进行计算是不可行的!

这阻碍了对关键应用的进展,如解释作文评分模型或处理文档的问答系统。创建模仿预测的简化模型(如 LIME)在复杂文本推理中也变得不可行。需要更具针对性的解决方案来扩展大型语言模型的可解释性。

突出了关键挑战 —— 特别是由长输入和开放输出引入的指数复杂性。如果有任何部分需要更多解释,请告诉我!

TextGenSHAP:优化语言任务的解释

[## TextGenSHAP:在长文档中生成的可扩展事后解释

大型语言模型(LLMs)由于其日益准确的性能而引起了对实际应用的极大关注…

arxiv.org

为了克服复杂语言模型的可解释性障碍,研究人员开发了TextGenSHAP —— 在 SHAP 的基础上,融入了对效率的优化,并考虑了语言结构。

若干创新技术应对了指数级的计算复杂性。预测解码首先预测可能的文本输出,避免了浪费的解码尝试。闪电注意力简化了内存密集型的注意力计算。原地重采样为提高效率预计算了输入编码。

这些加速技术使运行时间从几小时减少到几分钟,实现了实用的周转。作者验证了在不同模型类型和数据集复杂性下的数量级加速。

但仅仅是原始效率是不够的 —— 语言本身的复杂性必须得到体现。TextGenSHAP 解决了自然语言处理独有的解释挑战:

层次结构 —— 除了单个词语外,语言模型还学习句子、段落甚至文档之间的概念联系。TextGenSHAP 的层次化归因允许在粗粒度和细粒度层面上分配重要性评分。

指数输出空间 —— 开放式文本生成产生了巨大的可能输出集,不同于受限的分类任务。通过如 Shapley-Shubik 指数等重新表述,TextGenSHAP 绕过了详尽的枚举来估计特征重要性。

自回归依赖 —— 生成的标记概率上依赖于前面的标记。TextGenSHAP 的适应性解码算法,如预测解码,明确地尊重这些标记间的依赖关系。

这些架构和语言上的进展为 TextGenSHAP 应对现代自然语言处理的复杂性铺平了道路。现在可以着手解决长期挑战中的可解释性问题,例如文档上的问答。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像,来源于 Dall-E-3.0

应用:解释文档中的问答

文档中的问答代表了 AI 的一个宝贵里程碑——综合散布在段落中的信息以解决复杂的查询。TextGenSHAP 现在使解释这些复杂的文本推理工作流程成为可能。

作者在需要从超过 10,000 个单词的上下文中推导答案的挑战性数据集上评估了 TextGenSHAP。令人印象深刻的是,它准确地识别出分布在扩展文本中的关键句子和短语,这些句子和短语对每个答案的形成最有帮助。

通过适当地归因于冗长文档的不同部分,TextGenSHAP 使强大的应用成为可能:

改进文档检索——通过影响评分对上下文进行排名和筛选,提取了更相关的段落。仅通过根据 TextGenSHAP 重新排序,作者展示了检索召回率的显著提升——从 84%提高到近 89%。这有助于更好地为下游推理步骤提供信息。

提炼证据——使用重要性评分来挑选每个问题回答的最核心支持段落,在具有多样证据的数据集上准确率从 50%提高到 70%。确保模型关注简明的解释提取物,以对抗在大型语料库中对虚假模式的过拟合。

人工监督——通过揭示最有影响力的文本片段,TextGenSHAP 使审计员能够快速验证模型是否使用了适当的支持内容,而不是依赖于非预期的提示。否则,监控复杂的推理过程是不可行的。

对于推理密集型问题回答的成功表明,解释 AI 能力的社会影响有更广泛的适用性——如评分论文内容和散文或解释医疗诊断。通过揭示语言中的关键联系,TextGenSHAP 使我们朝着负责任和可信赖的 NLP 系统迈进。

anon832098265.github.io/

调查自我解释

arxiv.org/abs/2310.11207?source=post_page-----ddea07e65c35-------------------------------- [## 大型语言模型能否自我解释?LLM 生成自我解释的研究

大型语言模型(LLMs)如 ChatGPT 在各种自然语言任务中展示了卓越的性能…

arxiv.org

我们讨论了将模型视为黑箱的传统事后方法。一个有趣的替代方案是使系统能够解释自身的推理——自我解释

最近的研究分析了这些用于情感分析,使用 ChatGPT。模型突出了输入词汇对其预测的影响。与直接扰动输入的外部技术不同,自我解释依赖于内省模型意识来声明重要因素。

论文系统地比较了不同格式,发现预测然后解释或反之都工作得相当好。模型容易生成所有单词的特征归因分数或仅仅是最重要的高亮部分。但有趣的是,重要性分数通常聚集在“全面的”水平(例如 0.0、0.5、0.75),这更类似于人类判断,而不是离散的机器精度。

尽管自我解释与人类的推理相对一致,但广泛使用的评估实践难以区分质量。依赖于细粒度模型预测变化的指标容易受到欺骗,而这些变化对于 ChatGPT 常常不敏感。研究人员得出结论认为,经典的可解释性流程需要重新考虑以适应大型语言模型。

要充分实现自我解释的潜力,需要新的评估框架,以适应其混合的人机特性。将它们与直接可观察的信号如注意力权重相结合,可能会增强其真实性。构建模块化的推理/解释组件也可能使得自我解释更加纯粹。

通过精心协同设计,以适应其新兴特性,自我解释有可能开启前所未有的模型透明度——将黑箱转换为“玻璃箱”,使系统不仅展示其内部工作原理,还讨论其内在机制。

TextGenSHAP 方法专注于为文本生成模型提供高效的 Shapley 值归因。它在量化特征重要性方面取得了进展,适用于长文档问答任务。

然而,TextGenSHAP 仍然依赖外部视角,扰动输入并观察输出变化,而不是让模型自我反思其推理。这为与自我解释方法的整合留下了空间。

自我解释可以提供更为定性、直观的理解,以补充来自 TextGenSHAP 的定量归因分数。例如,TextGenSHAP 可能会识别文档中的关键段落,并将某些句子突出为回答问题时最具影响力的部分。自我解释可以通过讨论聚焦于这些领域的逻辑来丰富这些信息。

相反,目前的自我解释通常以自由生成的形式出现,缺乏基础。与将模型推理综合为标记重要性排名的归因分数结合起来,可能有助于验证和增强自我解释的意义。

在架构上,TextGenSHAP 模块可以首先处理文档和问题,生成注意力分布和段落排名。然后,自我解释模块可以利用这些定量信号生成自由形式的推理,讨论评估内容,并通过归因分数引导解释。

联合评估还可以评估自我声明的解释因素是否与扰动基础评分指定为有影响的输入组件一致。

本质上,自我解释提供了模型理解的“是什么”,而归因分数则提供了“为什么”。它们的共生关系可能使丰富的解释性融合定量和定性见解成为可能。

前进的道路:通过透明度实现信任

TextGenSHAP 提供了一个关键的进展——窥视大型语言模型在处理大量文本时的复杂工作。通过创建高效准确的解释,它绕过了现有解释方法仅限于少量语言片段的障碍。

然而,单纯的语言流畅性并不能保证可信赖的人工智能。语言的掌握——推动 ChatGPT 口才进步的标志——必须与阐明的掌握相结合。

阐明不仅仅是抛出几个关键词那么简单——它需要复制复杂的推理链条,从而得出最终评估。像 TextGenSHAP 这样的进展将这一必要的透明度更接近现实。

随着模型不断吸收更多世界知识,其内部表征的复杂性也大大增加。通过简化的注意力分数或小扰动样本尝试监督只会混淆视听,而非阐明。尊重结构和逻辑依赖的更全面的方法,如 TextGenSHAP,将证明至关重要。

学习缺乏透明度将导致权力没有责任。观察缺乏阐明将导致缺乏严谨的橡皮图章。神经网络的显著复兴必须伴随揭示其复杂性的技术。

在这一领域的进展仍处于初期阶段——但重要的种子已经扎根。通过努力完善理解性与忠实性的混合,无论是通过高效的近似方法还是天生可解释的架构,也许未来的系统可以巧妙地解释其掌握,从而彻底揭开黑箱神秘的面纱。

初学者的图像分类

原文:towardsdatascience.com/image-classification-for-beginners-8546aa75f331

VGG 和 ResNet 架构来自 2014 年

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Mina Ghashami

·发布在Towards Data Science ·10 分钟阅读·2023 年 10 月 17 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源于unsplash — 作者修改

图像分类是我在Interview Kickstart教授的第一个主题,旨在帮助专业人士在顶尖科技公司找到工作。当我准备其中的一次讲座时,我写了这篇文章。因此,如果你对这个主题不熟悉,这个直观的解释可能对你也有帮助。

在这篇文章中,我们探讨了 VGG 和 ResNet 模型;这两者都是卷积神经网络(CNNs)在计算机视觉领域发展中的开创性和影响力巨大的作品。VGG[2] 是 2014 年由牛津大学的研究小组提出的,而 ResNet[3] 是 2015 年由微软研究人员提出的。

让我们开始吧。

什么是 VGG?

VGG 代表视觉几何组,是牛津大学的一个研究小组。2014 年,他们为图像分类任务设计了一个深度卷积神经网络架构,并以他们的名字命名了它,即 VGG。[2].

VGG 网络架构

这个网络有几种配置;所有配置的架构相同,只是层数不同。最著名的有 VGG16 和 VGG19。VGG19 比 VGG16 更深,性能更好。为了简化,我们关注 VGG16。

VGG16 的架构如下图所示。正如我们所见,它有 16 层;13 个卷积层和 3 个全连接层

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

VGG16 架构 — 图片由作者提供

这是一个非常简单的架构;它由 6 个块组成,其中前 5 个块包含卷积层,之后是一个最大池化层,第 6 个块仅包含全连接层。

所有卷积层使用 3x3 滤波器,步幅为 1,所有 最大池化层为 2x2,步幅为 2,因此它们将输入特征图的宽度和高度减半。这称为 下采样,因为它减少了输出特征图的大小。

注意,卷积层从 64 个滤波器开始,并在每次池化后翻倍,直到达到 512 个滤波器。所有卷积层使用“相同”填充以保持输入和输出之间的相同大小,并且它们都使用 RELU 激活函数。下面,我们解释这些概念:

相同填充:相同填充是一种填充技术,以确保卷积操作的输出体积具有与输入体积相同的高度和宽度。它通过在所有边缘均匀填充零来工作,使得卷积操作后空间维度保持不变。

最大池化:如我们上面所见,在每个块之后应用 2x2 最大池化,步幅为 2。最大池化输出窗口中的最大值。步幅为 2 将空间维度减半,并保留了对强大特征检测至关重要的信息。此外,这种减少带来了计算效率。

RELU 激活函数:如我们所提到的,VGG 使用的激活函数是 RELU。RELU 将负值设为零,保持正值不变。它所增加的非线性有助于提升模型的表现力,并有助于检测复杂的模式。VGG 模型在每个卷积层后使用 RELU。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

让我们逐层了解 VGG16 架构:

  • 假设输入是一个彩色图像,其尺寸为高度和宽度,则其大小为(高度,宽度,3)。注意 RGB 有 3 个通道。

  • 第一层具有 64 个神经元,并应用 3x3 卷积,具有“相同”填充,因此第一层的输出特征图为(高度,宽度,64)。

  • 第二层与第一层相同,因此这一层的输出特征图也为(高度,宽度,64)。

  • 第三层是 2x2 最大池化,步幅为 2,因此它将大小缩小到(高度/2,宽度/2,64)

  • 第四层和第五层是 conv3–128,具有“相同”填充,因此它们将输出大小更改为(高度/2,宽度/2,128)。

  • 第六层再次是 2x2 最大池化,它将输出大小更改为(高度/4,宽度/4,128)。

  • 如果我们继续这样下去,我们会发现当数据到达第一个全连接层时,它的形状是(高度/32,宽度/32,512)。因此,我们看到通道的数量从 3 增加到 512,同时高度和宽度减少了 32 倍!!!可以把它想象成压缩信息,而是捕捉通道中的模式。

VGG 计算成本

VGG16 是 最大的 CNN 模型之一;它拥有 1.38 亿个参数。在下图中,我们看到 VGG 的两个变体:VGG16(具有 16 层)和 VGG19(具有 19 层)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来自 [1]

我们看到 VGG16 和 VGG19 在一次前向传播中需要的操作次数是最大的 CNN 模型。注意,操作次数与模型的参数数量成正比。在下一篇文章中,我们将探讨 ResNet[3]模型,它比 VGG 小得多,并且表现优于 VGG。

为什么提出了 VGG?

在 VGG 之前,CNN 模型的层数较少,卷积滤波器较大。VGG 网络的提出是为了展示一个只有 3x3 卷积层堆叠在一起的简单 CNN 可以与具有大滤波器的复杂模型一样好。

它还展示了卷积网络中深度的重要性。他们表明,堆叠许多小的 3x3 卷积层可以有效模拟较大的感受野。在 VGG 被提出时,它在 ImageNet 数据集上的图像分类任务中超越了所有其他模型。

ResNet 是什么?

ResNet,即残差网络,是微软研究人员在 2015 年提出的[3]。在深入了解其架构之前,先来看看它为何被提出。

为什么提出了 ResNet?

总而言之,ResNet 的提出是为了解决在非常深的网络中的梯度消失问题。 让我们更深入地看看:

正如我们在 VGG 的案例中所看到的,深度神经网络极其强大。但它们也有更多的参数,因此训练时间较长,计算成本较高。此外,我们还需要更多的训练数据来训练它们。

除了计算成本和训练数据的大小外,训练深度神经网络也面临障碍。正如下图所示,当我们训练浅层神经网络时,训练损失在早期周期中开始减少。但在深度神经网络中,训练损失在早期周期中减少很少,经过几个周期后突然下降。这是深度神经网络实际训练中的一个大障碍。

那么为什么会发生这种情况呢?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

浅层和深度神经网络中训练损失随周期减少 —— 图片作者

这发生的原因有两个:

  1. 在深度神经网络的早期层中,梯度消失问题出现;即,损失的梯度在到达网络的早期层时会消失,因此这些层的参数更新非常少。

  2. 在深度神经网络的晚期层中,原始信号非常少(即原始输入)。这是为什么呢?因为信号被所有前面层的权重乘以,并通过激活函数,这会将信号推向零。因此,这些层在早期周期的输出几乎是随机噪声。因此,相对于损失的梯度是随机噪声,对这些层参数的更新没有意义。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

深度神经网络在前几个训练周期中学习受阻 —— 图片作者

这就是为什么在训练深度神经网络的前几个周期中没有看到太多改进的原因。

为了解决这个问题,我们希望找到一种方法,使得输入能够到达后期层,梯度能够到达早期层。我们可以使用跳跃连接来实现这两者。

跳跃连接

跳跃连接的理念是将网络层分组到块中,对于每个块,使输入同时通过和绕过该块。像这样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

在每个块内,层正常地向前传递它们的数据,而在块之间,我们有一种新类型的连接。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如上所示,这种连接通过将块的输入与块的输出结合起来工作。因此,数据基本上有两个路径流动:一个通过块,另一个绕过块。

所以一个残差块看起来像这样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

残差块 — 图片由作者提供

上面的“+”符号表示“组合”符号,它将输入张量和输出张量结合在一起。它必须是一个不会干扰梯度传递的操作。“+”操作可以是以下任意一种:

  1. 两个张量的逐元素相加

  2. 两个张量的拼接

值得强调的是,残差块之所以被称为“残差”,是因为它实现了一种残差学习方法。每个残差块学习一个相对于其输入的残差函数,而不是直接拟合一个期望的基础映射。

在前馈网络中,我们学习从输入到输出的直接映射,即 f(x): x->y。然而,在残差块中,如上所示,每个残差块学习一个残差函数,即 x->f(x)+x这个残差函数表示需要对输入进行的修改,以获得期望的输出。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源 [3]

ResNet 更容易训练

由残差块组成的网络称为残差网络或 ResNet。它们有几个优势,使得它们更快、更容易训练。

  1. 其中之一是每个残差块都会增强数据:由于它们将输入绕过块而不变,残差块的工作不是去判断输入中包含什么重要信息,而是去确定我们可以向输入中添加哪些额外的信息以达到输出。结果发现这是一项更简单的工作。

  2. 网络具有更短的梯度路径。由于每个块都有一个绕过块的路径,梯度也会经过这条路径。因此,网络中的任何层都有相对较短的路径,使得损失梯度能够到达。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

梯度在两个路径中流动:通过层和绕过块 — 图片由作者提供

关于 ResNet 的关注点

关于残差块,有一些关注点,我们在设计残差网络时需要注意:

  1. 为了能够添加/连接残差块的输入和输出,我们必须确保两个张量的形状相同。显然,如果我们强制每一层的输出形状与其输入相同,这个问题将不会出现。但是,强制这种约束会限制模型的容量。

  2. 如果我们使用连接而不是逐元素相加来组合每个块的输入和输出张量,那么我们将得到一个非常大的张量,并且参数会爆炸。因此,我们不应过度使用连接操作,如果我们的网络很深,必须优先考虑相加。通常,连接操作在一个或两个块中最多使用。

ResNet 架构

现在我们已经了解了残差块和跳跃连接,ResNet[3] 用于图像分类,通过堆叠多个残差块来构建。我们可以构建超过 100 层的非常深的网络。原始的 ResNet 具有从 18 层到 152 层的变体架构 [3]。

每个残差块由一个卷积层、批量归一化和 RELU 激活函数组成。正如我们在下图中看到的那样,“批量归一化”在每个卷积层之后使用;它通过减去均值并除以标准差来规范化激活。这一操作稳定了训练过程。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

残差块 - 作者提供的图像

当 ResNet 被提出时,它在 ImageNet 分类任务上取得了最先进的结果 [3]。

要点

要点 1: 深层神经网络中的最后几层接收到的输入信号非常少。这是因为每个中间层的激活函数如 sigmoid 或 tanh 对于大的正/负输入会饱和到 0 或 1。这会随着信号通过层而减弱。这被称为“饱和”。

要点 2: 深层神经网络的早期层在训练网络的前几个时期接收到的梯度非常少。这是因为在训练过程中,误差梯度通过许多层反向传播,它会指数级地缩小。这使得早期层难以有效学习。这个问题被称为“梯度消失问题”。

要点 3: VGG 的提出是为了展示使用简单的 3x3 滤波器的深层网络如何类似于使用大卷积的复杂网络。ResNet 的提出是为了解决非常深层网络中的梯度消失问题。

总结

在这篇文章中,我们研究了两个开创性的 CNN 架构,即 VGG 和 ResNet。VGG 是一个深层 CNN,仅包含 3x3 卷积层。它历史上用于图像分类任务,并且在提出时,它在 ImageNet 挑战中优于 AlexNet 和其他竞争模型。它展示了 CNN 中深度的力量,以及使用简单的 3x3 卷积可以类似于更大卷积核的效果。ResNet 在 VGG 之后被引入,并且优于 VGG。ResNet 的创新在于引入了残差块,这使得深层网络的训练变得更容易和更快。

如果你有任何问题或建议,请随时联系我:

电子邮件:mina.ghashami@gmail.com

领英:www.linkedin.com/in/minaghashami/

参考文献

  1. 实际应用中的深度神经网络模型分析

  2. 非常深的卷积网络用于大规模图像识别

  3. 深度残差学习用于图像识别

使用 PyTorch 和 SHAP 进行图像分类:你能信任自动驾驶汽车吗?

原文:towardsdatascience.com/image-classification-with-pytorch-and-shap-can-you-trust-an-automated-car-4d8d12714eea

构建一个目标检测模型,将其与强度阈值进行比较,评估并使用 DeepSHAP 解释它

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Conor O’Sullivan

·发表于Towards Data Science ·阅读时间 14 分钟·2023 年 3 月 21 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来源:作者)

如果世界不那么混乱,自驾车将会简单。但事实并非如此。为了避免严重的伤害,AI 必须考虑许多变量——速度限制、交通情况和路上的障碍物(例如分心的人)。AI 需要能够检测这些障碍物,并在遇到时采取适当的行动。

幸运的是,我们的应用并没有那么复杂。更幸运的是,我们将使用锡罐而不是人类。我们将建立一个模型,用于检测迷你自动驾驶汽车前方的障碍物。如果障碍物过于接近,汽车应该停下,否则前进

到头来,这是一个二分类问题。为了解决它,我们将:

  • 使用强度阈值创建基准

  • 使用 PyTorch 构建 CNN

  • 使用准确率、精确率和召回率评估模型

  • 使用 SHAP 解释模型

我们将看到模型不仅表现良好,而且其预测方式也似乎合理。在此过程中,我们将讨论 Python 代码,你可以在GitHub上找到完整的项目。

导入和数据集

# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob 
import random 

from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

import shap
from sklearn import metrics
from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.metrics import ConfusionMatrixDisplay as cmd

在图 1 中,你可以看到我们数据集中图像的示例。这些图像的尺寸都是 224 x 224。如果没有黑色罐子或者罐子距离较远,图像被分类为 GO。如果罐子过于接近,图像被分类为 STOP。你可以在Kaggle上找到完整的数据集。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1:示例图像(来源:作者)

我们使用下面的代码显示上述图像。注意图像的名称。它总是以一个数字开头。这是目标变量。我们用 0 表示 GO,用 1 表示 STOP。

# Paths of example images
ex_paths = ["../../data/object_detection/0_b812cd70-4eff-11ed-9b15-f602a686e36d.jpg",
          "../../data/object_detection/0_d1edcc80-4ef6-11ed-8ddf-a46bb6070c92.jpg",
          "../../data/object_detection/1_cb171726-4ef7-11ed-8ddf-a46bb6070c92.jpg"]

# Plot example images
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.set_facecolor('white')

for i, path in enumerate(ex_paths):

    # Load image
    img =  Image.open(path)

    # Get target
    name = path.split("/")[-1]
    target = int(name.split("_")[0])

    # Plot image
    ax[i].imshow(img)
    ax[i].axis("off")

    # Set title
    title = ["GO","STOP"][target]
    ax[i].set_title(title,size=20)

基准

在建模之前,值得创建一个基准。这可以提供一些对我们问题的见解。更重要的是,它为我们提供了一个比较模型结果的标准。我们更复杂的深度学习模型应该会优于简单的基准。

在图 1 中,我们可以看到锡罐比周围环境更暗。我们将在创建基准时利用这一点。即,如果图像中有许多暗像素,我们将其分类为 STOP。达到这一点需要几个步骤。对于每个图像,我们将:

  1. 进行灰度化,使每个像素的值在 0(黑色)和 255(白色)之间。

  2. 使用截止值,将每个像素转换为二进制值——深色像素为 1,浅色像素为 0。

  3. 计算平均强度——暗像素的百分比

  4. 如果平均强度超过某个百分比,我们将图像分类为 STOP。

合并步骤 1 和 2 是一种图像数据的特征工程方法,称为强度阈值。你可以在这篇文章中阅读更多关于此及其他特征工程方法的信息:

## 图像数据的特征工程

裁剪、灰度化、RGB 通道、强度阈值、边缘检测和颜色滤镜

towardsdatascience.com

我们使用下面的函数应用强度阈值。缩放后,一个像素将具有 0(黑色)或 1(白色)的值。对于我们的应用,颠倒这一点是有意义的。也就是说,原本深色的像素将被赋值为 1。

def threshold(img,cutoff,invert=False):
    """Apply intesity thresholding"""

    img = np.array(img)

    # Greyscale image
    img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)

    #Apply cutoff
    img[img>cutoff] = 255 #white
    img[img<=cutoff] = 0 #black

    # Scale to 0-1    
    img = img/255

    # Invert image so black = 1
    if invert: 
        img = 1 - img

    return img

在图 2 中,你可以看到我们应用强度阈值的一些示例。我们可以调整截止值。较小的截止值意味着我们包括的背景噪声更少。缺点是我们捕捉到的锡罐较少。在这种情况下,我们将使用截止值 60。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2:使用强度阈值的特征工程(来源:作者)

我们加载了所有的图像(第 5 行)和目标变量(第 6 行)。然后,我们对这些图像应用强度阈值(第 9 行)。请注意,我们设置了invert=True。最后,我们计算每个处理过的图像的平均强度(第 10 行)。最终,每个图像由一个单一的数字——平均强度来表示。这可以解释为暗像素的百分比

# Load paths
paths = glob.glob("../../data/object_detection/*.jpg")

# Load images and targets
images = [Image.open(path) for path in paths]
target = [int(path.split("/")[-1].split("_")[0]) for path in paths]

# Apply thresholding and get intensity
thresh_img = [threshold(img,60,True) for img in images]
intensity = [np.average(img) for img in thresh_img]

图 3 给出了所有标记为 GO 和 STOP 的图像的平均强度箱线图。通常,我们可以看到 STOP 的值更高。这是有道理的——罐子离得更近,因此我们会有更多的暗像素。红线在 6.5% 处。这似乎能很好地分离图像类别。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3:目标变量的平均强度(来源:作者)

# Split data into go and stop images
go_data = [intensity[i] for i in range(len(target)) if target[i]==0]
stop_data = [intensity[i] for i in range(len(target)) if target[i]==1]
data= [go_data,stop_data]

fig = plt.figure(figsize=(5,5))

# Plot boxplot
plt.boxplot(data)
plt.hlines(y=0.065,xmin=0.5,xmax=2.5,color='r')
plt.xticks([1,2],['GO','STOP'])
plt.ylabel("Average Intensity",size=15)

我们使用 6.5% 作为预测的截断值(第 2 行)。即如果暗像素的百分比超过 6.5%,则预测为 STOP(1),否则预测为 GO(0)。其余的代码用于评估这些预测。

# Predict using average intensity
prediction = [1 if i>0.065 else 0 for i in intensity]

# Evaluate
acc = metrics.accuracy_score(target,prediction)
prec,rec,_,_ = score(target, prediction,average='macro')

print('Accuracy: {}'.format(round(acc,4)))
print('Precision: {}'.format(round(prec,4)))
print('Recall: {}'.format(round(rec,4)))

# Plot confusion matrix
cm = metrics.confusion_matrix(target, prediction)
cm_display = cmd(cm, display_labels = ['GO', 'STOP'])

cm_display.plot()

最终,我们的准确率为 82%,精确率为 77.1%,召回率为 82.96%。不错!在混淆矩阵中,我们可以看到大多数错误是由于假阳性。这些是被预测为 STOP 的图像,而实际上我们应该 GO。这对应于图 3 的箱线图。查看 GO 强度值在红线以上的长尾。这可能是由于背景像素增加了图像中的暗像素数量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4:基准预测的混淆矩阵(来源:作者)

卷积神经网络

如果一辆 AI 汽车的准确率只有 82%,你可能会有点担心。那么我们来看看更复杂的解决方案。

加载数据集

我们首先定义ImageDataset类。这用于加载我们的图像和目标变量。作为参数,我们需要传入所有图像路径的列表和用于转换图像的方法。我们的目标变量将是张量——[1,0] 代表 GO 和 [0,1] 代表 STOP。

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, paths, transform):

        self.paths = paths
        self.transform = transform

    def __getitem__(self, idx):
        """Get image and target (x, y) coordinates"""

        # Read image
        path = self.paths[idx]
        image = cv2.imread(path, cv2.IMREAD_COLOR)
        image = Image.fromarray(image)

        # Transform image
        image = self.transform(image)

        # Get target
        target = path.split("/")[-1].split("_")[0]
        target = [[1,0],[0,1]][int(target)]

        target = torch.Tensor(target)

        return image, target

    def __len__(self):
        return len(self.paths)

我们将使用常见的图像转换。为了帮助创建一个更强大的模型,我们将对颜色进行抖动(第 2 行)。这将随机改变图像的亮度、对比度、饱和度和色调。我们还会对像素值进行归一化(第 4 行)。这将帮助模型收敛。

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

我们加载了所有图像路径(第 1 行)并随机打乱它们(第 4 行)。然后我们为训练数据(第 8 行)和验证数据(第 9 行)创建ImageDataset对象。为此我们使用了 80/20 的划分(第 7 行)。最终,我们将在训练集中拥有3,892张图像,在验证集中拥有974张图像。

paths = glob.glob("../../data/object_detection/*.jpg")

# Shuffle the paths
random.shuffle(paths)

# Create a datasets for training and validation
split = int(0.8 * len(paths))
train_data = ImageDataset(paths[:split], TRANSFORMS)
valid_data = ImageDataset(paths[split:], TRANSFORMS)

此时,实际上还没有数据加载到内存中。在我们能够使用数据训练 PyTorch 模型之前,我们需要创建DataLoader对象。对于train_loader,我们设置了batch_size=128。这允许我们迭代所有训练图像,每次加载 128 张。对于验证图像,我们将批处理大小设置为验证集的完整长度。这允许我们一次加载所有 974 张图像。

# Prepare data for Pytorch model
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())

模型架构

接下来,我们定义我们的 CNN 架构。你可以在图 5 中看到这个架构的图示。我们从 224x224x3 的图像张量开始。我们有 3 个卷积层和最大池化层。这将我们缩减到 28x28x64 的张量。接下来是一个 drop-out 层和两个全连接层。我们对所有隐藏层使用 ReLu 激活函数。对于输出节点,我们使用 sigmoid 函数。这是为了使我们的预测值在 0 和 1 之间。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:CNN 架构(来源:作者)

我们在下面的Net类中捕捉了这种架构。需要指出的一点是使用了**nn.Sequential()**函数。必须使用这种定义 PyTorch 模型的方法,否则 SHAP 包将无法工作。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # Convolutional layers
        self.conv_layers = nn.Sequential(
            # Sees 224x224x3 image tensor
            nn.Conv2d(3, #RGB channels
                            16, #number of kernels
                            3, #size of kernels
                            padding=1), 
            nn.MaxPool2d(2),
            nn.ReLU(),

            # Sees 112x112x16 tensor
            nn.Conv2d(16, 32, 3, padding=1),
            nn.MaxPool2d(2),
            nn.ReLU(),

            # Sees 56x56x32 tensor
            nn.Conv2d(32, 64, 3, padding=1),
            nn.MaxPool2d(2),
            nn.ReLU()
        )

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            # Sees flattened 28 * 28 * 64 tensor
            nn.Dropout(0.25),
            nn.Linear(64 * 28 * 28, 500),
            nn.ReLU(),
            nn.Linear(500, 2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 64 * 28 * 28)
        x = self.fc_layers(x)
        return x

我们创建一个模型对象(第 2 行)。我们将其移动到 GPU 上(第 6–7 行)。我使用的是苹果 M1 笔记本电脑。你需要设置适合你机器的设备。

# create a complete CNN
model = Net()
print(model)

# move tensors to GPU if available
device = torch.device('mps')
model.to(device)

我们定义我们的损失函数(第 2 行)。由于我们的目标变量是二元的,我们将使用二元交叉熵损失。最后,我们使用 Adam 作为我们的优化器(第 5 行)。这是用来最小化损失的算法。

# specify loss function (binary cross-entropy)
criterion = nn.BCELoss()

# specify optimizer
optimizer = torch.optim.Adam(model.parameters())

训练模型

现在是有趣的部分!我们训练我们的模型 20 个周期,并选择验证损失最低的那个。模型可以在同一个GitHub Repo中找到。

name = "object_detection_cnn" # Change this to save a new model

# Train the model
min_loss = np.inf
for epoch in range(20):

    model = model.train()
    for images, target in iter(train_loader):

        images = images.to(device)
        target = target.to(device)

        # Zero gradients of parameters
        optimizer.zero_grad()  

        # Execute model to get outputs
        output = model(images)

        # Calculate loss
        loss = criterion(output, target)

        # Run backpropogation to accumulate gradients
        loss.backward()

        # Update model parameters
        optimizer.step()

    # Calculate validation loss
    model = model.eval()

    images, target = next(iter(valid_loader))
    images = images.to(device)
    target = target.to(device)

    output = model(images)
    valid_loss = criterion(output, target)

    print("Epoch: {}, Validation Loss: {}".format(epoch, valid_loss.item()))

    # Save model with lowest loss
    if valid_loss < min_loss:
        print("Saving model")
        torch.save(model, '../../models/{}.pth'.format(name))

        min_loss = valid_loss

需要提到的一点是**optimizer.zero_grad()**行。这将所有参数的梯度设置为 0。在每次训练迭代中,我们希望使用仅来自该批次的梯度来更新参数。如果不将梯度清零,它们会积累。这意味着我们将使用新批次和旧批次的梯度组合来更新参数。

模型评估

现在让我们看看这个模型的表现如何。我们从加载我们保存的模型开始(第 2 行)。切换到评估模式很重要(第 3 行)。如果我们不这样做,一些模型层(例如 dropout)在推理时会被不正确地使用。

# Load saved model 
model = torch.load('../../models/object_detection_cnn.pth')
model.eval()
model.to(device)

我们从验证集中加载图像和目标变量(第 2 行)。请记住,目标变量是维度为 2 的张量。我们获取每个张量的第二个元素(第 4 行)。这意味着我们现在将有一个二元目标变量——1 表示 STOP,0 表示 GO。

# Get images and targets
images, target = next(iter(valid_loader))
images = images.to(device)
target = [int(t[1]) for t in target]

我们使用模型对验证图像进行预测(第 2 行)。同样,输出将是维度为 2 的张量。我们考虑第二个元素。如果概率超过 0.5,我们预测 STOP,否则预测 GO。

# Get predictions
output=model(images)
prediction = [1 if o[1] > 0.5 else 0 for o in output]

最后,我们使用与评估基准相同的代码将目标预测进行比较。我们现在的准确率为 98.05%,精确率为 97.38%,召回率为 97.5%。相比基准有了显著的提升!在混淆矩阵中,你可以看到错误的来源。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6:模型在验证集上的混淆矩阵(来源:作者)

在图 7 中,我们更详细地查看了一些这些错误。第一行显示了一些假阳性。这些是当汽车应该 GO 时被预测为 STOP 的图像。类似地,底行显示了假阴性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7:预测错误的示例(来源:作者)

你可能已经注意到所有障碍物都在相似的距离处。当标记图像时,我们使用了一个截止距离。即当障碍物距离小于这个截止距离时,它被标记为 STOP。上述障碍物都接近这个截止距离。它们可能被错误标记,所以当障碍物接近这个截止距离时,模型可能会“困惑”。

使用 SHAP 解释模型

我们的模型似乎表现良好。通过了解它如何做出这些预测,我们可以更确定它的效果。为此,我们使用 SHAP。如果你对 SHAP 不熟悉,你可能会发现下面的视频很有用。否则,查看我的SHAP 课程 如果你注册我的新闻通讯,你可以获得免费访问权 😃

下面的代码计算并显示了我们在图 1 中看到的 3 个示例图像的 SHAP 值。如果你想了解更多关于这段代码如何工作的细节,请查看文末提到的文章。

# Load saved model 
model = torch.load('../../models/object_detection_cnn.pth')

# Use CPU
device = torch.device('cpu')
model = model.to(device)

#Load 100 images for background
shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)
background, _ = next(iter(shap_loader))
background = background.to(device)

#Create SHAP explainer 
explainer = shap.DeepExplainer(model, background)

# Load test images
test_images = [Image.open(path) for path in ex_paths]
test_images = np.array(test_images)

test_input = [TRANSFORMS(img) for img in test_images]
test_input = torch.stack(test_input).to(device)

# Get SHAP values
shap_values = explainer.shap_values(test_input)

# Reshape shap values and images for plotting
shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
test_numpy = np.array([np.array(img) for img in test_images])

shap.image_plot(shap_numpy, test_numpy,show=False)

你可以在图 8 中看到输出。前两行是标记为 GO 的图像,第三行为标记为 STOP 的图像。我们有目标张量中每个元素的 SHAP 值。第一列是 GO 预测的 SHAP 值,第二列是 STOP 预测的 SHAP 值。

颜色非常重要。蓝色 SHAP 值告诉我们这些像素减少了预测值。换句话说,它们使得模型预测给定标签的可能性降低。类似地,红色 SHAP 值则增加了这种可能性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 8:示例图像的 SHAP 值(来源:作者)

为了理解这一点,让我们关注图 8 的右上角。在图 9 中,我们有标记为 GO 的图像以及 GO 预测的 SHAP 值。你可以看到大多数像素是红色的。这些像素增加了该预测的值,从而导致正确的 GO 预测。你还可以看到像素聚集在障碍物截止位置——罐头的位置,其中标签从 GO 更改为 STOP。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 9:GO 预测和 GO 标签的 SHAP 值。

在图 10 中,我们可以看到标记为 STOP 的图像的 SHAP 值。罐头在 GO 预测中为蓝色,在 STOP 预测中为红色。换句话说,模型使用罐头中的像素来减少 GO 值并增加 STOP 值。这是有道理的!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 10:STOP 预测的 SHAP 值

这个模型不仅能够准确地进行预测,而且它做出这些预测的方式似乎也很合逻辑。然而,你可能注意到一些背景像素被突出显示了。这没有意义。为什么背景对预测如此重要?当我们移除物体或移动到新位置时,背景可能会发生变化。

原因是模型对训练数据过拟合了。这些物体出现在许多图像中。结果是模型将它们与 STOP/GO 标签关联。在下面的文章中,我们进行类似的分析。我们讨论了如何防止这种过拟合的方法。我们还花更多时间解释 SHAP 代码。

## 使用 SHAP 调试 PyTorch 图像回归模型

使用 DeepShap 来理解和改进支持自动驾驶汽车的模型

towardsdatascience.com

希望你喜欢这篇文章!你可以通过成为我的 推荐会员 😃 来支持我。

[## 使用我的推荐链接加入 Medium — Conor O’Sullivan

作为 Medium 会员,你的一部分会员费会分配给你阅读的作者,你将可以完全访问所有故事…

conorosullyds.medium.com

| Twitter | YouTube | Newsletter — 免费注册以获取 Python SHAP 课程

数据集

JatRacer 图像 (CC0: 公共领域) www.kaggle.com/datasets/conorsully1/jatracer-images

参考资料

stack overflow,为什么我们需要在 PyTorch 中调用 zero_grad()? stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch

Kenneth Leung如何轻松绘制神经网络架构图towardsdatascience.com/how-to-easily-draw-neural-network-architecture-diagrams-a6b6138ed875

使用 Vision Transformer 进行图像分类

原文:towardsdatascience.com/image-classification-with-vision-transformer-8bfde8e541d4

如何借助基于 Transformer 的模型进行图像分类

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Ruben Winastwan

·发表于 Towards Data Science ·阅读时长 13 分钟·2023 年 4 月 13 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

drmakete labUnsplash 上的照片

自 2017 年推出以来,Transformer 已被广泛认可为一种强大的编码器-解码器模型,可以解决几乎所有语言建模任务。

BERT、RoBERTa 和 XLM-RoBERTa 是在语言处理领域使用 Transformer 编码器堆栈作为其架构基础的一些最先进模型的例子。ChatGPT 和 GPT 系列也使用 Transformer 的解码器部分来生成文本。可以肯定地说,几乎所有最先进的自然语言处理模型都在其架构中融入了 Transformer。

Transformer 的表现非常优秀,以至于不把它用于自然语言处理之外的任务(例如计算机视觉)似乎有些浪费。然而,大问题是:我们能否实际将其用于计算机视觉任务?

事实证明,Transformer 也具有应用于计算机视觉任务的良好潜力。在 2020 年,Google Brain 团队推出了一种基于 Transformer 的模型,可以用于解决图像分类任务,称为 Vision Transformer(ViT)。与传统 CNN 在多个图像分类基准上的表现相比,它的表现非常有竞争力。

因此,在本文中,我们将讨论这个模型。具体来说,我们将讨论 ViT 模型如何工作以及如何利用 HuggingFace 库在我们自己的自定义数据集上对其进行微调,以进行图像分类任务。

所以,作为第一步,让我们开始使用本文中将要使用的数据集。

关于数据集

我们将使用一个小吃数据集,该数据集可以从 HuggingFace 的dataset库中轻松访问。该数据集标注为 CC-BY 2.0 许可证,这意味着你可以自由分享和使用它,只要在你的工作中引用数据集来源即可。

我们来瞧一瞧这个数据集:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据集中图像的子集

我们只需要几行代码就可以加载数据集,如下所示:

!pip install -q datasets

from datasets import load_dataset 

# Load dataset
dataset = load_dataset("Matthijs/snacks")
print(dataset)

# Output
  '''
  DatasetDict({
      train: Dataset({
          features: ['image', 'label'],
          num_rows: 4838
      })
      test: Dataset({
          features: ['image', 'label'],
          num_rows: 952
      })
      validation: Dataset({
          features: ['image', 'label'],
          num_rows: 955
      })
  })''' 

数据集是一个字典对象,由 4898 张训练图像、955 张验证图像和 952 张测试图像组成。

每张图片都有一个标签,属于 20 个小吃类别之一。我们可以通过以下代码检查这 20 种不同的类别:

print(dataset["train"].features['label'].names)

# Output
'''
['apple','banana','cake','candy','carrot','cookie','doughnut','grape',
'hot dog', 'ice cream','juice','muffin','orange','pineapple','popcorn',
'pretzel','salad','strawberry','waffle','watermelon']''' 

我们来创建一个标签与其对应索引之间的映射。

# Mapping from label to index and vice versa
labels = dataset["train"].features["label"].names
num_labels = len(dataset["train"].features["label"].names)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

print(label2id)
print(id2label)

# Output
'''
{'apple': 0, 'banana': 1, 'cake': 2, 'candy': 3, 'carrot': 4, 'cookie': 5, 'doughnut': 6, 'grape': 7, 'hot dog': 8, 'ice cream': 9, 'juice': 10, 'muffin': 11, 'orange': 12, 'pineapple': 13, 'popcorn': 14, 'pretzel': 15, 'salad': 16, 'strawberry': 17, 'waffle': 18, 'watermelon': 19}
{0: 'apple', 1: 'banana', 2: 'cake', 3: 'candy', 4: 'carrot', 5: 'cookie', 6: 'doughnut', 7: 'grape', 8: 'hot dog', 9: 'ice cream', 10: 'juice', 11: 'muffin', 12: 'orange', 13: 'pineapple', 14: 'popcorn', 15: 'pretzel', 16: 'salad', 17: 'strawberry', 18: 'waffle', 19: 'watermelon'}
'''

在继续之前,我们需要了解的一件重要事情是每张图像的尺寸是不同的。因此,我们需要在将图像输入模型进行微调之前执行一些图像预处理步骤。

现在我们了解了正在使用的数据集,让我们更详细地了解 ViT 架构。

ViT 的工作原理

在 ViT 引入之前,Transformer 模型依赖自注意力机制,这给我们在计算机视觉任务中使用它带来了很大的挑战。

自注意力机制是 Transformer 模型能够区分一个词在不同上下文中语义的原因。例如,BERT 模型能够通过自注意力机制区分词语*‘park’在句子‘They park their car in the basement’‘She walks her dog in a park’中的含义。

但是,自注意力有一个问题:这是一个计算开销大的操作,因为它要求每个标记关注序列中的每个其他标记。

现在,如果我们在图像数据上使用自注意力机制,那么图像中的每个像素都需要关注并与每个其他像素进行比较。问题是,如果我们将像素值增加一个,那么计算成本将会呈二次增长。如果图像分辨率较大,这显然是不可行的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

为了解决这个问题,ViT 引入了将输入图像拆分为图像块的概念。每个图像块的尺寸为 16 x 16 像素。假设我们有一张 48 x 48 像素的图像,那么图像块将会像这样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

在实际应用中,ViT 有两种选项来将我们的图像拆分成图像块:

  1. 将我们的输入图像(大小为height x width x channel)重塑为一个展平的 2D 图像块序列,大小为no.of patches x (patch_size².channel)。然后,我们将展平的图像块投影到一个基本的线性层中,以获得每个图像块的嵌入表示。

  2. 将我们的输入图像投影到卷积层中,卷积核的大小和步幅等于补丁大小。然后,我们将该卷积层的输出展平。

在对多个数据集测试模型性能后,结果表明第二种方法具有更好的性能。因此,在本文中,我们将使用第二种方法。

让我们用一个简单的例子来演示如何使用卷积层将输入图像拆分成补丁。

import torch
import torch.nn as nn

# Create toy image with dim (batch x channel x width x height)
toy_img = torch.rand(1, 3, 48, 48)

# Define conv layer parameters
num_channels = 3
hidden_size = 768 #or emb_dimension
patch_size = 16

# Conv 2D layer
projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, 
             stride=patch_size)

# Forward pass toy img
out_projection = projection(toy_img)

print(f'Original image size: {toy_img.size()}')
print(f'Size after projection: {out_projection.size()}')

# Output
'''
Original image size: torch.Size([1, 3, 48, 48])
Size after projection: torch.Size([1, 768, 3, 3])
'''

模型接下来会将补丁展平,并按顺序排列,如下图所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我们可以使用以下代码进行展平处理:

# Flatten the output after projection with Conv2D layer

patch_embeddings = out_projection.flatten(2).transpose(1, 2)
print(f'Patch embedding size: {patch_embeddings.size()}')

# Output
'''
Patch embedding size: torch.Size([1, 9, 768]) #[batch, no. of patches, emb_dim]
'''

我们在展平处理后得到的基本上是每个补丁的向量嵌入。这类似于许多基于 Transformer 的语言模型中的标记嵌入。

接下来,类似于 BERT,ViT 将在我们补丁序列的第一个位置添加一个特殊的**[CLS]**向量嵌入。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

# Define [CLS] token embedding with the same emb dimension as the patches
batch_size = 1
cls_token = nn.Parameter(torch.randn(1, 1, hidden_size))
cls_tokens = cls_token.expand(batch_size, -1, -1)

# Prepend [CLS] token in the beginning of patch embedding
patch_embeddings = torch.cat((cls_tokens, patch_embeddings), dim=1)
print(f'Patch embedding size: {patch_embeddings.size()}')

# Output
'''
Patch embedding size: torch.Size([1, 10, 768]) #[batch, no. of patches+1, emb_dim]
'''

如你所见,通过在补丁嵌入的开头添加**[CLS]**标记嵌入,序列的长度增加了一个。接下来的最后一步是将位置嵌入添加到我们的补丁序列中。这一步很重要,以便我们的 ViT 模型可以学习补丁的序列顺序。

这个位置嵌入是一个可学习的参数,将在训练过程中由模型更新。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

# Define position embedding with the same dimension as the patch embedding
position_embeddings = nn.Parameter(torch.randn(batch_size, 10, hidden_size))

# Add position embedding into patch embedding
input_embeddings = patch_embeddings + position_embeddings
print(f'Input embedding size: {input_embeddings.size()}')

# Output
'''
Input embedding size: torch.Size([1, 10, 768]) #[batch, no. of patches+1, emb_dim]
'''

现在,每个补丁的位置信息加上向量嵌入将作为一组 Transformer 编码器的输入。Transformer 编码器的数量取决于你使用的 ViT 模型类型。总体上,有三种类型的 ViT 模型:

  • **ViT-base:**它具有 12 层,隐藏大小为 768,总参数量为 86M。

  • **ViT-large:**它具有 24 层,隐藏大小为 1024,总参数量为 307M。

  • **ViT-huge:**它具有 32 层,隐藏大小为 1280,总参数量为 632M。

在以下代码片段中,假设我们想使用Vit-base。这意味着我们有 12 层 Transformer 编码器:

# Define parameters for ViT-base (example)
num_heads = 12
num_layers = 12

# Define Transformer encoders' stack
transformer_encoder_layer = nn.TransformerEncoderLayer(
           d_model=hidden_size, nhead=num_heads,
           dim_feedforward=int(hidden_size * 4),
           dropout=0.1)
transformer_encoder = nn.TransformerEncoder(
           encoder_layer=transformer_encoder_layer,
           num_layers=num_layers)

# Forward pass
output_embeddings = transformer_encoder(input_embeddings)
print(f' Output embedding size: {output_embeddings.size()}')

# Output
'''
Output embedding size: torch.Size([1, 10, 768])
'''

最后,Transformer 编码器堆叠将输出每个图像补丁的最终向量表示。最终向量的维度对应于我们使用的 ViT 模型的隐藏大小。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

就是这些了。

我们当然可以从头开始构建和训练自己的 ViT 模型。然而,与其他基于 Transformer 的模型一样,ViT 需要在大量图像数据(14M-300M 图像)上进行训练,以便在未见过的数据上具有良好的泛化能力。

如果我们想在自定义数据集上使用 ViT,最常见的方法是微调预训练模型。最简单的方法是利用 HuggingFace 库。我们只需调用ViTModel.from_pretrained()方法,并将预训练模型的路径作为参数传递即可。HuggingFace 的VitModel()类还将作为我们上述所有步骤的封装器。

!pip install transformers

from transformers import ViTModel

# Load pretrained model
model_checkpoint = 'google/vit-base-patch16-224-in21k'
model = ViTModel.from_pretrained(model_checkpoint, add_pooling_layer=False)

# Example input image
input_img = torch.rand(batch_size, num_channels, 224, 224)

# Forward pass input image
output_embedding = model(input_img)
print(output_embedding)
print(f"Ouput embedding size: {output_embedding['last_hidden_state'].size()}")

# Output
'''
BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.0985, -0.2080,  0.0727,  ...,  0.2035,  0.0443, -0.3266],
         [ 0.1899, -0.0641,  0.0996,  ..., -0.0209,  0.1514, -0.3397],
         [ 0.0646, -0.3392,  0.0881,  ..., -0.0044,  0.2018, -0.3038],
         ...,
         [-0.0708, -0.2932, -0.1839,  ...,  0.1035,  0.0922, -0.3241],
         [ 0.0070, -0.3093, -0.0217,  ...,  0.0666,  0.1672, -0.4103],
         [ 0.1723, -0.1037,  0.0317,  ..., -0.0571,  0.0746, -0.2483]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=None, hidden_states=None, attentions=None)

Output embedding size: torch.Size([1, 197, 768])
'''

完整 ViT 模型的输出是一个向量嵌入,表示每个图像补丁加上**[CLS]**标记。其维度为[batch_size, image_patches+1, hidden_size]

要执行图像分类任务,我们遵循与 BERT 模型相同的方法。我们提取**[CLS]**标记的输出向量嵌入,并通过最终的线性层来确定图像的类别。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

num_labels = 20

# Define linear classifier layer
classifier = nn.Linear(hidden_size, num_labels) 

# Forward pass on the output embedding of [CLS] token
output_classification = classifier(output_embedding['last_hidden_state'][:, 0, :])
print(f"Output embedding size: {output_classification.size()}")

# Output
'''
Output embedding size: torch.Size([1, 20]) #[batch, no. of labels]
'''

微调实现

在本节中,我们将微调一个在 ImageNet-21K 数据集上进行过预训练的ViT-base模型,该数据集包含约 1400 万张图像和 21,843 个类别。数据集中的每张图像的尺寸为 224 x 224 像素。

首先,我们需要定义预训练模型的检查点路径,并加载必要的库。

import numpy as np
import torch
import cv2
import torch.nn as nn
from transformers import ViTModel, ViTConfig
from torchvision import transforms
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

#Pretrained model checkpoint
model_checkpoint = 'google/vit-base-patch16-224-in21k'

图像数据加载器

如前所述,ViT-base 模型已在包含 224 x 224 像素尺寸图像的数据集上进行过预训练。这些图像还根据其每个颜色通道的特定均值和标准差进行了归一化。

因此,在我们将自己的数据集输入 ViT 模型进行微调之前,我们必须首先对图像进行预处理。这包括将每张图像转换为张量,将其调整到适当的尺寸,然后使用与模型预训练数据集相同的均值和标准差值进行归一化。

class ImageDataset(torch.utils.data.Dataset):

  def __init__(self, input_data):

      self.input_data = input_data
      # Transform input data
      self.transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                             std=[0.5, 0.5, 0.5])
        ])

  def __len__(self):
      return len(self.input_data)

  def get_images(self, idx):
      return self.transform(self.input_data[idx]['image'])

  def get_labels(self, idx):
      return self.input_data[idx]['label']

  def __getitem__(self, idx):
      # Get input data in a batch
      train_images = self.get_images(idx)
      train_labels = self.get_labels(idx)

      return train_images, train_labels

从上面的图像数据加载器中,我们将获取一批预处理过的图像及其对应的标签。我们可以使用上述图像数据加载器的输出作为微调过程中模型的输入。

模型定义

我们的 ViT 模型架构非常简单。由于我们将微调一个预训练模型,我们可以使用VitModel.from_pretrained()方法,并提供模型的检查点作为参数。

我们还需要在最后添加一个线性层,作为最终的分类器。这个层的输出应该等于数据集中不同标签的数量。

class ViT(nn.Module):

  def __init__(self, config=ViTConfig(), num_labels=20, 
               model_checkpoint='google/vit-base-patch16-224-in21k'):

        super(ViT, self).__init__()

        self.vit = ViTModel.from_pretrained(model_checkpoint, add_pooling_layer=False)
        self.classifier = (
            nn.Linear(config.hidden_size, num_labels) 
        )

  def forward(self, x):

    x = self.vit(x)['last_hidden_state']
    # Use the embedding of [CLS] token
    output = self.classifier(x[:, 0, :])

    return output

上述 ViT 模型为每个图像补丁和**[CLS]标记生成最终的向量嵌入。为了对图像进行分类,如上所示,我们提取[CLS]**标记的最终向量嵌入,并将其传递给最终的线性层以获得最终的类别预测。

模型微调

现在我们已经定义了模型架构并准备了输入图像用于批处理过程,我们可以开始微调我们的 ViT 模型。训练脚本是一个标准的 Pytorch 训练脚本,如下所示:

def model_train(dataset, epochs, learning_rate, bs):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Load nodel, loss function, and optimizer
    model = ViT().to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)

    # Load batch image
    train_dataset = ImageDataset(dataset)
    train_dataloader = DataLoader(train_dataset, num_workers=1, batch_size=bs, shuffle=True)

    # Fine tuning loop
    for i in range(epochs):
        total_acc_train = 0
        total_loss_train = 0.0

        for train_image, train_label in tqdm(train_dataloader):
            output = model(train_image.to(device))
            loss = criterion(output, train_label.to(device))
            acc = (output.argmax(dim=1) == train_label.to(device)).sum().item()
            total_acc_train += acc
            total_loss_train += loss.item()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print(f'Epochs: {i + 1} | Loss: {total_loss_train / len(train_dataset): .3f} | Accuracy: {total_acc_train / len(train_dataset): .3f}')

    return model

# Hyperparameters
EPOCHS = 10
LEARNING_RATE = 1e-4
BATCH_SIZE = 8

# Train the model
trained_model = model_train(dataset['train'], EPOCHS, LEARNING_RATE, BATCH_SIZE)

由于我们的零食数据集有 20 个不同的类别,因此我们面临的是一个多类分类问题。因此,CrossEntropyLoss()将是合适的损失函数。在上面的示例中,我们训练了模型 10 个周期,学习率设置为 1e-4,批量大小为 8。你可以调整这些超参数以优化模型的性能。

训练模型后,你将得到一个类似下面的输出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

模型预测

既然我们已经微调了模型,自然希望将其用于测试数据的预测。为此,首先创建一个函数来封装所有必要的图像预处理步骤和模型推理过程。

def predict(img):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                             std=[0.5, 0.5, 0.5])
        ])

    img = transform(img)
    output = trained_model(img.unsqueeze(0).to(device))
    prediction = output.argmax(dim=1).item()

    return id2label[prediction]

正如你所见,上面显示的推理过程中的图像预处理步骤与我们在训练数据上进行的步骤完全相同。然后,我们将变换后的图像作为输入传递给训练好的模型,最后将其预测映射到相应的标签。

如果我们想对测试数据中的特定图像进行预测,我们只需调用上述函数,之后我们会得到预测结果。让我们试试看。

print(predict(dataset['test'][900]['image']))
# Output: waffle

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据集中的测试数据示例

我们的模型正确预测了我们的测试图像。让我们尝试另一张。

print(predict(dataset['test'][250]['image']))
# Output: cookie

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据集中的测试数据示例

我们的模型再次正确地预测了测试数据。通过微调 ViT 模型,我们可以在自定义数据集上获得良好的性能。你也可以对任何自定义数据集执行相同的过程,用于图像分类任务。

结论

在本文中,我们已经看到 Transformer 不仅可以用于语言建模任务,还可以用于计算机视觉任务,在本例中是图像分类。

为了做到这一点,首先将输入图像分解成大小为 16 x 16 像素的补丁。然后,Vision Transformer 模型利用一系列 Transformer 编码器来学习每个图像补丁的向量表示。最后,我们可以使用图像补丁序列开头的**[CLS]**标记的最终向量表示来预测输入图像的标签。

我希望这篇文章对你开始使用 Vision Transformer 模型有所帮助。与往常一样,你可以在这个笔记本中找到本文中展示的代码实现。

数据集参考

huggingface.co/datasets/Matthijs/snacks

使用预训练扩散模型进行图像合成

原文:towardsdatascience.com/image-composition-with-pre-trained-diffusion-models-772cd01b5022?source=collection_archive---------5-----------------------#2023-07-12

一种提高对预训练文本到图像扩散模型生成图像的控制的方法

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Gabriele Sgroi, PhD

·

关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 7 月 12 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用文章中描述的方法生成的稳定扩散图像。图片由作者提供。

文本到图像的扩散模型在生成符合自然语言描述的逼真图像方面取得了惊人的表现。开源预训练模型的发布,例如稳定扩散,促进了这些技术的民主化。预训练扩散模型使任何人都可以创造出令人惊叹的图像,而不需要大量计算能力或漫长的训练过程。

尽管文本引导的图像生成提供了控制水平,但即使有大量提示,获得具有预定组成的图像仍然很棘手。实际上,标准的文本到图像扩散模型对生成图像中将描绘的各种元素几乎没有控制。

在这篇文章中,我将解释一种基于论文MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation的最新技术。这种技术使得在由文本引导的扩散模型生成的图像中放置元素的控制更为精确。论文中提出的方法更为通用,还可以用于其他应用,如生成全景图像,但我将在这里限制讨论图像组成性,使用基于区域的文本提示。该方法的主要优点是可以与开箱即用的预训练扩散模型一起使用,无需昂贵的重新训练或微调。

为了补充这篇文章的代码,我准备了一个简单的Colab notebook和一个GitHub 仓库,其中包含了我用于生成本文中图像的代码实现。该代码基于 Hugging Face 的diffusers library中的稳定扩散管道,但只实现了其功能所需的部分,使其更简单易读。

扩散模型

在这一部分中,我将回顾一些关于扩散模型的基本事实。扩散模型是生成模型,通过逆转扩散过程来生成新数据,该过程将数据分布映射到各向同性的高斯分布。更具体地说,给定一个图像,扩散过程包括一系列步骤,每一步都向图像中添加少量高斯噪声。在无限多步的极限下,噪声图像将与从各向同性高斯分布中采样的纯噪声无法区分。

扩散模型的目标是通过尝试猜测扩散过程中的步骤 t-1 处的噪声图像来逆转这一过程,给定步骤 t 处的噪声图像。例如,可以通过训练一个神经网络来预测该步骤添加的噪声,并将其从噪声图像中减去来实现这一目标。

一旦我们训练好这样一个模型,就可以通过从各向同性的高斯分布中采样噪声来生成新图像,并使用模型通过逐渐去除噪声来逆转扩散过程。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

扩散模型的目标是学习所有时间步 t 的概率 q(x(t-1)|x(t))。图像来自论文:Denoising Diffusion Probabilistic Models

文本到图像扩散模型反转扩散过程,试图达到与文本提示描述相对应的图像。这通常是通过神经网络完成的,该网络在每一步 t 预测步骤 t-1 的噪声图像,条件不仅是步骤 t 的噪声图像,还包括描述其试图重建的图像的文本提示。

许多图像扩散模型,包括稳定扩散,不是在原始图像空间中操作,而是在一个较小的学习潜空间中操作。通过这种方式,可以在最小质量损失的情况下减少所需的计算资源。潜空间通常通过变分自编码器来学习。潜空间中的扩散过程与之前完全相同,允许从高斯噪声生成新的潜向量。从这些向量中,可以使用变分自编码器的解码器获得新生成的图像。

使用 MultiDiffusion 进行图像组合

现在让我们转向解释如何使用 MultiDiffusion 方法获得可控图像组合。目标是通过预训练的文本到图像扩散模型更好地控制生成图像中的元素。具体而言,给定图像的一般描述(例如封面图像中的客厅),我们希望一系列通过文本提示指定的元素出现在特定位置(例如中心的红色沙发,左侧的盆栽和右上角的画作)。这可以通过提供一组描述所需元素的文本提示和一组基于区域的二进制掩码来实现,该掩码指定了元素必须描绘在其中的位置。例如,下面的图像包含封面图像中图像元素的边界框。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

生成封面图像所用的边界框和提示。图像由作者提供。

MultiDiffusion用于可控图像生成的核心思想是将多个扩散过程结合在一起,针对不同指定的提示,以获得在预定区域显示每个提示内容的连贯和平滑图像。与每个提示关联的区域通过与图像相同尺寸的二进制掩码来指定。如果提示必须在该位置描绘,则掩码的像素设为 1,否则设为 0。

更具体地说,我们用 t 表示在潜在空间中运行的扩散过程中的一个通用步骤。给定时间步 t 的噪声潜在向量,模型将预测每个指定文本提示的噪声。从这些预测的噪声中,我们通过从时间步 t 的前一个潜在向量中去除每个预测噪声,获得时间步 t-1 的一组潜在向量(每个提示一个)。为了获得扩散过程下一时间步的输入,我们需要将这些不同的向量组合在一起。这可以通过将每个潜在向量乘以相应的提示掩码,然后按掩码加权取每像素的平均值来完成。按照这个程序,在特定掩码指定的区域内,潜在向量将遵循由相应局部提示引导的扩散过程轨迹。在每一步将潜在向量组合在一起后,再预测噪声,可以确保生成图像的全球一致性以及不同掩码区域之间的平滑过渡。

MultiDiffusion 在扩散过程开始时引入了一个自举阶段,以更好地遵循紧密的掩码。在这些初步步骤中,与不同提示相对应的去噪潜在向量不会被组合在一起,而是与一些对应于常色背景的噪声潜在向量结合在一起。通过这种方式,由于布局通常在扩散过程早期就已经确定,因此可以在模型最初只专注于掩码区域来描绘提示的情况下,获得与指定掩码更好的匹配。

示例

在本节中,我将展示该方法的一些应用。我使用了 HuggingFace 托管的预训练 stable diffusion 2 模型创建了本文中的所有图像,包括封面图像。

正如讨论的那样,该方法的一个直接应用是获得包含在预定义位置生成的元素的图像。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

边界框。图片由作者提供。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用上述边界框生成的图像。图片由作者提供。

该方法允许指定单个元素的风格或其他属性。这可以用于例如在模糊背景上获得清晰的图像。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

模糊背景的边界框。图片由作者提供。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用上述边界框生成的图像。图片由作者提供。

元素的风格也可以非常不同,带来令人惊叹的视觉效果。例如,下面的图像是通过将高质量照片风格与梵高风格的画作混合获得的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不同风格的边界框。图片由作者提供。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用上述边界框生成的图像。图像由作者提供。

结论

在这篇文章中,我们探讨了一种结合不同扩散过程的方法,以提高对由文本条件扩散模型生成的图像的控制能力。该方法增强了对图像中元素生成位置的控制,并且能够无缝地结合以不同风格描绘的元素。

描述的程序的主要优点之一是它可以与预训练的文本到图像扩散模型一起使用,而无需进行通常较为昂贵的微调。另一个优势是可控图像生成通过二进制掩码实现,这比更复杂的条件设置更容易指定和处理。

这种技术的主要缺点是,在每个扩散步骤中需要为每个提示进行一次神经网络传递,以预测相应的噪声。幸运的是,这些操作可以批量进行,以减少推断时间开销,但代价是更大的 GPU 内存使用。此外,有时一些提示(特别是仅在图像小部分中指定的提示)会被忽视或覆盖的区域比相应掩码指定的区域要大。虽然可以通过引导步骤来缓解这个问题,但过多的引导步骤可能会显著降低图像的整体质量,因为可以用来协调元素的步骤减少了。

值得一提的是,结合不同扩散过程的想法并不限于本文描述的内容,它还可以用于其他应用,例如论文中描述的全景图像生成MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation

希望你喜欢这篇文章,如果你想深入了解技术细节,可以查看这个Colab 笔记本GitHub 仓库的代码实现。

使用 Python 的图像滤镜

原文:towardsdatascience.com/image-filters-with-python-3dc223a12624

一个简洁的计算机视觉项目,用于使用 Python 构建图像滤镜

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Bharath K

·发布于Towards Data Science ·8 分钟阅读·2023 年 2 月 10 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Pineapple Supply Co.拍摄,来源于Unsplash

图像存在于不同的尺度、对比度、位深和质量中。我们被各种独特和美丽的图像包围,这些图像遍布我们的周围和互联网。操作这些图像可以产生一些有趣的结果,这些结果被用于各种有趣和有用的应用。

在图像处理和计算机视觉中,操作图像是解决不同任务和获得各种项目期望结果的关键组成部分。通过正确处理图像任务,我们可以重新创建一个修改后的图像版本,这对多种计算机视觉和深度学习应用(如数据增强)非常有用。

在本文中,我们将重点开发一个简单的图像滤镜应用程序,主要用于修改特定图像的亮度和对比度。还可以实现并添加到项目中的其他一些显著修改,包括着色器样式、剪贴画、表情符号和其他类似的附加内容。

如果读者不熟悉计算机视觉和 OpenCV,我建议查看我之前的一篇文章,内容是关于 OpenCV 和计算机视觉的全面初学者指南。相关链接如下。我建议在继续阅读本文其余内容之前先查看它。

## OpenCV: 完整的初学者指南掌握计算机视觉基础及代码!

一个教程,包含代码,旨在掌握计算机视觉的所有重要概念及其使用 OpenCV 实现的方法

[towardsdatascience.com

使用 Python 的亮度和对比度调整器的起始代码:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Jacopo Maia 提供,来自 Unsplash

在本节中,我们将查看一个简单的起始代码,这将帮助我们开始使用 OpenCV 计算机视觉库来修改原始图像的亮度和对比度的基本图像过滤器。为此任务,我们将使用随机图像来测试示例代码并理解其基本工作原理。

为了理解这个测试案例,我使用上面的图像作为测试样本来分析和理解亮度和对比度调整器的工作过程。为了跟进这个项目,我强烈建议下载上述图像并将其存储为 test_image.jpg 在工作目录中。

请注意,你可以用任何其他功能名称存储图像,并使用任何其他格式的图像。唯一的重要步骤是读取图像时提及适当的名称及其格式。第一步是导入 OpenCV 库,如下面的代码块所示,并确保该库完全正常运行。

# Importing the computer vision library
import cv2

一旦导入库并验证其工作,我们可以继续读取我们最近在工作目录中保存的原始图像。只需提及图像名称,cv2.imread 函数就能读取图像。

如果图像未存储在工作目录中,请确保提及特定文件及其图像名称。由于上述图像的原始尺寸为 4608 x 3072,将图像缩小一点并减少尺寸是最佳的。我已将图像调整为 (512, 512) 规模,以便更容易跟踪进度并以稍高的速度执行所需任务。

在下面代码示例的最后一步,我们将定义 alpha 和 gamma 参数,这些参数将分别作为对比度和亮度的调整器。使用这两个参数,我们可以相应地控制这些值。请注意,对比度参数的范围是 0 到 127,而亮度参数的范围是 0 到 100。

# read the input image
image = cv2.imread('test_image.jpg')
image = cv2.resize(image, (512, 512))

# Define the alpha and gamma parameters for brightness and contrast accordingly
alpha = 1.2
gamma = 0.5

一旦我们完成了所有之前的步骤,我们将继续使用 Open CV 库中的“add weighted”函数,这有助于我们计算 alpha 和 gamma(亮度和对比度值)。该函数主要用于通过使用 alpha、beta 和 gamma 值来混合图像。请注意,对于单个图像,我们可以将 beta 设为零,并获得适当的结果。

最后,我们将在计算机视觉窗口中显示修改后的图像,如下面的示例代码块所示。一旦图像显示出来,当用户点击关闭按钮时,我们可以继续关闭窗口。为了进一步理解 Open CV 的一些基本函数,我强烈建议查看我之前提到的文章以获得更多信息。

# Used the added weighting function in OpenCV to compute the assigned values together
final_image = cv2.addWeighted(image, alpha, image, 0, gamma)

# Display the final image
cv2.imshow('Modified Image', final_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

在上述起始代码示例中,参数 alpha 作为对比度,gamma 作为亮度调整器。然而,修改这些参数以适应多种不同的变化可能会稍显不切实际。因此,最好的方法是创建一个 GUI 界面,并为所需的图像滤镜找到最佳值。这个话题将在接下来的部分中进一步讨论。

带有控制器的项目进一步开发:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者修改后的图像截图

我们可以使用 Open CV 创建一个自定义 GUI,利用控制器优化亮度和对比度参数,并相应调整我们的结果。我们可以为每个属性设置一个滑块,当滑动时,可以通过对每个亮度和对比度变化应用独特的滤镜来创建图像的不同效果。

前几步与上一部分类似,我们将导入 Open CV 库,读取相应的图像并按需调整其大小。读者可以选择适合自己的尺寸。以下是相关的代码片段。

# Import the Open CV library
import cv2

# read the input image
image = cv2.imread('test_image.jpg')
image = cv2.resize(image, (512, 512))

在下一步中,我们将定义亮度和对比度函数,通过这些函数我们可以定义获取 trackbar 位置的函数,以获得当前亮度和对比度元素的位置。加权函数将用于计算这两个参数,以计算这两个组合的综合输出,如下面的代码块所示。

# Creating the control function for the brightness and contrast
def BrightnessContrast(brightness=0):
    brightness = cv2.getTrackbarPos('Brightness',
                                    'Image')

    contrast = cv2.getTrackbarPos('Contrast',
                                  'Image')

    effect = cv2.addWeighted(image, brightness, image, 0, contrast)

    cv2.imshow('Effect', effect)

一旦定义了 trackbar 函数,我们将创建一个命名窗口和一个用于获取所需参数的 trackbar。我们将为亮度和对比度特性分别创建两个独立的 trackbars,如下面的代码块所示。当原始图像窗口和带有图像的 trackbar 窗口都关闭时,wait key 函数将激活,以帮助终止程序。

# Defining the parameters for the Open CV window
cv2.namedWindow('Image')

cv2.imshow('Image', image)

cv2.createTrackbar('Brightness',
                    'Image', 0, 10,
                    BrightnessContrast) 

cv2.createTrackbar('Contrast', 'Image',
                    0, 20,
                    BrightnessContrast)  

BrightnessContrast(0)

cv2.waitKey(0)

我们可以调整亮度和对比度滑块的轨迹条位置,将数字分别从 0 到 10 和 0 到 20 进行更改。通过调整相应的位置可以观察到变化。然而,这个项目仍然有很大的改进空间,以及更高的参数调整尺度。我们可以在 255 的尺度上调整亮度值,在 127 的尺度上调整对比度,以获得更细致的图像。

为了解决这些问题并进一步提升这个项目,我建议访问 Geeks for Geeks 网站,我在本节中使用了一部分代码。强烈推荐读者查看以下链接以进一步阅读和理解这个项目。

结论:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 MailchimpUnsplash 上提供

我结合魔法和科学来创造幻觉。我从事新媒体和互动技术的工作,如人工智能或计算机视觉,并将它们融入我的魔法中。

Marco Tempest

图像在计算机视觉和图像处理中扮演着重要角色。通过修改图像以创建类似或高度过滤的变体,我们可以通过积累更多的数据来解决各种项目,比如数据增强或其他类似任务。

我们还可以通过对特定图像进行改进来实现更理想的效果,通过这种方式,可以执行其他有用的深度学习任务。在本文中,我们探索了通过添加亮度和对比度等特性来修改原始图像的使用。在第一部分中,我们学习了如何利用 alpha 和 gamma 参数作为亮度和对比度参数。

在接下来的部分中,我们开发了一个 GUI 界面,通过它可以操作原始图像以获取其修改版本的副本。所有任务都是在一些基本的计算机视觉和图像处理知识及库的基础上完成的。

如果你想在我的文章发布时第一时间得到通知,可以查看以下链接以订阅电子邮件推荐。如果你希望支持其他作者和我,请订阅以下链接。

[## 使用我的推荐链接加入 Medium - Bharath K

阅读 Bharath K(以及 Medium 上成千上万的其他作者)的每一个故事。您的会员费用将直接支持…

bharath-k1297.medium.com

如果你对本文中的各个点有任何疑问,请在下方评论中告诉我。我会尽快回复你。

查看我与本文主题相关的其他文章,你可能也会喜欢阅读!

前往数据科学 [## Jupyter Notebooks 的终极替代方案

讨论 Jupyter Notebooks 在数据科学项目中的一个优秀替代选项

前往数据科学 [## 阅读七篇最佳研究论文,启动深度学习项目

七篇经得起时间考验的最佳研究论文,将帮助你创建出色的项目

前往数据科学 [## 使用 Python 开发自己的拼写检查工具包

使用 Python 创建一个有效的拼写检查应用程序

前往数据科学

感谢大家一直看到最后。希望你们喜欢阅读这篇文章。祝大家有美好的一天!

使用 ChatGPT 生成图像的代码

原文:towardsdatascience.com/image-generation-with-chatgpt-68c98a061bec

如果图像仅仅是像素值的矩阵,那么 ChatGPT 能否编写代码生成对应于有意义图像的矩阵呢?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Jamshaid Shahir, Ph.D

·发布于Towards Data Science ·阅读时间 8 分钟·2023 年 2 月 10 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Jonathan Kemper提供,来自Unsplash

像许多人一样,我发现自己被 OpenAI 的 ChatGPT 深深吸引,并将其用于各种活动。从调试代码这样的严肃任务,到写诗和短篇故事这样的创意应用,探索 ChatGPT 的功能非常有趣。同时,查看它的不足之处也很有启发性。作为一个语言模型,它不能像 Midjourney AI 或 DALL-E 那样生成直接的图像,因为它并未在大量图像上进行训练,而是训练于文本。然而,从数学的角度来看,图像只是二维(或在彩色图像的情况下,即 RGB 的三维)数组,你可以让它编写代码来生成图像。因此,我要求 ChatGPT 生成对应于图像的数组。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

没有修改,我将这段代码复制并粘贴到 Jupyter Notebook 中运行,确实得到了一个白色方块(尽管不完全在中间):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片(代码由 ChatGPT 编写)

因此,从技术上讲,它可以生成图像。然而,我想看看它是否能生成更复杂的东西,比如篮球。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

和上次一样,我直接在 Jupyter Notebook 中运行了这段代码。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

如你所见,ChatGPT 显然没有生成一个篮球,而是一个模糊的物体,给人一种 3D 球体的错觉。然后我给了它一些额外的指示,指出应该有黑色线条贯穿其中(尽管事后看来,我可能应该指定黑色的“曲线”)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

这一次,我只得到了一个纯黑色的方块。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我们已经接近目标,因为现在圆圈里有一条粗黑线,但它仍然离篮球相差甚远(如果你问我,它几乎看起来像一个日本饭团)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

现在我们只是有一条细细的黑色水平线穿过我们的圆圈。为了使其更像一个真正的篮球,我们需要一些黑色的曲线贯穿其中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

这一次,当我运行代码时,我实际上得到了一个IndexError,我要求 ChatGPT 进行调试并相应修订

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In [7], line 14
     11         image[y.astype(int) + i, x] = 0
     12     return image
---> 14 basketball = basketball_image()
     15 plt.imshow(basketball, cmap='gray')
     16 plt.show()

Cell In [7], line 11, in basketball_image()
      9 y[y < 64] = 64
     10 for i in range(0, 256, 8):
---> 11     image[y.astype(int) + i, x] = 0
     12 return image

IndexError: index 256 is out of bounds for axis 0 with size 256

当我将这个错误报告给 ChatGPT 时,它提供了一个可能的解释和代码重写:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

这一次,当运行 ChatGPT 的代码时,我遇到了ValueError:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [9], line 14
     11                 image[y.astype(int) + i, j] = 0
     12     return image
---> 14 basketball_image()

Cell In [9], line 10, in basketball_image()
      8 for i in range(0, 256, 8):
      9     for j in range(256):
---> 10         if y[j].astype(int) + i < 256:
     11             image[y.astype(int) + i, j] = 0
     12 return image

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我再次将错误报告给 ChatGPT,它解释说是将一个数组与一个整数进行比较,并提出了以下修订,但仍然导致了另一个IndexError

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In [10], line 16
     13                 image[y[j].astype(int) - i, j] = 0
     14     return image
---> 16 basketball = basketball_image()
     17 plt.imshow(basketball, cmap='gray')
     18 plt.show()

Cell In [10], line 11, in basketball_image()
      9 for j in range(256):
     10     if y[j].astype(int) + i < 256:
---> 11         image[y[j].astype(int) + i, j] = 0
     12     if y[j].astype(int) - i >= 0:
     13         image[y[j].astype(int) - i, j] = 0

IndexError: index -9223372036854775808 is out of bounds for axis 0 with size 256

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

结果证明第三次终于成功了,因为我没有收到任何错误。不过,我没有得到任何看起来像篮球的东西,甚至连球的形状都没有(至少知道要画曲线):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我要求再试一次,这次它尝试使用正弦函数生成图像。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不幸的是,这只给了我一个黑色方块:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我尝试给它一些鼓励,但经过这么多次失败和几次错误后,它决定放弃了:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

对于我的最终练习,我要求它生成一个笑脸,并且加了个要求,即生成一个 RGB 图像,这样它就需要生成一个 3D 数组而不是 2D。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

第一次尝试显然很糟糕。实际上,chatGPT 还需要多尝试八次才能最终得到一些连贯的结果。为了简洁起见,我将展示它绘制的前三个最佳图像及其对应的代码。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

我一直强调的主要点是画一个大圆圈,在圆圈内画两个较小的圆圈来表示眼睛,并画一个竖直的半圆来表示嘴巴。

生成下一个图像的代码片段过长,无法显示完整的屏幕截图,因此我将代码单独粘贴。

import matplotlib.pyplot as plt
import numpy as np

def smiley_face(size=512):
    # Create an empty image with the given size
    image = np.ones((size, size, 3), dtype=np.uint8) * 255

    # Create the circle for the face
    center = (size // 2, size // 2)
    radius = size // 2 - 10
    for i in range(size):
        for j in range(size):
            if (i - center[0]) ** 2 + (j - center[1]) ** 2 < radius ** 2:
                image[i, j, :] = [255, 255, 102]

    # Create the circle for the left eye
    eye_radius = size // 8
    eye_center = (size // 2 - size // 6, size // 2 - size // 6)
    for i in range(size):
        for j in range(size):
            if (i - eye_center[0]) ** 2 + (j - eye_center[1]) ** 2 < eye_radius ** 2:
                image[i, j, :] = [0, 0, 0]

    # Create the circle for the right eye
    eye_center = (size // 2 + size // 6, size // 2 - size // 6)
    for i in range(size):
        for j in range(size):
            if (i - eye_center[0]) ** 2 + (j - eye_center[1]) ** 2 < eye_radius ** 2:
                image[i, j, :] = [0, 0, 0]

    # Create the smile
    smile_center = (size // 2, size // 2 + size // 4)
    smile_radius = size // 4
    for i in range(size):
        for j in range(size):
            if (i - smile_center[0]) ** 2 + (j - smile_center[1]) ** 2 < smile_radius ** 2 and j > smile_center[1]:
                image[i, j, :] = [255, 0, 0]

    return image

smiley = smiley_face()
plt.imshow(smiley)
plt.axis('off')
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

这是离笑脸最接近的图像,除了某种原因,它被旋转到了侧面。经过两个额外的提示,它最终生成了这个图像:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

到那时,经过 8 次尝试,我宣布这相比篮球已经算是成功了!

结论

总之,ChatGPT 具备使用 Python 的numpy库生成对应有意义图像的能力,但存在局限性。尽管缺乏训练数据,ChatGPT 仍能够生成简单的图像,如笑脸。然而,它在生成复杂图像如篮球时遇到困难,需要多次迭代才能生成笑脸的代码。随着进一步的发展,ChatGPT 可能能够在较少的帮助下生成基础图像。未来,我们可能会看到 ChatGPT 与 AI 图像生成器的结合。在我的下一篇文章中,我将讨论如何利用 ChatGPT 为 AI 图像生成器提供提示。

如果你喜欢这篇文章并且是 Medium 的新用户,可以考虑成为会员。如果你通过这个推荐链接加入,我将从你的会员费中获得一部分,而你可以享受 Medium 提供的全部内容,且无需额外费用。

[## Jamshaid Shahir - Medium

阅读来自 Jamshaid Shahir 在 Medium 上的文章。计算生物学博士生。喜欢通过数据探索世界……

medium.com](https://medium.com/@jashahir?source=post_page-----68c98a061bec--------------------------------)

医学数据集的图像配准

原文:towardsdatascience.com/image-registration-for-medical-datasets-ee605ff8eb2e

从 SimpleElastix 到空间变换网络

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Charlie O’Neill

·发布于 Towards Data Science ·阅读时间 31 分钟·2023 年 2 月 22 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Michael Dziedzic 提供,来源于 Unsplash

介绍

图像配准是图像处理中的一项基础任务,涉及将两个或多个图像对齐到一个共同的坐标系统中。通过这样做,图像中的对应像素表示现实世界中的同源点,从而使图像的比较和分析成为可能。图像配准的一个常见应用是在医学成像中,其中对同一患者进行多次扫描或拍摄,由于时间、位置或其他因素的不同而产生变化。配准这些图像可以揭示出可能指示疾病进展或治疗效果的微妙变化或模式。

图像配准涉及寻找一种空间变换,将一个图像中的点映射到另一个图像中的对应点,以便可以将图像重叠在一起。空间变换通常由一组控制点参数化,这些控制点用于将一个图像扭曲以匹配另一个图像。配准的质量通过相似度度量来衡量,该度量量化了图像之间的对应程度。

近年来,由于先进成像技术的出现、计算能力的提升以及对更准确和高效医学图像分析的需求,医学图像配准引起了越来越多的关注。图像配准已经成为广泛医学图像分析任务的前提条件,包括解剖结构的分割、计算机辅助诊断、疾病进展监测、外科干预和治疗规划。

尽管大量研究集中在开发图像配准算法上,但对这些算法的可访问性、互操作性和扩展性关注较少。科学源代码通常未公开,因未考虑其他研究人员的需求而难以使用,或缺乏适当的文档。这限制了图像配准算法的采用和使用,阻碍了科学进步和可重复性。

为了解决这些挑战,开发了几个开源医学图像配准库,其中 SimpleElastix 是最受欢迎的之一。SimpleElastix 是 SimpleITK 的扩展,SimpleITK 是一个开源医学图像分析库,允许用户完全在 Python、Java、R、Octave、Ruby、Lua、Tcl 和 C# 中配置和运行 Elastix 配准算法。SimpleElastix 提供了一个简单的参数接口、模块化架构和多种变换、度量和优化器,使其易于使用且计算高效。它还提供了一系列功能,如随机采样、多线程和代码优化,以加快配准速度,而不牺牲鲁棒性。

在这里,我将深入探讨使用 SimpleElastix 进行图像配准的过程,重点介绍注册来自地理萎缩患者的视网膜图像的具体示例。我还将提供实施这一配准过程的逐步指南,并探讨其他技术,如光流和空间变换网络。希望这能让你更好地理解医学成像中图像配准的重要性以及实施它的工具。

设置

任务是处理来自地理萎缩患者(眼病的一种)的视网膜图像,并将这些图像进行患者间注册,即仅将来自同一患者的图像注册到该患者。为了解释一下,地理萎缩(GA)特征是视网膜色素上皮细胞的丧失,这些细胞负责支持和滋养黄斑中的视网膜感光细胞。视网膜色素上皮细胞的丧失会导致黄斑中出现一个或多个萎缩区或“孔洞”,这可能导致中心视力丧失,影响个人进行日常活动,如阅读、驾驶和面孔识别。你将在下面的视网膜图像中注意到这些萎缩区域。

你可以从这个 Kaggle 数据集中获取与代码一起使用的图像。首先,我们需要导入适当的模块。

from pathlib import Path
import matplotlib.pyplot as plt
from typing import List
import numpy as np
import seaborn as sns
import os
import cv2
import pandas as pd
from tqdm.notebook import tqdm
from skimage.registration import optical_flow_tvl1, optical_flow_ilk
from skimage.transform import warp
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import normalized_root_mse as nrmse

接下来,我们编写一个函数来处理图像的检索。由于我们只想将视网膜图像注册到同一只眼睛,我们需要指定要加载的患者和侧别:

def retrieve_images(patient_id = '156518', laterality = 'L', date = None):
    # Set the root directory for the patient data
    root_dir = Path(f'../data/{patient_id}')

    # Get the list of image filenames for the left eye
    image_filenames = [f for f in os.listdir(root_dir) if f'{laterality}.png' in f]

    # If we are registering to same visit, only keep files from given date
    if date != None:
        pattern = re.compile(r"\w+_(\d{4}-\d{2}-\d{2})_")
        image_filenames = [file for file in image_filenames if date in file]
    # Read the images into a list
    images = [cv2.imread(str(root_dir / f)) for f in image_filenames]
    # Convert the images to grayscale
    gray_images = [rgb2gray(img) for img in images]
    # Register all images to the first image
    template = gray_images[0]
    # Remove invalid images
    final_images = [x for x in gray_images[1:] if x.shape == template.shape]
    return final_images, template

在评估我们的配准算法时,我们的评估指标将是一个计算注册图像与模板图像之间距离的函数。我们希望能够追踪这些指标中的一些。常见的指标包括:

  • L1 损失,也称为平均绝对误差,测量两张图像之间逐元素差异的平均幅度。它对离群值具有鲁棒性,并对所有像素赋予相等的权重,使其成为图像配准的一个不错选择。

  • 均方根误差(RMSE)是两张图像之间平方差的均值的平方根。它对较大的差异赋予更多权重,使其对离群值非常敏感。RMSE 常用于图像配准,以测量两张图像之间的总体差异。

  • 归一化互相关 是一种衡量两张图像之间相似度的指标,考虑了它们的强度。它被归一化以确保结果在 -1 和 1 之间,其中 1 表示完全匹配。归一化互相关常用于图像配准,以评估配准质量,特别是在处理强度不同的图像时。

  • 相似度 是衡量两张图像之间重叠程度的指标,考虑了强度和空间信息。常见的用于图像配准的相似度指标包括互信息、归一化互信息和詹森-香农散度。这些指标提供了两张图像之间共享信息的度量,使其非常适合评估图像配准的质量。

以下函数接受一个注册图像的列表以及模板图像,并计算每张图像的上述指标:

def evaluate_registration(template_img: np.ndarray, 
                          registered_imgs: List[np.ndarray]) -> (List[float], List[float], List[float]):
    """
    Evaluate the registration quality of multiple registered images with respect to a template image.
    """
    l1_losses = []
    ncc_values = []
    ssim_values = []

    for registered_img in registered_imgs:
        # Compute L1 loss between the template and registered images
        l1_loss = np.mean(np.abs(template_img - registered_img))
        l1_losses.append(l1_loss)

        # Compute normalized cross-correlation between the template and registered images
        ncc = np.corrcoef(template_img.ravel(), registered_img.ravel())[0,1]
        ncc_values.append(ncc)

        # Compute structural similarity index between the template and registered images
        ssim_value = ssim(template_img, registered_img, data_range=registered_img.max() - registered_img.min())
        ssim_values.append(ssim_value)

    return l1_losses, ncc_values, ssim_values

根据这些损失,最好有某种函数可以根据损失显示最佳和最差注册图像。这在某种程度上类似于在分类任务中查看混淆矩阵的个别示例。

def visualise_registration_results(registered_images, original_images, template, loss_values):
    num_images = min(len(registered_images), 3)
    # Get the indices of the three images with the highest L1 losses
    top_indices = np.argsort(loss_values)[-num_images:]
    # Get the indices of the three images with the lowest L1 losses
    bottom_indices = np.argsort(loss_values)[:num_images]
    # Create the grid figure
    fig, axes = plt.subplots(num_images, 4, figsize=(20, 15))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)
    # Loop through the top three images
    for i, idx in enumerate(top_indices):
        # Plot the original image in the first column of the left section
        ax = axes[i][0]
        ax.imshow(original_images[idx], cmap='gray')
        original_l1 = np.mean(np.abs(template - original_images[idx]))
        ax.set_title("Original Image (L1 Loss: {:.2f})".format(original_l1))
        # Plot the registered image in the second column of the left section
        ax = axes[i][1]
        ax.imshow(registered_images[idx], cmap='gray')
        ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))
    # Loop through the bottom three images
    for i, idx in enumerate(bottom_indices):
        # Plot the original image in the first column of the right section
        ax = axes[i][2]
        ax.imshow(original_images[idx], cmap='gray')
        original_l1 = np.mean(np.abs(template - original_images[idx]))
        ax.set_title("Original Image (L1 Loss: {:.2f})".format(original_l1))
        # Plot the registered image in the second column of the right section
        ax = axes[i][3]
        ax.imshow(registered_images[idx], cmap='gray')
        ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))
    # Show the grid
    plt.show()

编写一个汇总函数,显示我们的配准算法所取得的整体改进,这可能是个好主意。

def highlight_worse(val, comparison_column, worse_val, better_val):
    color = better_val if val == worse_val else worse_val
    return 'background-color: {}'.format(color)

def style_df(df_dict):
    df = pd.DataFrame(df_dict)
    for column in df.columns:
        comparison_column = 'original' if column == 'registered' else 'registered'
        worse_val = 'red'
        better_val = 'green'
        if column in ['ncc', 'ssim']:
            worse_val, better_val = better_val, worse_val
        df.style.apply(highlight_worse, axis=1, subset=[column], comparison_column=comparison_column, worse_val=worse_val, better_val=better_val)
    return df

def summarise_registration(original_images, registered_images, template):

    # Calculate metrics for original images
    l1_losses, ncc_values, ssim_values = evaluate_registration(template, original_images)
    l1_original, ncc_original, ssim_original = np.mean(l1_losses), np.mean(ncc_values), np.mean(ssim_values)

    # Calculate metrics for registered images
    l1_losses, ncc_values, ssim_values = evaluate_registration(template, registered_images)
    l1_registered, ncc_registered, ssim_registered = np.mean(l1_losses), np.mean(ncc_values), np.mean(ssim_values)

    # Create dataframe
    df_dict = {'original': {'l1': l1_original, 'ncc': ncc_original, 'ssim': ssim_original}, 
               'registered': {'l1': l1_registered, 'ncc': ncc_registered, 'ssim': ssim_registered}}

    return style_df(df_dict)

最后,我们将为任何配准算法编写一个简洁的包装器,以便我们能够轻松地应用和评估它:

class RegistrationAlgorithm:

    def __init__(self, registration_function):
        self.registration_function = registration_function
        self.final_images, self.template = retrieve_images()
        self.registered_images = self.apply_registration()

    def apply_registration(self):
        # Do the registration process
        registered_images = []
        for i, img in enumerate(tqdm(self.final_images)):
            registered = self.registration_function(self.template, img) 
            registered_images.append(registered)
        return registered_images

    def evaluate(self, template_img, registered_imgs):
        l1_losses = []
        ncc_values = []
        ssim_values = []

        for registered_img in registered_imgs:

            # Compute L1 loss between the template and registered images
            l1_loss = np.mean(np.abs(template_img - registered_img))
            l1_losses.append(l1_loss)

            # Compute normalized cross-correlation between the template and registered images
            ncc = np.corrcoef(template_img.ravel(), registered_img.ravel())[0,1]
            ncc_values.append(ncc)

            # Compute structural similarity index between the template and registered images
            ssim_value = ssim(template_img, registered_img, data_range=registered_img.max() - registered_img.min())
            ssim_values.append(ssim_value)

        return l1_losses, ncc_values, ssim_values

探索性数据分析

让我们获取一些图像,看看我们要处理的是什么。

images, template = retrieve_images()

一个好的主意是检查哪些图像与模板图像的差异最大。我们可以重复使用上述函数来实现这一点。让我们计算未注册图像与模板之间的损失,然后查看差异最大(损失最高)的图像。

# calculate various distances
l1_losses, ncc_values, ssim_values = evaluate_registration(template, images)

# plot most and least similar images
visualise_registration_results(images, images, template, l1_losses)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

作为比较,这里是模板图像:

plt.imshow(template, cmap="gray");

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

寻找最佳模板图像

显然,选择第一张眼底图像作为 固定 或“模板”图像可能并不理想。如果第一张图像质量差,或者旋转,或者与大多数需要配准的图像差异很大,这将导致结果不佳、大的仿射变换和高的“死”图像区域。因此,我们需要某种方法来选择模板图像。我们可以有几种不同的想法来实现这一点:

  • 计算每张图像与数据集中所有其他图像的累积 L2 距离,并选择结果最低的那一张。这代表了与所有其他图像“最接近”的图像。

  • 重复上述过程,但这次创建一个累积 L2 距离的直方图。选择最好的 k 张图像,取平均值,并将其作为模板。

让我们从第一个想法开始。这个函数循环遍历每张图像,计算与所有其他图像的聚合 L2 距离。

def calculate_total_rmse(images):
    n = len(images)
    sum_rmse = np.zeros(n)
    for i in range(n):
        for j in range(i+1, n):
            rmse = np.sqrt(np.mean((images[i] - images[j])**2))
            sum_rmse[i] += rmse
            sum_rmse[j] += rmse
    return sum_rmse

patient_id = '123456'
laterality = 'L'
# Set the root directory for the patient data
root_dir = Path(f'../data/{patient_id}')
# Get the list of image filenames for the left eye
image_filenames = [f for f in os.listdir(root_dir) if f'{laterality}.png' in f]
# Read the images into a list
images = [cv2.imread(str(root_dir / f)) for f in image_filenames]
# Convert the images to grayscale
gray_images = [rgb2gray(img) for img in images]
# Remove invalid images
final_images = [x for x in gray_images[1:] if x.shape == (768, 768)]
# Calculate the RMSEs
rmses = calculate_total_rmse(final_images)

让我们看看四张总 RMSE 最低的图像:

images = final_images
sorted_indices = [i[0] for i in sorted(enumerate(rmses), key=lambda x:x[1])]
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
for i in range(4):
    ax[i//2][i%2].imshow(images[sorted_indices[i]], cmap='gray')
    ax[i//2][i%2].set_title("RMSE: {:.2f}".format(rmses[sorted_indices[i]]))
    ax[i//2][i%2].axis("off")
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

现在让我们尝试第二种方法。首先,看看总 RMSE 的直方图:

# Plot the histogram
sns.set_style("whitegrid")
sns.displot(rmses, kde=False)
plt.show()
plt.rcdefaults()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

我们可以选择最好的 15 张图像(所有图像的 RMSE 都低于 10):

plt.rcdefaults()

def get_best_images(images, rmses, num_images=10):
    sorted_indices = sorted(range(len(rmses)), key=lambda x: rmses[x])
    best_indices = sorted_indices[:num_images]
    return [images[i] for i in best_indices]
best_images = get_best_images(images, rmses)
av_img = np.mean(best_images, axis=0)
plt.imshow(av_img, cmap='gray')
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

显然,我们用来取平均的图像越多,最终图像就会越模糊。

num_images_range = np.linspace(4, 36, 9, dtype=int)
best_images_list = []
for num_images in num_images_range:
    best_images = get_best_images(images, rmses, num_images)
    av_img = np.mean(best_images, axis=0)
    best_images_list.append(av_img)

fig, axs = plt.subplots(3, 3, figsize=(12,12))
for i, av_img in enumerate(best_images_list):
    row, col = i//3, i%3
    axs[row, col].imshow(av_img, cmap='gray')
    axs[row, col].axis('off')
    axs[row, col].set_title(f"Best {num_images_range[i]} images")
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

让我们选择最好的 8 张图像,并将其作为我们的模板图像。

算法

这是项目中的实际工作马驹:注册算法本身。

刚性

刚性配准是医学图像分析中的一种基本技术,它通过应用平移、旋转和缩放来对齐两张或多张图像。这是一个将图像变换的过程,目的是保持图像中对应点之间的距离不变。刚性配准的目标是找到最佳变换,以最小化图像之间的差异,同时保持基础结构的解剖一致性。刚性配准有多个应用,包括图像融合、图像引导手术和纵向研究,并且是更高级配准技术的关键预处理步骤。

import SimpleITK as sitk
import numpy as np

def rigid(fixed_image, moving_image):
    # Convert the input images to SimpleITK images
    fixed_image = sitk.GetImageFromArray(fixed_image)
    moving_image = sitk.GetImageFromArray(moving_image)
    # Create a rigid registration method and set the initial transform to the identity
    registration_method = sitk.ImageRegistrationMethod()
    initial_transform = sitk.Euler2DTransform()
    initial_transform.SetMatrix(np.eye(2).ravel())
    initial_transform.SetTranslation([0, 0])
    registration_method.SetInitialTransform(initial_transform)
    # Set the number of iterations and the learning rate for the optimization
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100)
    # Use mean squared error as the similarity metric
    registration_method.SetMetricAsMeanSquares()
    # Execute the registration
    final_transform = registration_method.Execute(fixed_image, moving_image)
    # Transform the moving image using the final transform
    registered_image = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelIDValue())
    # Convert the registered image back to a Numpy array
    registered_image = sitk.GetArrayFromImage(registered_image)
    return registered_image
opt = RegistrationAlgorithm(rigid)
l1_losses, ncc_values, ssim_values = opt.evaluate(opt.template, opt.registered_images)
print("L1 losses:", f"{np.mean(l1_losses):.2f}")
print("Normalized cross-correlation values:", f"{np.mean(ncc_values):.2f}")
print("Structural similarity index values:", f"{np.mean(ssim_values):.2f}")
L1 losses: 0.14
Normalized cross-correlation values: 0.56
Structural similarity index values: 0.55
images, template = retrieve_images()
summarise_registration(images, opt.registered_images, template)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与原始图像相比,我们的指标较差。让我们看看实际发生了什么:

visualise_registration_results(opt.registered_images, images, template, l1_losses)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

有趣。因此,指标通常较差,但这样出现的原因是因为这些指标比较的是移动图像上的大黑色区域。我们可能需要包括与上述相同的指标,但排除完全黑色的像素进行比较。

def evaluate_registration(template_img: np.ndarray, registered_imgs: List[np.ndarray]):
    """
    Evaluate the registration quality of multiple registered images with respect to a template image.
    """
    l1_losses = []
    ncc_values = []
    ssim_values = []
    l1_losses_excl_black = []
    ncc_values_excl_black = []
    ssim_values_excl_black = []

    for i, registered_img in enumerate(registered_imgs):

        # Create mask of non-black pixels in original image
        mask = (registered_img.ravel() != 0.0)

        # Compute L1 loss between the template and registered images
        l1_loss = np.mean(np.abs(template_img - registered_img))
        l1_losses.append(l1_loss)

        # Compute L1 loss between the template and registered images, excluding black pixels
        l1_loss_excl_black = np.mean(np.abs(template_img.ravel()[mask] - registered_img.ravel()[mask]))
        l1_losses_excl_black.append(l1_loss_excl_black)

        # Compute normalized cross-correlation between the template and registered images
        ncc = np.corrcoef(template_img.ravel(), registered_img.ravel())[0,1]
        ncc_values.append(ncc)

        # Compute normalized cross-correlation between the template and registered images, excluding black pixels
        ncc_excl_black = np.corrcoef(template_img.ravel()[mask], registered_img.ravel()[mask])[0,1]
        ncc_values_excl_black.append(ncc_excl_black)

        # Compute structural similarity index between the template and registered images
        ssim_value = ssim(template_img, registered_img, data_range=registered_img.max() - registered_img.min())
        ssim_values.append(ssim_value)

        # Compute structural similarity index between the template and registered images, excluding black pixels
        ssim_value_excl_black = ssim(template_img.ravel()[mask], registered_img.ravel()[mask], 
                                     data_range=registered_img.ravel()[mask].max() - registered_img.ravel()[mask].min())
        ssim_values_excl_black.append(ssim_value_excl_black)

    return l1_losses, ncc_values, ssim_values, l1_losses_excl_black, ncc_values_excl_black, ssim_values_excl_black

def summarise_registration(original_images, registered_images, template):

    # Calculate metrics for original images
    l1_losses, ncc_values, ssim_values, l1_losses_black, ncc_values_black, ssim_values_black = evaluate_registration(template, original_images)
    l1_original, ncc_original, ssim_original = np.mean(l1_losses), np.mean(ncc_values), np.mean(ssim_values)
    l1_black_original, ncc_black_original, ssim_black_original = np.mean(l1_losses_black), np.mean(ncc_values_black), np.mean(ssim_values_black)

    # Calculate metrics for registered images
    l1_losses, ncc_values, ssim_values, l1_losses_black, ncc_values_black, ssim_values_black = evaluate_registration(template, registered_images)
    l1_registered, ncc_registered, ssim_registered = np.mean(l1_losses), np.mean(ncc_values), np.mean(ssim_values)
    l1_black_registered, ncc_black_registered, ssim_black_registered = np.mean(l1_losses_black), np.mean(ncc_values_black), np.mean(ssim_values_black)

    # Create dataframe
    df_dict = {'original': {'l1': l1_original, 'ncc': ncc_original, 'ssim': ssim_original,
                            'l1_excl_black': l1_black_original, 'ncc_excl_black': ncc_black_original,
                            'ssim_excl_black': ssim_black_original}, 
               'registered': {'l1': l1_registered, 'ncc': ncc_registered, 'ssim': ssim_registered,
                              'l1_excl_black': l1_black_registered, 'ncc_excl_black': ncc_black_registered,
                              'ssim_excl_black': ssim_black_registered}}

    return style_df(df_dict)
images, template = retrieve_images()
summarise_registration(images, opt.registered_images, template)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

实际上没有显著改善。唯一绝对更好的指标是排除黑色像素的 SSIM。关于为什么会这样的一种理论是,通过排除黑色像素,我们也在排除视网膜,视网膜已经与大多数图像非常对齐,因此“抑制”了对齐良好的图像的指标。

光流

光流是计算机视觉中的一种基本技术,它估计视频序列中两个连续帧之间物体的运动。其假设是物体的像素强度在帧之间保持不变,并且物体的表观运动仅由于其实际运动。光流可以表示为一个 2D 矢量场(u,v),将一个速度矢量分配给图像中的每个像素。

光流场可以通过求解将图像亮度变化与像素运动相关联的方程组来计算。这些方程可以使用不同的方法求解,例如 Lucas-Kanade、Horn-Schunck 或 Farneback,每种方法都有其自身的优点和局限性。一旦计算出光流场,它可以用于通过将一幅图像扭曲以对齐另一幅图像来实现图像配准。

光流具有广泛的应用,包括物体跟踪、运动分析、视频稳定和视频压缩。然而,光流估计对图像噪声、遮挡和大位移非常敏感,这可能导致运动估计中的错误和不准确性。当前的研究集中在提高光流方法的准确性、鲁棒性和效率,以增强其在实际场景中的适用性。

让我们来看看这如何工作:

# --- Load the sequence
images, template = retrieve_images()
image0, image1 = images[0], template

# --- Convert the images to gray level: color is not supported.
#image0 = rgb2gray(image0)
#image1 = rgb2gray(image1)
# --- Compute the optical flow
v, u = optical_flow_tvl1(image0, image1)
# --- Use the estimated optical flow for registration
nr, nc = image0.shape
row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc),
                                     indexing='ij')
image1_warp = warp(image1, np.array([row_coords + v, col_coords + u]),
                   mode='edge')
# build an RGB image with the unregistered sequence
seq_im = np.zeros((nr, nc, 3))
seq_im[..., 0] = image1
seq_im[..., 1] = image0
seq_im[..., 2] = image0
# build an RGB image with the registered sequence
reg_im = np.zeros((nr, nc, 3))
reg_im[..., 0] = image1_warp
reg_im[..., 1] = image0
reg_im[..., 2] = image0
# build an RGB image with the registered sequence
target_im = np.zeros((nr, nc, 3))
target_im[..., 0] = image0
target_im[..., 1] = image0
target_im[..., 2] = image0
# --- Show the result
fig, (ax0, ax1, ax2) = plt.subplots(3, 1, figsize=(5, 10))
ax0.imshow(seq_im)
ax0.set_title("Unregistered sequence")
ax0.set_axis_off()
ax1.imshow(reg_im)
ax1.set_title("Registered sequence")
ax1.set_axis_off()
ax2.imshow(target_im)
ax2.set_title("Target")
ax2.set_axis_off()
fig.tight_layout()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

上述代码演示了使用光流进行图像配准。首先,代码加载一系列图像和模板。然后,将图像转换为灰度,并使用 TVL1 算法计算第一幅图像和模板之间的光流。计算出的光流矢量用于将模板图像配准到第一幅图像上。

要实现这一点,代码生成了模板图像的行和列坐标网格,并将光流矢量应用于这些行和列坐标,以获取第一幅图像中的相应位置。然后使用基于样条的图像变形函数,将这些变换后的坐标用于将模板图像扭曲到第一幅图像上。

代码然后生成 RGB 图像,以显示未配准序列、配准序列和目标图像(即第一幅图像)。未配准序列是一个 RGB 图像,其中第一幅图像和模板图像叠加在一起。配准序列是一个 RGB 图像,其中扭曲后的模板图像和第一幅图像叠加在一起。目标图像是一个仅包含第一幅图像的 RGB 图像。

最后,代码使用 Matplotlib 子图显示了三张 RGB 图像。第一个子图显示了未配准的序列,第二个子图显示了已配准的序列,第三个子图显示了目标图像。生成的图提供了未配准和已配准序列的视觉比较,突出了基于光流的配准方法的有效性。

估计的向量场(u,v)也可以通过箭头图进行显示。

# --- Compute the optical flow
v, u = optical_flow_ilk(image0, image1, radius=15)
# --- Compute flow magnitude
norm = np.sqrt(u ** 2 + v ** 2)
# --- Display
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(8, 4))
# --- Sequence image sample
ax0.imshow(image0, cmap='gray')
ax0.set_title("Sequence image sample")
ax0.set_axis_off()
# --- Quiver plot arguments
nvec = 20  # Number of vectors to be displayed along each image dimension
nl, nc = image0.shape
step = max(nl//nvec, nc//nvec)
y, x = np.mgrid[:nl:step, :nc:step]
u_ = u[::step, ::step]
v_ = v[::step, ::step]
ax1.imshow(norm)
ax1.quiver(x, y, u_, v_, color='r', units='dots',
           angles='xy', scale_units='xy', lw=3)
ax1.set_title("Optical flow magnitude and vector field")
ax1.set_axis_off()
fig.tight_layout()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

让我们实现算法。

def optical_flow(template, img):
    # calculate the vector field for optical flow
    v, u = optical_flow_tvl1(template, img)
    # use the estimated optical flow for registration
    nr, nc = template.shape
    row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc),
                                         indexing='ij')
    registered = warp(img, np.array([row_coords + v, col_coords + u]), mode='edge')
    return registered

opt = RegistrationAlgorithm(optical_flow)
images, template = retrieve_images()
summarise_registration(images, opt.registered_images, template).loc[['l1', 'ncc', 'ssim']]

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

显著提高了性能!让我们可视化一下:

images, template = retrieve_images()
visualise_registration_results(opt.registered_images, images, template, l1_losses)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

看起来光流在较难的图像上有些“作弊”,通过完全变形图像来实现。让我们看看是否可以改进这一点。

SimpleElastix

SimpleElastix 是一个开源的多平台软件库,提供了一个简单的接口来执行医学图像配准。图像配准是通过在图像之间找到空间映射来对齐两张或更多张图像的过程。SimpleElastix 提供了广泛的预实现配准组件,包括变换、相似性度量和优化器,这些组件可以轻松组合以创建配准管道。该库支持各种类型的配准,包括刚性、仿射、非刚性和组配准,并允许用户在不同的成像模态中配准图像,如 MRI、CT、PET 和显微镜。

SimpleElastix 的一个关键优点是其易用性。它提供了一个用户友好的高级接口,要求的编码知识很少,并且可以通过 Python 或 C++ 接口使用。此外,该库包括高级功能,如多分辨率优化、正则化和空间约束,这些功能提高了配准的准确性和鲁棒性。SimpleElastix 在医学影像研究和临床实践中被广泛使用,并在许多研究中得到了验证。它是一个有价值的工具,适用于广泛的应用,包括图像引导手术、纵向研究和图像分析。

刚性配准

如上所述,刚性变换能够对齐通过平移和旋转相关的对象。例如,在对齐患者骨骼的图像时,刚性变换通常足以对齐这些结构。尽可能使用简单的变换是有利的,因为这减少了可能的解决方案数量,并且最小化了可能影响配准结果准确性的非刚性局部极小值的风险。这种方法可以看作是在配准过程中融入领域专长的一种手段。

让我们看看单个已配准的图像:

import SimpleITK as sitk
from IPython.display import clear_output
from IPython.display import Image

images, template = retrieve_images()
elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(sitk.GetImageFromArray(images[0]))
elastixImageFilter.SetMovingImage(sitk.GetImageFromArray(template))
elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("rigid"))
elastixImageFilter.Execute()
clear_output()
sitk.WriteImage(elastixImageFilter.GetResultImage(), 'test.tif')
# load image with SimpleITK
sitk_image = sitk.ReadImage('test.tif')
# convert to NumPy array
im = sitk.GetArrayFromImage(sitk_image)
plt.imshow(im, cmap='gray');

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

现在,让我们使用上面的架构来应用和验证刚性注册:

def simple_elastix_rigid(image, template):
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(sitk.GetImageFromArray(image))
    elastixImageFilter.SetMovingImage(sitk.GetImageFromArray(template))
    elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("rigid"))
    elastixImageFilter.Execute()
    clear_output()
    sitk.WriteImage(elastixImageFilter.GetResultImage(), 'reg.tif')
    # load image with SimpleITK
    sitk_image = sitk.ReadImage('reg.tif')
    # convert to NumPy array
    registered_img = sitk.GetArrayFromImage(sitk_image)
    # delete the tif file
    os.remove('reg.tif')
    return registered_img
# retrieve images to be registered, and the image to register to
images, template = retrieve_images()

# perform the registration using SimpleElastix
opt = RegistrationAlgorithm(simple_elastix_rigid)

可视化结果:

l1_losses, ncc_values, ssim_values = opt.evaluate(opt.template, opt.registered_images)
visualise_registration_results(opt.registered_images, images, template, l1_losses)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

最后,让我们审视这些度量:

images, template = retrieve_images()
summarise_registration(images, opt.registered_images, template)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

尽管 L1 损失相似,SimpleElastix 的刚性注册显著改善了 NCC 和结构相似性损失。

仿射注册

非常类似于刚性注册,仿射变换允许我们在旋转和平移之外进行剪切和缩放。通常,仿射注册作为非刚性变换之前的初步预处理步骤使用。

def simple_elastix_affine(image, template):
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(sitk.GetImageFromArray(image))
    elastixImageFilter.SetMovingImage(sitk.GetImageFromArray(template))
    elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("affine"))
    elastixImageFilter.Execute()
    clear_output()
    sitk.WriteImage(elastixImageFilter.GetResultImage(), 'reg.tif')
    # load image with SimpleITK
    sitk_image = sitk.ReadImage('reg.tif')
    # convert to NumPy array
    registered_img = sitk.GetArrayFromImage(sitk_image)
    # delete the tif file
    os.remove('reg.tif')
    return registered_img
# retrieve images to be registered, and the image to register to
images, template = retrieve_images()

# perform the registration using SimpleElastix
opt = RegistrationAlgorithm(simple_elastix_affine)
l1_losses, ncc_values, ssim_values = opt.evaluate(opt.template, opt.registered_images)
visualise_registration_results(opt.registered_images, images, template, l1_losses

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

images, template = retrieve_images()
summarise_registration(images, opt.registered_images, template)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

略好于刚性变换。

非刚性注册

非刚性注册技术能够对齐需要局部变形的图像,使其更适合处理患者之间的解剖、生理和病理变化。

为了参数化自由形状变形 (FFD) 场,通常使用 B-splines。FFD 场的注册比简单变换复杂得多。参数空间维度的增加使得解决这个问题具有挑战性,因此推荐使用多分辨率方法。以仿射初始化开始也有助于简化注册。在 SimpleElastix 中,实施多分辨率方法非常简单。

以下代码运行多分辨率仿射初始化,然后应用 B-spline 非刚性注册变换。

def simple_elastix_nonrigid(image, template):

    # Initialise the filter, as well as fixed and moving images
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(sitk.GetImageFromArray(image))
    elastixImageFilter.SetMovingImage(sitk.GetImageFromArray(template))

    # Setup the initialisation and transforms 
    parameterMapVector = sitk.VectorOfParameterMap()
    parameterMapVector.append(sitk.GetDefaultParameterMap("affine"))
    parameterMapVector.append(sitk.GetDefaultParameterMap("bspline"))
    elastixImageFilter.SetParameterMap(parameterMapVector)

    # Execute and save
    elastixImageFilter.Execute()
    clear_output()
    sitk.WriteImage(elastixImageFilter.GetResultImage(), 'reg.tif')
    # load image with SimpleITK
    sitk_image = sitk.ReadImage('reg.tif')
    # convert to NumPy array
    registered_img = sitk.GetArrayFromImage(sitk_image)
    # delete the tif file
    os.remove('reg.tif')
    return registered_img
# retrieve images to be registered, and the image to register to
images, template = retrieve_images()

# perform the registration using SimpleElastix
opt = RegistrationAlgorithm(simple_elastix_nonrigid)
l1_losses, ncc_values, ssim_values = opt.evaluate(opt.template, opt.registered_images)
visualise_registration_results(opt.registered_images, images, template, l1_losses)
images, template = retrieve_images()
summarise_registration(images, opt.registered_images, template)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

所以 SSIM 比仿射变换稍差,但 NCC 绝对更好。

群体注册

群体注册方法在医学影像中用于解决将一张图像注册到选定参考框架时的不确定性。相反,所有图像同时注册到一个位于群体中心的平均参考框架。该方法使用三维或四维 B-spline 变形模型和一个相似度度量,该度量最小化强度方差,同时确保所有图像的平均变形为零。该方法还可以结合变形的时间平滑性和时间维度上的循环变换,这在解剖运动具有周期性特征的情况下非常有用,例如心脏或呼吸运动。通过使用此方法,消除了对特定参考框架的偏倚,从而实现了图像的更准确和无偏注册。然而,该方法计算量巨大,未进行并行处理,因此在这里为了效率而未作介绍。

2D Voxelmorph 和空间变换网络

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

空间变换网络(STN)是一种神经网络架构,能够学习在空间上变换图像,以提高下游任务的性能。特别地,STN 能够自动学习裁剪、旋转、缩放和扭曲输入图像,以适应当前任务的最佳方式。这是通过学习估计每个输入图像的一组仿射变换参数来实现的,这些参数可用于将图像扭曲成新的配置。

在下面的代码中,STN 作为一个模块实现于一个更大的神经网络中,该网络包括几个卷积层和全连接层。STN 由两个组件组成:定位网络和回归器。

定位网络是一组卷积层,用于从输入图像中提取一组特征。这些特征随后被输入到回归器中,回归器是一组用于估计仿射变换参数的全连接层。在提供的代码中,回归器由两个带 ReLU 激活函数的线性层组成。

STN 模块还包括一个stn方法,该方法接受输入图像并通过双线性插值将学习到的仿射变换应用于图像。stn方法在更大神经网络的前向方法中被调用,用于对变换后的输入进行预测。

总的来说,STN 模块提供了一个强大的工具,用于学习对输入图像进行空间变换,这可以用于提高各种图像处理和计算机视觉任务的性能。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 188 * 188, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )
        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 188 * 188)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x
    def forward(self, x):
        # transform the input
        x = self.stn(x)
        return x

model = Net().to(device)

我们还将尝试一种不同的可微分损失,这种损失可能比之前的度量方法更适合图像配准。给定的代码定义了一种名为 voxelmorph 损失的自定义损失函数,用于 2D 图像配准。该损失函数由两个组件组成:重建损失和平滑惩罚。

重建损失衡量源图像与目标图像之间的不同。它计算为两幅图像之间的绝对差的平均值,并按目标权重加权。源图像和目标图像是配准网络的输入图像,其中源图像被变换以对齐目标图像。

平滑惩罚通过惩罚源图像和目标图像之间的空间变化形变来鼓励平滑变换。该惩罚通过计算目标图像在 x 和 y 方向上梯度的绝对差的平均值来计算,并按平滑权重加权。这个惩罚项有助于避免形变场中的急剧变化,这可能导致过拟合并对新图像的泛化能力差。

总体的 voxelmorph 损失是重建损失和平滑惩罚的总和。通过在训练期间使用基于梯度的优化器来优化损失,以提高配准网络的准确性。

Voxelmorph 损失函数因其处理大变形、多模态图像和个体差异的能力而在医学图像配准中被广泛使用。它对于图像的可变形配准尤为有用,其目标是对齐具有显著形状变化的图像。损失函数中的平滑性惩罚项有助于规范化变形场,并提高配准的准确性。

def voxelmorph_loss_2d(source, target, source_weight=1, target_weight=1, smoothness_weight=0.001):
    def gradient(x):
        d_dx = x[:, :-1, :-1] - x[:, 1:, :-1]
        d_dy = x[:, :-1, :-1] - x[:, :-1, 1:]
        return d_dx, d_dy

    def gradient_penalty(x):
        d_dx, d_dy = gradient(x)
        return (d_dx.abs().mean() + d_dy.abs().mean()) * smoothness_weight

    reconstruction_loss = (source - target).abs().mean() * target_weight
    smoothness_penalty = gradient_penalty(target)
    return reconstruction_loss + smoothness_penalty

下面的代码定义了一个 PyTorch 数据集类,名为 FundusDataset,用于加载和预处理用于神经网络的训练图像。数据集类接受训练图像列表和目标图像作为输入,并返回一个图像及其对应的目标图像,以便在训练过程中使用。

class FundusDataset(Dataset):
    def __init__(self, image_list, target_image):
        self.image_list = image_list
        self.target_image = target_image

    def __len__(self):
            return len(self.image_list)

    def __getitem__(self, idx):
        image = self.image_list[idx]
        image = torch.from_numpy(image).float()
        return image, self.target_image

# Load your list of Numpy arrays of training images
training_images, template = retrieve_images()
template_image = torch.from_numpy(template).float()
# Create the dataset
dataset = FundusDataset(training_images, template_image)
# Create the data loader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

现在,让我们编写一个简短的训练循环:

optimizer = optim.SGD(model.parameters(), lr=0.05)
criterion = voxelmorph_loss_2d #nn.L1Loss() #nn.MSELoss()

def train(epoch):
    model.train()
    batch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data = data.unsqueeze(1)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.reshape(output.shape[0], 768, 768), target)
        batch_loss += loss.item()
        loss.backward()
        optimizer.step()
    if epoch % 1 == 0:
        avg_loss = batch_loss / len(train_loader)
        print('Train Epoch: {}, Average Loss: {:.6f}'.format(epoch, avg_loss))
for epoch in range(1, 5 + 1):
    train(epoch)

最后,我们定义了一个名为 convert_image_np 的 Python 函数,将 PyTorch 张量转换为 numpy 图像。该函数以 PyTorch 张量为输入,应用标准归一化程序,通过减去均值并除以标准差值来完成归一化。生成的 numpy 图像随后被裁剪到 0 和 1 之间。

代码接着定义了一个名为 visualize_stn 的函数,用于在训练过程中可视化空间变换网络(STN)层的输出。使用 matplotlib 库中的 subplots() 函数将生成的输入和变换后的 numpy 图像并排绘制。图中左侧显示输入图像,右侧显示对应的变换图像。

def convert_image_np(inp):
    """Convert a Tensor to numpy image."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.

def visualize_stn():
    with torch.no_grad():
        # Get a batch of training data
        data = next(iter(train_loader))[0].to(device)
        data = data.unsqueeze(1)
        input_tensor = data.cpu()
        transformed_input_tensor = model.stn(data).cpu()
        in_grid = convert_image_np(
            torchvision.utils.make_grid(input_tensor))
        out_grid = convert_image_np(
            torchvision.utils.make_grid(transformed_input_tensor))
        # Plot the results side-by-side
        f, axarr = plt.subplots(1, 2, figsize=(20,20))
        axarr[0].imshow(in_grid, cmap='gray')
        axarr[0].set_title('Dataset Images')
        axarr[1].imshow(out_grid, cmap='gray')
        axarr[1].set_title('Transformed Images')
# Visualize the STN transformation on some input batch
visualize_stn()
plt.ioff()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

显然,网络已经在某种程度上移动了图像,但对于与模板图像差异过大的眼底图像仍然存在困难。

超大规模配准

为了完整性,我在这里包括了一些我尝试过的事情,以使 STN 的效果更好。

图像增强

我直觉上认为,图像增强可以通过增加训练数据的多样性和数量来改善 STN 在学习图像配准变换中的性能。STN 依赖大量的训练数据来学习图像之间的复杂空间变换。然而,由于患者数据有限和成像模式的变异等因素,获取足够大且多样的医学图像数据集可能具有挑战性。此外,我每个患者的数据有限,因此训练 STN 仅对许多患者的组合数据可行。

图像增强通过对现有图像应用多种图像变换来生成合成训练数据,提供了一个解决方案。这增加了训练数据集的大小和多样性,使 STN 能够学习更强大且具有更好泛化能力的配准变换。图像增强还可以帮助 STN 学习对某些成像条件如光照、对比度和噪声变化不变的变换。

常见的图像增强技术包括随机旋转、平移、缩放和翻转,以及更复杂的变换,如弹性形变和强度变化。这些变换在训练过程中随机应用,以生成与原始图像相似的各种变换图像。然后使用增强的图像来训练 STN,从而提高其对新图像的泛化能力。

import numpy as np
from PIL import Image
import cv2
import random
import torchvision.transforms as transforms

def image_augmentation(images, base_index=0, n_aug=5):
    # Convert the NumPy arrays to Pillow Image objects
    items = [Image.fromarray(image).convert("RGBA") for image in images]
    # Define the image transformation pipeline
    transform = transforms.Compose([
        transforms.Resize(460),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomAffine(degrees=0, translate=(0.2, 0.2),
                                scale=(0.9, 1.1), shear=0,
                                fillcolor=(128, 128, 128, 255)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    # Generate the augmented images
    new_items = []
    for i in range(n_aug):
        # Get the base image
        base_item = items[base_index]
        base_image = np.array(base_item)
        # Apply the random transforms to the base image
        transformed_item = transform(base_item)
        # Convert the transformed image to a NumPy array and add it to the list of augmented images
        transformed_image = np.transpose(transformed_item.numpy(), (1, 2, 0))
        transformed_image = cv2.cvtColor(transformed_image, cv2.COLOR_RGB2BGR)
        new_items.append(transformed_image)
    # Convert the augmented data back to NumPy arrays
    new_images = [np.array(image) for image in new_items]
    return new_images

然后你可以将这个函数应用到图像列表中,并传入扩展的数据集进行训练。

使用 k 最近邻的聚类模型

以下代码实现了对存储为 NumPy 数组的一组图像进行 k-means 聚类。代码的目的是找到最佳的簇数,以最佳地表示图像集。

代码首先将图像列表转换为 2D NumPy 数组,然后将数组重塑为 2D 形状。这是为了创建一个可以输入到 k-means 聚类算法的数据集。然后,对一系列的 k 值(其中 k 是要生成的簇的数量)运行 k-means 算法。对于每个 k 值,运行算法,并计算簇内平方和(WCSS)。WCSS 是衡量每个簇内数据点分散程度的指标,用于评估聚类质量。WCSS 值存储在一个列表中,并对所有 k 值重复此过程。

一旦计算出 WCSS 值,就会生成一个肘部图来可视化簇数与 WCSS 值之间的关系。肘部图展示了一个下降的曲线,并到达一个肘部点,在此点 WCSS 值的下降速度开始平缓。最佳的簇数被选择为曲线开始平缓的值。

from sklearn.cluster import KMeans

# Assume you have a list of images stored as numpy arrays in a variable called "images"
images, template = retrieve_images()
# Convert the list of images to a 2D numpy array
data = np.array(images)
n_samples, height, width = data.shape
data = data.reshape(n_samples, height * width)
# Set up an empty list to hold the within-cluster sum of squares (WCSS) values for each value of k
wcss_values = []
# Set up a range of values for k
k_values = range(1, 11)
# Loop over the values of k and fit a k-means model for each value
for k in k_values:
    kmeans = KMeans(n_clusters=k, random_state=0)
    kmeans.fit(data)

    # Calculate the within-cluster sum of squares (WCSS)
    wcss = kmeans.inertia_
    wcss_values.append(wcss)

# Plot the WCSS values against the number of clusters
fig, ax = plt.subplots()
ax.plot(k_values, wcss_values, 'bo-')
ax.set_xlabel('Number of clusters (k)')
ax.set_ylabel('Within-cluster sum of squares (WCSS)')
ax.set_title('Elbow Plot')
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

从这个图中来看,最佳的簇数可能是三个。让我们用这个来对我们的图像进行分组:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors

# Assume you have a list of images stored as NumPy arrays in a variable called "images"
images, template = retrieve_images()
# First, flatten each image into a 1D array
image_vectors = np.array([image.flatten() for image in images])
# Use k-means to cluster the image vectors
kmeans = KMeans(n_clusters=3, random_state=0).fit(image_vectors)
cluster_labels = kmeans.labels_
# Use k-nearest neighbors to find the nearest images to each centroid
n_neighbors = 5
nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(image_vectors)
distances, indices = nn.kneighbors(kmeans.cluster_centers_)
# Plot the nearest images to each centroid
fig, axs = plt.subplots(kmeans.n_clusters, n_neighbors, figsize=(15, 15))
for i in range(kmeans.n_clusters):
    for j in range(n_neighbors):
        axs[i][j].imshow(images[indices[i][j]], cmap='gray')
        axs[i][j].axis('off')
        axs[i][j].set_title(f"Cluster {i}, Neighbor {j+1}")
plt.show()
# Store the cluster labels and image labels in a dictionary
labels_dict = {}
for i in range(len(images)):
    labels_dict[i] = {"cluster_label": cluster_labels[i]}

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

这看起来不错。让我们为每个簇创建一个新列表:

# Assume you have a list of images stored as NumPy arrays in a variable called "images"
images, template = retrieve_images()

# First, flatten each image into a 1D array
image_vectors = np.array([image.flatten() for image in images])

# Use k-means to cluster the image vectors
kmeans = KMeans(n_clusters=3, random_state=0).fit(image_vectors)
cluster_labels = kmeans.labels_

# Use k-nearest neighbors to find the nearest images to each centroid
n_neighbors = 5
nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(image_vectors)
distances, indices = nn.kneighbors(kmeans.cluster_centers_)

# Store the images in each cluster
cluster_0_images = []
cluster_1_images = []
cluster_2_images = []
for i, cluster_label in enumerate(cluster_labels):
    if cluster_label == 0:
        cluster_0_images.append(images[i])
    elif cluster_label == 1:
        cluster_1_images.append(images[i])
    else:
        cluster_2_images.append(images[i])

# Print the number of images in each cluster
print(f"Number of images in cluster 0: {len(cluster_0_images)}")
print(f"Number of images in cluster 1: {len(cluster_1_images)}")
print(f"Number of images in cluster 2: {len(cluster_2_images)}")
Number of images in cluster 0: 38
Number of images in cluster 1: 31
Number of images in cluster 2: 32
trained_models = {}

for i, cluster_images in enumerate([cluster_0_images, cluster_1_images, cluster_2_images]):
    # load the spatial transformer network
    model = Net().to(device)
    template_image = torch.from_numpy(template).float()
    # Create the dataset
    dataset = FundusDataset(cluster_images, template_image)
    # Create the data loader
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=0.05)
    criterion = voxelmorph_loss_2d

    print(f"TRAINING CLUSTER {i}")
    print("-"*50)
    for epoch in range(1, 5 + 1):
        train(epoch)

    trained_models[f"cluster_{i}_model"] = model
    print(" ")
TRAINING CLUSTER 0
--------------------------------------------------
Train Epoch: 1, Average Loss: 0.122363
Train Epoch: 2, Average Loss: 0.120482
Train Epoch: 3, Average Loss: 0.117854
Train Epoch: 4, Average Loss: 0.112581
Train Epoch: 5, Average Loss: 0.119771

TRAINING CLUSTER 1
--------------------------------------------------
Train Epoch: 1, Average Loss: 0.079963
Train Epoch: 2, Average Loss: 0.080274
Train Epoch: 3, Average Loss: 0.077222
Train Epoch: 4, Average Loss: 0.076657
Train Epoch: 5, Average Loss: 0.077101

TRAINING CLUSTER 2
--------------------------------------------------
Train Epoch: 1, Average Loss: 0.172036
Train Epoch: 2, Average Loss: 0.171105
Train Epoch: 3, Average Loss: 0.170653
Train Epoch: 4, Average Loss: 0.170199
Train Epoch: 5, Average Loss: 0.169759

实际上没有帮助。这也违背了 STN 的整体思想,即使用卷积层来确定应用于哪些图像的变换,即一些图像在变换中需要显著更高的权重。

结论

总之,对医学图像配准技术的评估表明,虽然现代方法如空间变换网络(STNs)提供了有希望的结果,但需要大量投资才能实现预期效果。相比之下,传统技术如 SimpleElastix 被证明更有效且高效。尽管实施了各种策略来提高 STN 的性能,但模型未能学习到足够的权重来调整目标,显示出需要替代的损失函数。一种方法可能是忽略由于仿射变换产生的黑色像素。此外,点对点配准利用视网膜或血管等生物标记来指导配准过程,对特定应用可能更有利。因此,需要进一步研究以确定最合适的配准方法,基于具体问题和可用资源。

参考文献

数据来自 Kaggle 视网膜基金图像配准数据集,该数据集的许可证为国际署名 4.0 (CC BY 4.0)。

  1. Kaggle: 视网膜基金图像配准 数据集

  2. Jaderberg, M., Simonyan, K., & Zisserman, A. (2015). 空间变换网络。神经信息处理系统的进展, 28

  3. Hill, D. L., Batchelor, P. G., Holden, M., & Hawkes, D. J. (2001). 医学图像配准。医学与生物物理学, 46(3), R1。

  4. Brown, L. G. (1992). 图像配准技术的综述。ACM 计算机调查(CSUR), 24(4), 325–376。

  5. Lee, M. C., Oktay, O., Schuh, A., Schaap, M., & Glocker, B. (2019). 图像和空间变换网络用于结构引导的图像配准。在医学图像计算与计算机辅助干预–MICCAI 2019: 第 22 届国际会议,深圳,中国,2019 年 10 月 13–17 日,会议录,第二十二部分(第 337–345 页)。Springer International Publishing。

  6. Pytorch: 空间变换网络教程

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值