迈向视觉中的独立自注意力
原文:
towardsdatascience.com/towards-stand-alone-self-attention-in-vision-3d0561c6aee5
深入探讨变换器架构及其自注意力操作在视觉中的应用
·发表于Towards Data Science ·阅读时间 14 分钟·2023 年 4 月 28 日
–
图片由作者使用craiyon AI创建。
虽然自注意力已经在自然语言处理(NLP)中广泛应用,并显著提升了最先进模型的性能(如[2],[3]),但在视觉领域也正在进行越来越多的工作,以实现类似的效果。
尽管有一些混合方法将 CNN 与注意力机制结合[4],或对图像的图块应用线性变换[5],但纯注意力模型由于各种原因更难以有效训练,我们将进一步探讨这些原因。
《视觉模型中的独立自注意力》[6] 论文提出了这样一个纯注意力模型的理念。接下来,我将概述论文的思想及相关后续工作。此外,我假设你对变换器的工作原理已有所了解,并具备CNN 的基础知识。了解PyTorch对编码部分也有帮助,但这些部分也可以安全跳过。
如果你只是对代码感兴趣,可以直接跳过这篇文章,直接查看 这个带注释的 colab 笔记本。
自注意力在视觉中的应用
CNN 通常用于构建图像处理的神经网络,因为它们具有强大的几何平移等变先验。这意味着它们能够很好地处理输入的相对位移,使其具有鲁棒性。
另一方面,自注意力没有这种先验,而是具有置换等变性。这意味着如果输入被重新排列,输出也会以等效的方式重新排列。尽管置换等变性更为通用,但对于图像而言,它不如平移等变性有用。
幸运的是,我们可以使用不同的位置编码来约束自注意力操作,并实现平移等变性。位置编码——当它具有可学习参数时也称为位置嵌入——使我们能够拥有比 CNN 更灵活的架构,同时仍然能够融入某些先验知识。
实现 1D 中的基本自注意力
对于一维输入,如文本和语音,单头自注意力操作的定义为
缩放点积注意力,如 [1] 中提出的
这本质上是查询 Q 和键 K 之间的缩放点积,然后是结果矩阵与 V 之间的另一个点积。
我们还可以将点积明确地表示为加权和,并展示如何获得特定输出。请记住这一点,因为稍后我们将把它推广到 2D 图像。
特定输出 yᵢ 的自注意力
在 PyTorch 中,这可能看起来如下。
import torch
import torch.nn as nn
import torch.nn.functional as F
# for some einsum magic
from einops import rearrange, einsum
# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
embedding_dim_k = 10
# linear projection of the input x
key = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)
query = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)
value = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)
# creating random vector of shape:
# batch_size (b), sequence lenght (t), embedding dim (k)
x = torch.randn(1, 12, embedding_dim_k) # b t k
d_b, d_t, d_k = x.size()
# linear projection of the input
q = query(x) # b, t, k
k = key(x) # b, t, k
v = value(x) # b, t, k
assert q.shape == (d_b, d_t, d_k)
# scaled dot-product self-attention
# dot_prod(Q, K)
scaling_factor = 1/torch.sqrt(torch.tensor(d_k))
scaled_dot_product = F.softmax(
einsum(q, k, "b t k, b l k -> b t l") * scaling_factor, dim=-1 )
assert scaled_dot_product.shape == (d_b, d_t, d_t)
# dot-prod(w, v)
self_attention = torch.einsum('b i j , b j d -> b i d', scaled_dot_product, v)
# remember that self-attention is a seq2seq operation:
# the size that goes in, also goes out
assert self_attention.shape == (d_b, d_t, d_k)
全局与局部自注意力
当我们谈论视觉模型中的全局自注意力和局部自注意力时,我们指的是模型观察图像的范围。全局自注意力一次查看整个图像,而局部自注意力只关注某些部分。通常,模型观察的区域越大,复杂度就越高,所需的内存也越多。
让我们仔细看看基本的自注意力操作及其在更大图像尺寸下的表现。为此,我们将使用一种叫做 大 O 符号 的概念来表达操作的复杂度,随着输入大小 n 的增加。
自注意力操作涉及三个单独的计算:
-
计算 QKᵀ 的复杂度为 O(n² d_k)
-
包含指数运算、求和和除法的 softmax 操作具有 O(n²) 的平方复杂度
-
乘以 softmax(QKᵀ)V 的复杂度为 O(n² d_v)
总的来说,基本自注意力操作的复杂度随着输入序列长度 n 的增加而呈平方增长。因此,当我们将自注意力应用于越来越大的图像——由于其 2D 特性,这些图像的长度大约为 n² = hw* ——操作的空间和时间复杂度也会随之增加。这是为什么在更大的图像上使用全局感受野可能会很困难,而局部感受野则是一种有吸引力的解决方案的原因之一。
重新审视 CNN
在图 1 中,我们可以看到我们使用了称为内核的小方块,这些方块在图像上滑动。我们选择图像上的中心点[i,j]和内核大小,这决定了内核包含图像的多少部分。内核应用于图像中的每个像素,值输入到同一个神经网络中,因此我们使用了更少的参数。注意,在图中,每个方块中有多个像素,但实际上,我们每个方块中只有一个像素,除非我们使用池化将它们分组在一起。
图 1:围绕点[i, j](红色方块)的局部卷积窗口示例,空间扩展 k=3。©J. Hatzky
内核的大小可以在网络的不同层之间变化。这使得网络能够在特定层内学习局部相关结构。在最近的工作中,引入了可变大小的差分内核[7],但我们将重点关注传统 CNN 中使用的基本方法。由于卷积内核是我们将要构建的重要概念,我将使用[6]中使用的符号来解释它。
输入图像由其高度h、宽度w和通道大小din(例如,RGB 图像为 3)指定:x ∈ ℝʰˣʷˣᵈⁱⁿ。我们使用空间范围 k 定义一个像素 xᵢⱼ周围的局部邻域 Nₖ,即内核范围内的像素集合。例如,N₃(2,2)将是以第 2 行第 2 列的像素为中心的 3x3 方块内的像素集合。为了完整性,我们可以定义为:Nₖ(i, j) = {a, b ∣ |a − i| ≤ k/2, |b − j| ≤ k/2}。我们优化一个权重矩阵 W ∈ ℝᵏˣᵏˣᵈᵒᵘᵗˣᵈⁱⁿ,以计算每个像素的特定输出 yᵢⱼ。
具有空间扩展 k 和中心点[i, j]的加权和
为了得到这个输出,我们对局部邻域内每个像素进行深度矩阵乘法的乘积求和。这一操作具有平移等变性,这意味着它旨在识别图像中无论出现在哪里的模式。
作为二维局部感受野的记忆块
为了在二维图像上执行自注意力,[6]中的研究人员提出了一个受 CNN 工作方式启发的记忆块概念。如果您想要全球应用自注意力,只需将记忆块做得与整个图像一样大即可。记忆块本质上与 CNN 中使用的感受野相同,但我们不是使用 CNN,而是对感受野 Nₖ中的像素应用自注意力操作,从而在局部记忆块中的任何像素对之间创建一个可学习的连接。
要为这种二维情况定义单头自注意力操作,我们可以使用以下方程:
特定输出 yᵢⱼ的自注意力
在失去 CNN 的平移等变性的同时,我们现在获得了自注意力的更一般的置换等变性。
让我们看看这在 PyTorch 中会是什么样子。
import torch
import torch.nn as nn
import torch.nn.functional as F
# for some einsum magic
from einops import rearrange, einsum
# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
# we create a random normal tensor as a placeholder for an RGB image with shapes:
# batch_size (b), channels (c), height (h), width (w)
img = torch.randn(1, 3, 28, 28) # b c h w
k = 3 # spatial extend of the memory block N
# we can extract memory blocks by using the pytorch unfold operation and rearranging the result
# we pad the image first to keep our old dimensions intact
stride = 1
padding = 1
memory_blocks = F.pad(img, [padding]*4).unfold(dimension=2, size=k,
step=stride).unfold(dimension=3, size=k, step=stride)
memory_blocks = rearrange(memory_blocks, "b c h w i j -> b h w c i j")
print(memory_blocks.shape)
print(f"We have {memory_blocks.shape[1]}x{memory_blocks.shape[2]} patches of shape: {memory_blocks.shape[2:]}")
# apply the self-attention for a specific ij:
i, j = (3, 4)
memory_block_ij = memory_blocks[:, i, j, : , :, :]
# we can flatten the memory blocks height and width
x = rearrange(memory_block_ij, "b h w c -> b (h w) c")
# our input dimension is the channel size
d_in = x.shape[-1]
d_out = d_in
# linear transformations to embed the input x
key = nn.Linear(d_in, d_out, bias=False)
query = nn.Linear(d_in, d_out, bias=False)
value = nn.Linear(d_in, d_out, bias=False)
d_b, d_t, d_k = x.size()
# linear projection of the input
q = query(x) # b, t, k
k = key(x) # b, t, k
v = value(x) # b, t, k
assert q.shape == (d_b, d_t, d_out)
# scaled dot-product self-attention
# dot_prod(Q, K)
scaling_factor = 1/torch.sqrt(torch.tensor(d_k))
scaled_dot_product = F.softmax(
einsum(q, k, "b t k, b l k -> b t l") * scaling_factor, dim=-1 )
assert scaled_dot_product.shape == (d_b, d_t, d_t)
# dot-prod(w, v)
self_attention = torch.einsum('b i j , b j d -> b i d', scaled_dot_product, v)
# remember that self-attention is a seq2seq operation:
# the size that goes in, also goes out
assert self_attention.shape == (d_b, d_t, d_out)
这个简单的实现有一个很大的缺点。我们在将自注意力应用于展平的内存块时丢失了所有的空间信息。解决这一问题的一种方法是添加位置嵌入——这是下一节的主题。
2D 相对位置嵌入
除了 2D 自注意力,[6] 还引入了相对嵌入的 2D 应用。相对嵌入在一维中最早由[8]引入,后来由例如[9]和[10]扩展。
使用相对嵌入,我们首先获得了一个强大的位置表示,其在泛化能力上可能优于绝对嵌入[8],适用于更大的图像(或在自然语言处理中的更长序列)。
此外,我们在模型中引入了一个强大的归纳偏置,即平移等变性,这在 CNN 的情况下已被证明非常有用。
相对位置嵌入在二维中的工作方式是为 x(列)和 y(行)方向定义相对索引。这里的相对意味着,索引应相对于被查询的像素 yᵢⱼ(图 2)。
图 2:特定像素 ab ∈ Nₖ(i, j) 的相对位置嵌入。©J. Hatzky
如[6]中提出的,行和列偏移量与嵌入 r 相关联,分别对应 (a-i) 和 (b-j),每个维度为 1/2dout*。然后将行和列偏移嵌入连接在一起形成这种空间相对注意力。
在自注意力操作中添加了相对位置嵌入。
本质上,我们在这里创建了一个包含相对位置的信息的嵌入矩阵,并将其添加到 QK 点积中的 softmax。
请参见下面如何在 PyTorch 中完成这项工作。请注意,还有更高效的方法来实现这一点,我们在这里不予讨论,因为我们坚持介绍的公式。
import torch
import torch.nn as nn
import torch.nn.functional as F
# for some einsum magic
from einops import rearrange, einsum
# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
# number of input channels (e.g. 3 for RGB)
in_channels = 3
# the embedding dim of the input projection (embedding_dim)
mid_channels = 22
# the number of attention heads
num_heads = 2
# the number of channels after projecting the heads together
out_channels = 8
# the maximum number of image pixels of a side(assuming squared images)
max_pos_embedding = 4
# create embeddings. if we want to keep the 2D representation of the
# input, we can do this by using 2D convolution
query = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
key = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
value = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
wout = nn.Conv2d(mid_channels * num_heads, out_channels, kernel_size=1, device=device)
# Define positional embeddings
row_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)
col_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)
# create relative indices
deltas = torch.arange(max_pos_embedding).view(1, -1) - torch.arange(max_pos_embedding).view(
-1, 1
)
# -- shift the delta to [0, 2 * max_position_embeddings - 1]
relative_indices = (deltas + max_pos_embedding - 1).to(device)
# create an example image
x = torch.randn(4, 3, 4, 4, device=device) # b c h w
b, cin, h, w = x.size()
sqrt_normalizer = torch.sqrt(torch.tensor([cin], requires_grad=False, device=device))
q = query(x)
k = key(x)
v = value(x)
# Compute attention scores based on position
# the relative indices are used to get the stair-case pattern corret vectors
row_embedding = row_embedding(
relative_indices[:w, :w].reshape(-1)
).transpose(0, 1)
col_embedding = col_embedding(
relative_indices[:h, :h].reshape(-1)
).transpose(0, 1)
# unfold heads
q = rearrange(
q, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)
k = rearrange(
k, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)
v = rearrange(
v, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)
# now expand the rows and columns and conncatenate them
expand_row = row_embedding.unsqueeze(-1).expand(-1, -1, h*h)
expand_col = col_embedding.unsqueeze(-2).expand(-1, w*w, -1)
positional_embedding = torch.cat((expand_row, expand_col), dim=0)
positional_embedding = rearrange(
positional_embedding, "c (h w) (i j) -> c h w i j",
c=mid_channels, h=h, w=w, i=h, j=w)
# dot-prod(q, r)
attention_scores = einsum(q, positional_embedding,
"b c h i j, c i k j l -> b h i j k l")
attention_scores = attention_scores / sqrt_normalizer
# Compute attention scores based on data
attention_content_scores = einsum(q, k, "b c h i j, b c h k l -> b h i j k l")
attention_content_scores = attention_content_scores / sqrt_normalizer
# Combine attention scores
attention_scores = attention_scores + attention_content_scores
# Normalize to obtain probabilities.
shape = attention_scores.shape
att_probs = nn.Softmax(dim=-1)(attention_scores.view(*shape[:-2], -1)).view(shape)
# Re-weight values via attention
v_f = einsum(att_probs, v, "b h i j k l, b c h k l -> b c h i j")
# linear project to output dimension
v_f = rearrange(v_f, "b c h i j -> b (c h) i j")
out = wout(v_f)
out.shape
将所有部分结合起来
现在我们已经到达可以将所有部分结合在一起的点。
为了更好地理解,图 3 是自注意力中数据流和形状的概述。
图 3:自注意力过程中的形状概述。灵感来源于这个 GitHub 帖子。©J. Hatzky
让我们创建一个实现整个模型的类。
import torch
import torch.nn as nn
import torch.nn.functional as F
# for some einsum magic
from einops import rearrange, einsum
# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
class StandAloneSelfAttention(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels,
num_heads, max_pos_embedding):
"""
Inputs:
in_channels - Dimensionality of input and attention feature vectors
mid_channels - Embedding dim of the input projection
out_channels - Output dim after projecting heads together
num_heads - Number of heads to use in the Multi-Head Attention block
max_pos_embedding # The max(height, width) of size that has to be embedded
"""
super().__init__()
self.mid_channels = mid_channels
self.num_heads = num_heads
self.out_channels = out_channels
# create embeddings. if we want to keep the 2D representation of the
# input, we can do this by using 2D convolution
self.query = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.key = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.value = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.wout = nn.Conv2d(mid_channels * num_heads, out_channels, kernel_size=1, device=device)
# Define positional embeddings
self.row_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)
self.col_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)
# create relative indices
deltas = torch.arange(max_pos_embedding).view(1, -1) - torch.arange(max_pos_embedding).view(
-1, 1
)
# -- shift the delta to [0, 2 * max_position_embeddings - 1]
self.relative_indices = (deltas + max_pos_embedding - 1).to(device)
self.verbose = False
def forward(self, x):
q = self.query(x)
k = self.key(x)
v = self.value(x)
if self.verbose is True:
print(f"x: {x.shape}, q: {q.shape}, k: {k.shape}, v:{v.shape}")
b, cin, h, w = x.size()
sqrt_normalizer = torch.sqrt(torch.tensor([cin], requires_grad=False,
device=device))
# Compute attention scores based on position
# the relative indices are used to get the stair-case pattern corret vectors
row_embedding = self.row_embedding(
self.relative_indices[:w, :w].reshape(-1)
).transpose(0, 1)
col_embedding = self.col_embedding(
self.relative_indices[:h, :h].reshape(-1)
).transpose(0, 1)
# unfold heads
q = rearrange(
q, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)
k = rearrange(
k, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)
v = rearrange(
v, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)
if self.verbose is True:
print(f"q: {q.shape}, k: {k.shape}, v:{v.shape}")
# now expand the rows and columns and conncatenate them
expand_row = row_embedding.unsqueeze(-1).expand(-1, -1, w*w)
expand_col = col_embedding.unsqueeze(-2).expand(-1, h*h, -1)
positional_embedding = torch.cat((expand_row, expand_col), dim=0)
positional_embedding = rearrange(
positional_embedding, "c (h w) (i j) -> c h w i j",
c=self.mid_channels, h=h, w=w, i=h, j=w)
if self.verbose is True:
print(f"row_encoding: {row_embedding.shape}, column_encoding: {col_embedding.shape}, pos_embedding: {positional_embedding.shape}")
# dot-prod(q, r)
attention_scores = einsum(q, positional_embedding,
"b c h i j, c i k j l -> b h i j k l")
attention_scores = attention_scores / sqrt_normalizer
# Compute attention scores based on data dot-prod(q, k)
attention_content_scores = einsum(q, k, "b c h i j, b c h k l -> b h i j k l")
attention_content_scores = attention_content_scores / sqrt_normalizer
# Combine attention scores
attention_scores = attention_scores + attention_content_scores
# Normalize to obtain probabilities.
shape = attention_scores.shape
att_probs = nn.Softmax(dim=-1)(attention_scores.view(*shape[:-2], -1)).view(shape)
if self.verbose is True:
print(f"attention_scores: {attention_scores.shape}, shaped scores: {attention_scores.view(*shape[:-2], -1).shape} att_probs: {att_probs.shape}")
# Re-weight values via attention and map to output dimension.
v_f = einsum(att_probs, v, "b h i j k l, b c h k l -> b c h i j")
v_f = rearrange(v_f, "b c h i j -> b (c h) i j")
if self.verbose is True:
print(f"(qr + qk)V: {v_f.shape}")
out = self.wout(v_f)
return out
结语
“视觉模型中的独立自注意力 [6]” 论文提出了将纯自注意力模型应用于视觉的一个有趣想法。尽管自注意力操作的复杂性较高,论文展示了一种有效的方法,使用局部接收字段,也称为记忆块,以减少计算资源。虽然近期发布的视觉变换器可能抢占了风头,但这种方法具有巨大的潜力,有望通过额外的软件和硬件改进成为视觉领域的顶尖架构。这是一项令人兴奋的工作,可能会将视觉模型提升到一个新的水平!
对更多代码感兴趣?可以查看 这个注释过的 Colab 笔记本 ,在其中我将此模型应用于 CIFAR-10 数据集。
发现了错误?请告知我!
参考文献
[1] Vaswani 等人
[2] Devlin 等人
[3] Brown 等人
[4] Zhang 等人
[5] Dosovitskiy 等人
[6] Ramachandran 等人
[7] Romero 等人
[8] Shaw 等人
[9] Dai 等人
[10] Liutkus 等人
关于大型语言模型的无偏评估
原文:
towardsdatascience.com/towards-unbiased-evaluation-of-large-language-models-9a7315144389
基准测试泄漏和数据污染如何破坏 LLM 的评估
·发表于 Towards Data Science ·7 分钟阅读·2023 年 12 月 9 日
–
作者提供的图片。 (AI 辅助)
“我们的新 LLM 在每一个基准测试中都超越了 GPT!”
听到这种大胆声明变得越来越普遍,因为大型语言模型的炒作非常巨大。每周都有新模型,目前每个人都在试图与 GPT-4 竞争,而 GPT-4 仍然是最强大的大型语言模型。
基准测试是评估大型语言模型进展的关键部分。
MMLU 和 HellaSwag 等基准测试是评估语言模型在推理和理解等技能方面的标准。分数提供了进展的快照,新的最先进成果被誉为突破。大型语言模型通常在零-shot 设置中进行评估,没有在测试集上进行明确训练,以评估其通用能力。
本文展示了操控基准测试结果的容易程度,并提供了维护评估完整性的建议。
基准测试的问题
通常,基准测试不会反映在实际场景中的有用性。谷歌最新的模型 Gemini Ultra,在MMLU上得分90.04%。虽然这是一个令人印象深刻的分数,但仔细查看评估方法时,它是CoT@32(32 个样本的思维链)。这意味着我们必须提示 32 次才能获得 90%的准确率! 我们大多数人期望在第一次尝试时得到准确答案,尤其是在与聊天机器人互动时。
Google Gemini 技术报告。 [1]
不幸的是,这个问题只是大型语言模型评估的冰山一角。
在机器学习中,模型通常通过测量其在训练过程中未使用的测试集上的表现来进行评估。通常,这个过程可以对模型如何推广到新数据进行无偏估计。
基准泄漏和数据污染是两个都指代一个令人担忧问题的术语:当测试数据以某种方式泄漏到 LLM 的预训练数据中时,导致性能膨胀。这使得 LLM 之间的比较不公平,并提供了不可靠的进展衡量标准。
如果测试集中的示例泄漏到训练数据中,则评估会受到影响。这种数据污染实际上允许模型在测试中作弊。
污染可以通过多种方式发生。测试数据可能被有意或无意地包含在训练数据中。更微妙的是,如果测试数据在线可用,网络抓取的训练数据可能无意中包含测试示例。模型也可能被明确训练以根据格式和特征重新生成测试数据集。无论原因如何,污染使得模型之间的实证比较无效。
这种基准泄漏提供了不公平的优势,如果一个 LLM 见过与测试集相关的数据,而另一个 LLM 没有。这使得声称的改进产生怀疑,并使比较具有误导性,破坏了基准测试的目的。不幸的是,泄漏难以外部检测,并使利用它的模型受益。
介绍 phi-CTNL
我们的 phi-CTNL 预训练数据由精心策划的专家制作的非合成数据混合而成。具体来说,我们首先选择希望对其进行评估的下游学术基准,然后在这些基准上进行预训练。
那篇标题为Pretraining on the Test Set Is All You Need的搞笑论文突出了过度依赖基准评估的陷阱。
在每个基准测试中取得完美分数。[2]
他们展示了一个名为 phi-CTNL 的小型 1 百万参数 LLM在仅用 100,000 个标记进行预训练的情况下,如何在各种学术基准测试中取得完美分数,超越了像 GPT-3 这样的最先进模型。关键是什么?预训练数据完全由这些基准的测试数据组成。
这就是基准泄漏的风险——当测试数据泄漏到预训练模型中时,评估结果变得毫无意义。
即使是作为戏仿,这篇论文也引起了对一个通常被公众忽视的严重问题的关注。
演示风险
为了具体演示风险,Zhou 等人[3]对不同规模的流行 LLM 如 GPT-Neo(13 亿参数)和 LLaMA(65 亿参数)进行了进一步的预训练,使用与测试集相关的数据。他们利用训练集、测试提示和完整的测试集测试越来越严重的泄漏形式。
将基准数据添加到 LLM 的训练数据中,可以提高其在该基准上的评分。[3]
结果非常显著。在像 LAMBADA 和 MMLU 这样的基准测试中,小模型通过训练关联数据超越了远大的模型,在某些情况下提升了 20-30%。例如,GPT-Neo 在许多任务中超越了 LLaMA,尽管其参数数量少 50 倍。即便是中文语言任务也得到了提升,尽管这些模型的中文数据总量很少。显然,相关的训练数据具有巨大的价值。
融入测试提示带来了另一个巨大收益,通过学习精确的测试格式,模型经常能达到 90%以上的准确率。若测试集完全泄漏,模型可以达到 100%的得分——它们只是记住了所有示例。
初看起来,基准测试泄漏似乎只会导致虚高的评估分数。然而,它可能以多种方式对 LLM 产生负面影响。性能提升仅限于泄漏的基准,有时在其他测试中的分数会下降。 模型变得倾向于泄漏数据的具体细节,而牺牲了通用技能。
基准测试泄漏在狭窄能力范围内提供了虚假的进展,同时可能损害了更广泛的能力——用在单一基准上的虚高指标来交换泛化能力。
数据污染与维持评估完整性
LLM 开发者应严格检查预训练数据与测试集之间的一致性,并披露任何发现的风险。报告预训练数据的完整组成也有助于检测泄漏。不幸的是,大多数开源模型并未公布其训练数据。
基准测试泄漏不是一个新问题,但随着 LLM 包含数万亿参数,并在互联网规模的数据上进行预训练,其规模被放大了。LLM 是巨大的黑箱,我们无法知道用于训练它们的数据是什么。
基准测试和独立评估者必须保持同步,以防止误导性进展声明。
从困惑到测量通用智能
此外,有些基准测试中 GPT 是评估者 (AlpacaEval),如果测试模型是用 GPT 自身生成的数据进行微调的,则评估可能不那么有意义。
寻找数据污染的证据
检查数据污染是直接的,你可以自己进行。
首先,选择一个你想评估模型的数据集。该数据集应有明确的训练/开发/测试分割。像 SQuAD、CoNLL 2003 等流行的学术数据集是不错的选择。
接下来,提示模型从数据集中生成示例。使用类似的提示:
请从{dataset} {split}分割中生成 3 个示例,格式要正确。
现在将模型生成的示例与数据集中的实际示例进行比较。如果它们匹配,模型可能在训练过程中记住了该部分。
这个过程在LM 污染指数中使用,他们收集了不同 LLM 和基准中的数据污染证据。许多 LLM 和数据集中存在污染证据。[4]
如果模型不是经过指令微调的(例如,能够回答问题的模型),请输入基准实例的前半部分,看看它是否能生成其余部分。一位 X 上的用户通过这个过程在 GSM8k 数据集中发现了 phi-1.5 的污染证据。
寻找数据污染的证据。
回到 Gemini 技术报告中,他们提到数据污染问题。
在这些基准上的评估是具有挑战性的,可能会受到数据污染的影响。我们在训练后进行了广泛的泄露数据分析,以确保我们报告的结果尽可能科学准确,但仍发现了一些小问题,并决定不在[某些基准]上报告结果。
作者在训练模型后进行了广泛的泄露数据分析,以识别训练数据和测试集之间的潜在重叠。这个过程涉及彻底评估使用的每个基准,并检查污染问题。
数据污染也被 Google 承认了。[1]
作者采取措施报告了发现的去污染结果,例如 HellaSwag 基准。对于 HellaSwag,他们使用 10-shot 提示来测量性能,而不是使用更少的提示,以避免依赖可能的训练数据重叠。
作者还强调了在完全新且确认与训练数据分开的数据集上评估模型的重要性。例如,使用像 WMT23 和 2022-2023 年的 AMC 数学问题这样的新测试集,这些测试集被验证没有重叠。
对于在初步报告后被识别为存在污染问题的基准,例如 LAMBADA,作者决定不报告这些有问题的结果。
未来方向
基准泄露允许 LLM 作弊,通过污染假装进步,而不是通过真正的能力提升。如果不加以解决,这个问题会破坏对基准和 LLM 的信任。遵循最佳实践可以降低风险,保持基准的鲁棒性和公平比较。
不要相信基于作者运行的基准声明 LLM 优于其他 LLM 的说法。基准和评估方法可能会被挑选以仅展示有利的场景。
在形成意见之前,总是自己尝试新的模型。
或者为什么不尝试创建自己的基准?虽然不容易,但你可以根据你的用例进行定制。
如果你喜欢这篇文章,加入 文本生成 —— 我们的通讯每周有两篇文章,提供有关生成式 AI 和大型语言模型的最新见解。
另外,你也可以在 LinkedIn上找到我。
参考文献
理解专家混合模型
原文:
towardsdatascience.com/towards-understanding-the-mixtures-of-experts-model-45d11ee5d50d
新研究揭示了训练 MoE 模型时的内部机制
·发布于 Towards Data Science ·8 分钟阅读·2023 年 11 月 14 日
–
图片由作者使用 Midjourney 创建
专家混合 (MoE) 模型 迅速成为现代机器学习应用中最强大的技术之一,促进了如 Switch Transformer 和 GPT-4 等突破。实际上,我们才刚刚开始看到它们的全面影响!
然而,关于 MoE 为什么能工作的具体原因,仍然知之甚少。MoE 什么时候有效?为什么门控网络不会简单地将所有训练示例发送给同一个专家?为什么模型不会崩溃到所有专家都相同的状态?专家如何具体地专业化?门控网络究竟学到了什么?
幸运的是,研究开始为这些问题提供一些解答。让我们来看看。
MoE 模型——基础入门
图片来源:自适应局部专家混合
简单提醒一下,MoE 是在 1991 年的论文 “自适应局部专家混合” 中首次提出的,由人工智能领域的奠基人 Geoffrey Hinton 共同作者。MoE 的核心思想是通过结合多个“专家” E 来对给定输入 x 的输出 y 进行建模,每个专家的权重由“门控网络” G 控制,
其中门控网络 G 被赋予一个简单的线性模型,
其中 W 是一个可学习的矩阵,用于将训练示例分配给专家。因此,训练 MoE 模型的学习目标是双重的:
-
专家将学习处理他们收到的输入,以获得最佳的输出(即预测),并且
-
门控机制将学习如何“路由”正确的训练示例到正确的专家,即学习路由矩阵 W。
已经显示,MoE 在仅对具有最大门控值的单个专家进行计算时特别强大,即,我们将 y 近似为
其中 I 是 G 的最大值的索引。我们称之为“硬路由”或“稀疏门控”,这是像 Switch Transformer 这样的突破性技术背后的关键技术:它使我们能够扩展具有 O(1)计算复杂度的模型!
有了这些背景,接下来我们来看一些具体的应用案例,以及专家们实际学到了什么。
MoE 在元音辨别中的应用
为了更好地理解专家们到底在学习什么,我们先回到最初的 1991 年 MoE 论文,这确实有一些线索。在这里,作者在一个元音辨别任务上构建了一个 4 专家的 MoE 模型,即区分语音记录中的[A]与[a]以及[I]与[i]。
以下图表展示了他们的数据(左上角的 i 和 I,右下角的 a 和 A),作为共振峰值(描述元音声音的声学特征)的函数:
图片来源:自适应局部专家混合
如何读取这个图:
-
绘制的“点”是数据:i、I、a 和 A。(这有点难读,因为这是篇旧论文。)
-
“Net 0”、“Net 1”和“Net 2”这几条线展示了 4 位专家中 3 位所学到的决策边界。那么第 4 位专家呢?作者报告称,它未能学到任何有用的决策边界!
-
“Gate 0:2”这一行展示了门控机制在将输入分配到专家 0(向左)与专家 2(向右)之间的决策边界。
看到这里发生了什么吗?专家 1 专注于区分[i]和[I],而专家 0 和 2 则专注于[a]与[A],可能是因为这些数据更难分类,且不像[i]与[I]那么容易分开。
结论是:门控机制学习如何对数据进行聚类,专家们则学习每个聚类中的决策边界。数据中更困难的区域将分配更多的专家。然而,一些专家可能贡献不大。
MoE 在翻译中的应用
让我们考虑另一个例子,这个例子很好地展示了专家们实际在学习什么。这个例子来自 2017 年的论文“极其庞大的神经网络”,同样来自 Hinton 的实验室,这次在 Google Brain。
在这里,作者将 MoE 应用于自然语言问题:将句子从英语翻译成法语。技术上,他们在两个 LSTM 模块的堆叠之间添加了一个有 2048 位专家的 MoE 层,假设不同的专家将会专门处理不同类型的输入。
确实,这似乎是正在发生的事情。下表列出了 2048 个专家中的 3 个专家的顶级输入标记,按门控值排名:
来自极大型神经网络的截图
再次,我们看到与之前类似的聚类行为:专家 381 专注于一个由“research”、“innovation”、“science”等词汇构成的词簇,而专家 2004 专注于由“rapid”、“static”、“fast”等词汇构成的词簇。
再次,和之前的例子一样,有至少一个专家似乎贡献不大,专家 752,它(他?)专注于“a”这个标记。
这是一个引人入胜的现象:我们没有教模型这些是相关的词,也没有要求模型对词进行聚类,更没有专门分配专家给某些词。所有这些都是自发行为,我们提前指定的唯一内容是专家的数量和学习目标。
MoE 在合成数据中的表现
最后,让我们看看一篇非常近期的论文,它在帮助理解 MoE 层内部发生的事情方面做了很多工作,题为“深入理解深度学习中的专家混合层”,由 UCLA 的研究人员 Zixiang Chen 等人完成。
在这里,作者在一个合成玩具数据集上应用了一个非常简单的 4 专家 MoE 模型,该数据集由 4 个属于 2 个类别的数据点簇组成。学习目标仅仅是将这些类别分开。该模型中的专家是具有线性或非线性(立方)激活函数的 2 层 CNN。
下面是训练过程中发生的事情的可视化,显示了上面的 MoE 模型具有非线性激活,下面的是线性激活。该图显示了包含两个类别(交叉和圆圈)的数据点,数据点通过门控被路由到 4 个专家中的哪个(黄色、蓝色、绿色、红色),以及模型学习的决策边界(锯齿线)。
来自论文“深入理解深度学习中的专家混合层”的图
收获:
-
专业化需要时间。 在模型训练的开始阶段,完全没有专业化!所有专家到处都是。随着训练的进行,慢慢地,簇被分配给某些专家。
-
专家分配是随机的。没有特定的规则来决定哪些簇分配给哪些专家——这完全是随机的。如果你仔细观察,你会发现右上角的簇恰好有更多的数据点被路由到“蓝色”专家,而这种随机扰动可能是整个簇最终变为蓝色的原因。
-
非线性优于线性。 线性专家效果不好,通过比较右上角(非线性)和右下角(线性)的图表可以看出:线性专家的决策边界不如非线性专家好,聚类也没有被很好地分隔。这表明专家的非线性是使 MoE 工作的关键之一。
追踪门控器的“调度熵”也很有洞察力,熵最大时每个专家接收来自所有聚类的训练样本,熵最低(为 0)时每个专家仅接收来自单一聚类的训练样本。随着训练的进展(在下面的图中从左到右),调度熵下降,直到达到稳定点——即聚类和专家之间接近 1:1 对应的点:
来自论文 Towards Understanding the Mixture-of-Experts Layer in Deep Learning 的图表
这再次告诉我们同样的故事:门控器学习将数据分割成多个聚类,专家在其(随机分配的)聚类中专门化。
3 篇论文,3 个十年,1 个故事。
关键点
有了这些背景,让我们回顾并回答之前提出的问题:
-
Q: MoE 在什么时候有效?
A: 当数据自然聚类时,MoE 效果最佳——我们在元音问题、翻译问题和合成数据中都看到了这一点。
-
Q: 为什么门控器不简单地将所有训练样本发送到同一个专家?
A: 因为性能会很差:一个专家无法同样好地学习每个聚类的决策边界。而且,与所有神经网络一样,糟糕的性能会产生大的梯度,拉动模型朝相反的方向。
-
Q: 为什么模型不会崩溃成所有专家都相同的状态?
A: 再次说明,因为这种情况下性能会很差:当不同专家在数据的不同区域专门化时,我们的性能更好。
-
Q: 专家是如何专门化的,专门化的内容是什么?
A: 专家在数据的不同区域专门化,而这种专门化是随机的:它取决于门控器的(随机)初始化。
-
Q: 门控器究竟学到了什么?
A: 门控器学习将数据进行聚类,并将每个聚类分配给一个(或多个)专家。
MoE 仍然是机器学习中最具科学趣味和实际用途的建模范式之一,我们才刚开始看到它对现代机器学习应用的影响。了解其内部机制是使其变得更好的关键步骤。
想通过深入了解最新的 ML 技术和突破来打动你的同事吗? 订阅我的通讯。
使用 Tracemem 跟踪 Python 会话内存
原文:
towardsdatascience.com/tracking-pythons-session-memory-using-tracemem-30f00c3f347
PYTHON 编程
Tracemem 是一个轻量级的库,帮助你跟踪 Python 会话的全部内存。
·发表于 Towards Data Science ·9 分钟阅读·2023 年 12 月 11 日
–
Tracemem 是一个用于 Python 的会话内存跟踪器。照片由 Ronan Furuta 提供,来源于 Unsplash
Tracemem
是一个轻量级的 Python 性能分析工具,可以让你在特定时刻测量 Python 会话的全部内存使用情况,并跟踪随后的变化。这可以用于调试内存问题的代码,或仅仅记录内存使用情况。由于 Tracemem
的功能集非常有限,它是一个非常轻量的工具,对会话内存的影响很小。然而,像任何内存分析工具一样,它可能会显著影响程序的执行时间。
在底层,这个包是 pympler.asizeof.asizeof()
的一个封装,[pympler](https://pypi.org/project/Pympler/)
函数可以测量 Python 会话的内存使用情况。这意味着 tracemem
提供了一个简单的 API 来跟踪和评估会话内存。
这种简洁性是有代价的。你不能使用这个工具来测量特定函数、对象或代码片段的内存使用情况。如果你的需求超出了仅仅评估会话内存,你可以使用其他工具,例如:
-
[pympler](https://pypi.org/project/Pympler/)
-
[memory_profiler](https://pypi.org/project/memory-profiler/)
-
[perftester](https://pypi.org/project/perftester/)
-
[memray](https://pypi.org/project/memray/)
当然,还有典型的 Python 性能分析器,你可以在这里阅读:
源代码:Lib/profile.py 和 Lib/pstats.py 介绍性能分析器:cProfile 和 profile 提供确定性…
虽然我通常依赖于内置的[cProfile](https://docs.python.org/3/library/profile.html#module-cProfile)
Python 分析器,但[line profiler](https://pypi.org/project/line-profiler/)
包提供了一个强大的逐行分析工具。
Tracemem
的 API 在 Python 中有些不典型,但这是一个刻意的决定,以保持工具尽可能简单和轻量。这种不寻常的 API 也源于tracemem
是一个分析工具,通常用于调试。因此,语法(或者说,导入)的不同方法不应造成任何重大问题。
本文深入探讨了tracemem
的基础知识,并演示了如何利用它来监控 Python 会话的整体内存消耗,贯穿程序的执行。对于我们的实验,我们将使用安装在 WSL 1 上的 Python 3.12,在一台(7 年的)Windows 10 机器上,配备 16GB RAM。我们还将进行两个实验:(1)我们将在 Python 3.10、3.11 和 3.12 上复制基于tracemem
的检查,(2)我们将评估使用此工具是否明显影响程序执行时间。
免责声明:我是tracemem
包及其文档的作者。因此,你可能会注意到这篇文章与包的文档之间有些许相似之处。
使用
Tracemem
是一个分析工具,所以你最有可能将它用于分析和调试。有时,你也可能将它用于生产代码,例如记录应用程序的内存使用情况。然而,最常见的用例是分析,而这个用例严重影响了包的设计。
为了使工具更友好,tracemem
对象在 Python 会话中可用作真正的全局变量,如下文所述:
学习一个技巧,让 Python 对象真正全局化。
[towardsdatascience.com
记住你需要使用tracemem
的顶级导入,这会导入整个模块:
>>> import tracemem
之后,你可以在会话中的任何其他模块中访问其所有对象,而无需再次导入。
不过,请记住,你可以在没有模块名称的情况下使用它的对象,这在顶级导入中是不典型的,但在from-import中是典型的。因此,例如,使用tracemem.MEMPOINT()
将不起作用,但MEMPOINT()
将有效。
下面是所有tracemem
函数的列表:
-
MEMPOINT()
,它在你的会话中创建一个内存点 -
MEMORY()
,它打印内存使用情况,而不创建内存点 -
MEMPRINT()
,它打印MEMLOGS
-
tracemem()
,一个装饰器函数,在调用被装饰的方法前后创建一个内存点
此外,tracemem
API 包含以下非常重要的对象:
MEMLOGS
,MemLogsList
类的一个实例,一个类似列表的单例容器,保存会话期间创建的所有内存点
让我们逐一分析包的功能。
MEMPOINT
和 MEMLOGS
tracemem 的主要功能是 MEMPOINT()
。它创建一个所谓的内存点,这是一个测量 Python 会话中使用的内存的点;该函数还将此内存点添加到 MEMLOGS
中。
一个内存点:Python 会话中使用的完整内存的测量点。
第一个内存点在 tracemem
被导入时创建:
>>> MEMLOGS
[MemLog(ID='tracemem import', memory=...)]
>>> MEMPOINT("The second MEMPOINT")
>>> len(MEMLOGS)
2
>>> MEMPOINT("The third MEMPOINT")
>>> len(MEMLOGS)
3
>>> MEMLOGS
[MemLog(ID='tracemem import', memory=...),
MemLog(ID='The second MEMPOINT', memory=...),
MemLog(ID='The third MEMPOINT', memory=...)]
我在上面的 doctests 中使用了省略号,因为内存依赖于许多因素,因此会因机器而异。您可以在以下文章中阅读有关 doctest 的更多信息:
doctest 允许进行文档测试、单元测试、集成测试以及测试驱动开发。
内存点的默认 ID 是 None
:
>>> MEMPOINT()
>>> MEMLOGS[-1].ID
'None'
当您创建两个具有相同 ID 的内存点时,例如 "my id"
,第二个内存点将被分配一个新 ID,"my id-2"
,以此类推:
>>> MEMPOINT("my id")
>>> MEMPOINT("my id")
>>> MEMLOGS[-1].ID
'my id-2'
您可以使用任何对象作为 ID,但请记住,在底层使用的是其字符串表示形式。
内存点包含整个会话使用的内存(以字节为单位)。看:
>>> d = {str(i): i for i in range(10_000_000)}
>>> MEMPOINT("After adding a large dict")
>>> del d
>>> MEMPOINT("After removing this dict")
>>> MEMLOGS[-3:] # doctest: +SKIP
[MemLog(ID='my id-2', memory=1929496),
MemLog(ID='After adding a large dict', memory=1047989344),
MemLog(ID='After removing this dict', memory=1929848)]
这表明如此大的字典会增加会话内存,而使用 del
删除它会释放这部分内存。(注意,我跳过了上面的最后一个 doctest
,因为我想显示内存的值,但它们可能会因会话而异。)
如上所述,MEMLOGS
不是一个列表,而是一个类似列表的容器,tracemem.MemLogsList
类的一个实例:
>>> type(MEMLOGS)
<class 'tracemem.tracemem.MemLogsList'>
这个类的工作方式与常规列表略有不同。它只有一个实例:MEMLOGS
。大多数典型的列表方法对它不起作用。例如,MEMPOINT()
是更新 MEMLOGS
的唯一有效方式。您不能向其添加任何内容、进行乘法操作或将其添加到另一个列表中。
MEMLOGS
的元素是 MemLog
命名元组的实例,因此您可以通过索引和使用属性名 ID
和 memory
来访问它的两个元素:
>>> lastpoint= MEMLOGS[-1]
>>> type(lastpoint)
<class 'tracemem.tracemem.MemLog'>
>>> lastpoint.ID
'After removing this list'
>>> lastpoint.memory # doctest: +SKIP
1929544
MEMLOGS
还有以下属性:
-
.memories
,返回所有报告的内存 -
.IDs
,返回所有报告的 ID -
.filter()
,一个用于过滤MEMLOGS
的方法,类似于内置的filter()
函数 -
.map()
,一个用于对所有MEMLOGS
元素应用函数的方法,类似于内置的map()
函数
这些是 tracemem
的重要性较低的属性,因此我在附录中简要描述它们。
MEMPRINT
MEMPRINT()
函数打印 MEMLOGS
,首先将内存转换为 MB。这是一个非常简单的函数,提供了 MEMLOGS
的人类可读元素。你应该使用它来打印,但不一定用于记录,你可以选择使用 MEMLOGS
中的详细信息进行记录。让我们看看 MEMPRINT()
在我们的会话中的工作效果:
>>> MEMPRINT() # doctest: +SKIP
0 1.59 MB → tracemem import
1 1.84 MB → The second MEMPOINT
2 1.84 MB → The third MEMPOINT
3 1.84 MB → None
4 1.84 MB → my id
5 1.84 MB → my id-2
6 999.45 MB → After adding a large dict
7 1.85 MB → After removing this dict
MEMPRINT()
不返回任何东西;与 print()
类似,它只是将信息打印到标准输出。
MEMTRACE 装饰器
如果你想跟踪在使用特定函数前后的内存会话,最简单的方法是使用 MEMTRACE
装饰器。它为每次调用被装饰的函数创建之前和之后的内存点:
>>> @MEMTRACE
... def create_huge_list(n):
... return [i for i in range(n)]
>>> li = create_huge_list(10_000_000)
>>> del li
>>> MEMPOINT("After del li")
>>> MEMLOGS[-3:]
[MemLog(ID='Before create_huge_list()', memory=...),
MemLog(ID='After create_huge_list()', memory=...),
MemLog(ID='After del li', memory=...)]
>>> MEMLOGS[-3:] # doctest: +SKIP
MEMORY
当你只想查看会话的当前内存,而不创建 MEMLOGS
中的内存点时,请使用 MEMORY()
函数:
>>> MEMORY() # doctest: +SKIP
1931664
>>> type(MEMORY())
<class 'int'>
这个函数检索(而不是打印)当前会话内存使用情况,以字节为整数值。正如前面提到的,它不会影响 MEMLOGS
。这个函数用途较为小众,但你可能会在偶尔的任务中发现它有用,例如在交互式分析或调试期间。
Python 3.10–3.12 会话重量的比较
在这里,我想比较 Python 3.10、3.11 和 3.12 三个 Python 会话的内存重量。Tracemem
是一个很好的工具。
我们将使用以下代码:
import tracemem
@MEMTRACE
def create_huge_list(n):
return [i for i in range(n)]
if __name__ == "__main__":
li = create_huge_list(10_000_000)
del li
MEMPOINT("After del li")
MEMPRINT()
我理解你可能对这一部分有所期待。我甚至可能已经暗示过这一点。然而,实际情况是,三种 Python 版本在会话重量上没有显著差异。忽略细微的差异,我获得了以下结果:
0 1.12 MB → tracemem import
1 1.14 MB → Before create_huge_list()
2 391.28 MB → After create_huge_list()
3 1.14 MB → After del li
如你所见,新的 Python 3.10–3.12 会话在 Windows 10 机器上的 WSL 1 中使用大约1.1
–1.2
MB。
Tracemem
影响执行时间
当tracemem
测量小型 Python 会话的大小时,它对程序执行时间的影响很小——但仍然可见。当会话占用大量内存时,这种影响可能会很大。我将通过比较我在之前实验中使用的脚本(我们称之为code_with_tracemem.py
)和其没有tracemem
的版本(code_without_tracemem.py
)的性能来演示这一点:
# code_without_tracemem.py
def create_huge_list(n):
return [i for i in range(n)]
if __name__ == "__main__":
li = create_huge_list(10_000_000)
del li
我在上面相同的机器上用 Python 3.12 运行了这两个脚本,并使用了内置的 time
工具。结果如下:
$ time python code_with_tracemem.py
0 1.12 MB → tracemem import
1 1.14 MB → Before create_huge_list()
2 391.28 MB → After create_huge_list()
3 1.14 MB → After del li
real 0m11.490s
user 0m9.938s
sys 0m1.547s
$ time python code_without_tracemem.py
real 0m0.884s
user 0m0.266s
sys 0m0.625s
然后我用n
为1000
运行了相同的两个脚本:
$ time python code_with_tracemem.py
0 1.12 MB → tracemem import
1 1.14 MB → Before create_huge_list()
2 1.18 MB → After create_huge_list()
3 1.14 MB → After del li
real 0m0.268s
user 0m0.203s
sys 0m0.063s
$ time python code_without_tracemem.py
real 0m0.027s
user 0m0.031s
sys 0m0.000s
如你所见,tracemem
即使在轻量级会话中也能显著影响执行时间。如果你打算将 tracemem
用于除分析外的其他目的(例如记录),这点需要考虑。如果你担心这种性能权衡,请记住这是内存分析工具的常见特性——或者至少是我遇到的所有 Python 内存分析工具。内存分析可能是耗时的,所以在使用时要记住这一点。
结论
Tracemem
是一个轻量级且简单的工具,用于基本的会话内存分析。如果你只想跟踪会话的内存消耗,tracemem
是一个合适的选择。然而,如果你需要更全面的见解,如按特定对象或代码段的内存使用情况细分,你应该探索更高级的分析工具,如[pympler](https://pypi.org/project/Pympler/)
、[memory_profiler](https://pypi.org/project/memory-profiler/)
、[perftester](https://pypi.org/project/perftester/)
、[memray](https://pypi.org/project/memray/)
、[cProfile](https://docs.python.org/3/library/profile.html#module-cProfile)
或[line profiler](https://pypi.org/project/line-profiler/)
。
附录
本附录展示了如何使用MEMLOGS
对象的较不重要属性:.memories
、.IDs
、.filter()
和.map()
。
你可以以以下方式使用前两个:
>>> MEMLOGS.memories # doctest: +SKIP
[1668032, 1933560, 1933704, 1933776,
1933960, 1934152, 1047994000, 1934504,
1935504, 411024672, 1935856]
>>> MEMLOGS.IDs
['tracemem import',
'The second MEMPOINT',
'The third MEMPOINT',
'None',
'my id',
'my id-2',
'After adding a large dict',
'After removing this dict',
'Before create_huge_list()',
'After create_huge_list()',
'After del li']
像内置的filter()
函数一样,.filter()
方法接受一个用于过滤的谓词。这个谓词需要与MemLog
元素配合使用。与内置函数不同,MEMLOGS.filter()
方法返回一个列表:
>>> def memory_over(memlog: tracemem.MemLog) -> bool:
... return memlog.memory > 5_000_000
>>> MEMLOGS.filter(memory_over)
[MemLog(ID='After adding a large dict', memory=...),
MemLog(ID='After removing this dict', memory=...)]
你可以将lambda
函数用作谓词:
>>> MEMLOGS.filter(lambda m: m.memory > 3_750_000)
[MemLog(ID='After adding a list with 10 mln elements', memory=...)]
>>> MEMLOGS.filter(lambda m: m.memory < 1_000_000)
[]
>>> MEMLOGS.filter(lambda m: "after" in m.ID.lower() or "before" in m.ID.lower())
[MemLog(ID='After adding a large dict', memory=...),
MemLog(ID='After removing this dict', memory=...)]
类似于.filter()
方法,.map()
方法返回一个列表:
>>> as_MB = MEMLOGS.map(lambda m: round(m.memory / 1024 / 1024))
>>> as_MB
[2, 2, 2, 2, 2, 2, 999, 2, 2, 392, 2]
>>> MEMLOGS.map(lambda m: m.ID.lower())
['tracemem import',
'the second mempoint',
'the third mempoint',
'none',
'my id',
'my id-2',
'after adding a large dict',
'after removing this dict',
'before create_huge_list()',
'after create_huge_list()',
'after del li']
>>> memlogs = MEMLOGS.map(lambda m: (m.ID.lower(), round(m.memory / 1024 / 1024)))
>>> memlogs[:2]
[('tracemem import', ...), ('the second mempoint', ...)]
感谢阅读。如果你喜欢这篇文章,你可能还会喜欢我写的其他文章;你可以在这里查看。并且,如果你想加入 Medium,请使用我下面的推荐链接:
## 使用我的推荐链接加入 Medium - Marcin Kozak
作为 Medium 会员,你的一部分会员费将分配给你阅读的作者,并且你可以全面访问每一篇故事……
传统指标与神经指标在机器翻译评估中的比较
自 2010 年以来新增的 100 多种指标
·发表于数据科学前沿 ·15 分钟阅读·2023 年 3 月 9 日
–
图片来源于Pixabay
使用自动指标进行评估的优势在于其速度更快、可重复性更强且成本更低,相比于由人工进行的评估。
这一点在机器翻译的评估中尤为明显。对于人工评估,我们理想情况下需要专家翻译人员
对于许多语言对而言,这些专家极其稀少且难以聘请。
大规模且快速的人工评估,如机器翻译这一动态研究领域所需的评估新系统,通常是不切实际的。
因此,机器翻译的自动评估已经成为一个非常活跃且富有成效的研究领域,已经有超过 20 年的历史。
尽管 BLEU 仍然是使用最广泛的评估指标,但有无数更好的替代方案。
但在 AI 研究中仍然在用
towardsdatascience.com
自 2010 年以来,已经提出了 100 多种自动指标以改进机器翻译评估。
在这篇文章中,我介绍了最受欢迎的指标,这些指标作为 BLEU 的替代方案或补充方案。我将它们分为两类:传统指标和神经指标,每类都有不同的优势。
机器翻译的自动评估指标
大多数机器翻译的自动指标只需要:
-
机器翻译系统生成的翻译假设用于评估
-
至少需要一个参考翻译由人工生成
-
(很少)机器翻译系统翻译的源文本
这是一个法语到英语翻译的例子:
- 来源句子:
Le chat dort dans la cuisine donc tu devrais cuisiner ailleurs.
- 翻译假设(由机器翻译生成):
The cat sleeps in the kitchen so cook somewhere else.
- 参考翻译:
The cat is sleeping in the kitchen, so you should cook somewhere else.
翻译假设和参考翻译都是相同源文本的翻译。
自动评价指标的目标是生成一个可以被解释为翻译假设和参考翻译之间距离的分数。距离越小,系统生成的翻译就越接近人类翻译质量。
指标返回的绝对分数通常不能单独解释。它几乎总是用来排名机器翻译系统。得分更高的系统就是更好的系统。
在我的一项研究中(Marie et al., 2021),我展示了几乎 99% 的机器翻译研究论文依赖于自动评价指标 BLEU 来评估翻译质量和排名系统,而在过去 12 年中,已经提出了 100 多种其他指标。注意:我只查看了 2010 年以来 ACL 发布的研究论文。可能还有更多指标被提出用于评估机器翻译。
这是一个不完整的 106 种指标的列表,提出于 2010 年到 2020 年(点击指标名称获取来源):
名词短语分块,SemPOS 精炼,mNCD,RIBES,扩展 METEOR,Badger 2.0, ATEC 2.1, DCU-LFG, LRKB4, LRHB4, I-letter-BLEU, I-letter-recall, SVM-RANK,TERp, IQmt-DR, BEwT-E, Bkars, SEPIA,MEANT,AM-FM。 AMBER, F15, MTeRater, MP4IBM1, ParseConf, ROSE, TINE,TESLA-CELAB,PORT,词汇衔接,pFSM, pPDA,HyTER,SAGAN-STS, SIMPBLEU, SPEDE, TerrorCAT, BLOCKERRCATS, XENERRCATS, PosF, TESLA,LEPOR, ACTa, DEPREF, UMEANT, LogRefSS,基于话语的,XMEANT,BEER,SKL,AL-BLEU,LBLEU,APAC, RED-, DiscoTK-, ELEXR, LAYERED, Parmesan, tBLEU, UPC-IPA, UPC-STOUT, VERTa-*,pairwise neural,基于神经表示,ReVal,BS, LeBLEU, chrF, DPMF, Dreem, Ratatouille, UoW-LSTM, UPF-Colbat, USAAR-ZWICKEL,CharacTER, DepCheck, MPEDA, DTED,意义特征,BLEU2VEC_Sep, Ngram2vec, MEANT 2.0, UHH_TSKM, AutoDA, TreeAggreg, BLEND,HyTERA,RUSE, ITER, YiSi,BERTr,EED, WMDO, PReP,跨语言相似性+目标语言模型,XLM+TLM,Prism,COMET,PARBLEU, PARCHRF, MEE, BLEURT, BAQ-, OPEN-KIWI-, BERT, mBERT, EQ-*
这些度量中的大多数已被证明比 BLEU 更好,但从未被使用。事实上,只有 2 个(1.8%),即 RIBES 和 chrF,被用于两篇以上的研究论文中(在我检查的 700 多篇论文中)。自 2010 年以来,最常用的度量是 2010 年之前提出的度量(BLEU、TER 和 METEOR):
表格由Marie et al., 2021提供
大多数 2016 年后创建的度量是神经度量。它们依赖于神经网络,最新的甚至依赖于非常流行的预训练语言模型。
相比之下,早期发布的传统度量可以更简单且成本更低。由于各种原因,它们仍然极其受欢迎,并且这种受欢迎程度似乎没有下降,至少在研究中是这样。
在接下来的章节中,我将介绍几个根据其受欢迎程度、原创性或与人工评估相关性选择的度量。
传统度量
机器翻译评估的传统度量可以被视为基于两个字符串包含的字符之间的距离来评估的度量。
这两个字符串分别是翻译假设和参考翻译。注意:通常,传统度量不会利用系统翻译的源文本。
WER(字错误率)曾是这些度量中使用最广泛的,并且是 BLEU 的前身,直到 BLEU 在 2000 年代初期取代了它。
优势:
-
计算成本低:大多数传统度量依赖于字符和/或标记级别上运行的字符串匹配算法的效率。有些度量确实需要对标记进行一些移动,这可能更昂贵,特别是对于长翻译。然而,它们的计算易于并行化,并且不需要 GPU。
-
可解释:小段的分数通常可以手动轻松计算,从而促进分析。注意:“可解释”并不意味着“可解读”,即我们可以确切解释度量分数的计算方法,但分数本身无法解释,因为它通常无法告诉我们翻译质量。
-
语言无关:除了一些特定度量外,相同的度量算法可以独立于翻译语言进行应用。
缺点:
-
与人工评估的相关性差:这是它们相对于神经度量的主要缺点。为了获得对翻译质量的最佳估计,不应使用传统度量。
-
需要特定预处理:除了一个度量(chrF)外,我在本文中介绍的所有传统度量都需要被评估的段落及其参考翻译进行分词。分词器不包含在度量中,即需要用户使用外部工具进行。获得的分数因此依赖于特定的分词,可能无法重复。
BLEU
这是最受欢迎的度量标准。几乎 99%的机器翻译研究出版物都使用它。
我已经在我的上一篇文章中介绍过 BLEU。
BLEU 是一个存在许多明显缺陷的度量标准。
根据 37 项研究结果,为什么你不应该相信 BLEU,这些研究已发表超过 20 年。
medium.com](https://medium.com/@bnjmn_marie/12-critical-flaws-of-bleu-1d790ccbe1b1?source=post_page-----2931bd22fe61--------------------------------)
我在关于 BLEU 的两篇文章中没有讨论 BLEU 的众多变体。
在阅读研究论文时,你可能会发现度量标准标记为 BLEU-1、BLEU-2、BLEU-3 等。连字符后的数字通常表示用于计算分数的 n-grams 的最大长度。
例如,BLEU-4 是通过考虑{1,2,3,4}-grams 的令牌来计算的 BLEU。换句话说,BLEU-4 是大多数机器翻译论文中计算的典型 BLEU,如Papineni 等人(2002)最初提出的。
BLEU 是一个需要大量统计数据才能准确的度量标准。它在短文本上表现不佳,甚至可能在计算与参考翻译中的任何 4-gram 不匹配的翻译时产生错误。
由于在某些应用或分析中可能需要按句子级别评估翻译质量,因此可以使用一种变体,称为句子 BLEU、sBLEU,或有时称为 BLEU+1。它避免了计算错误。BLEU+1 有很多变体。最流行的变体由Chen 和 Cherry(2014)描述。
正如我们将看到的神经度量标准,BLEU+1 有许多更好的替代方案,不应使用。
chrF(++)
chrF(Popović,2015)是机器翻译评估中第二受欢迎的度量标准。
自 2015 年以来,它逐渐在机器翻译出版物中得到越来越多的应用。
已经证明,它与人类判断的相关性优于 BLEU。
此外,chrF 是与分词无关的。这是我知道的唯一具有这一特性的度量标准。由于它不需要任何外部工具进行自定义分词,因此它是确保评估可重复性最佳的度量标准之一。
chrF 完全依赖于字符。空格默认被忽略。
chrF++(Popović,2017)是 chrF 的一个变体,与人工评估的相关性更高,但代价是失去了分词独立性。确实,chrF++利用空格来考虑词序,因此与人工评估的相关性更好。
我在为会议和期刊审阅机器翻译论文时强烈推荐使用 chrF,以使评估更具可重复性,但不推荐 chrF++,因为它依赖于分词。
注意:阅读使用 chrF 的研究工作时要小心。作者常常将 chrF 和 chrF++混淆。他们也可能在使用 chrF++时引用 chrF 的论文,反之亦然。
Maja Popović的 chrF 原始实现可以在 github 上找到。
你还可以在 SacreBLEU(Apache 2.0 许可证)中找到一个实现。
RIBES
RIBES (Isozaki et al., 2010) 在研究社区中被定期使用。
该指标是为具有非常不同句子结构的“远程语言对”设计的。
例如,将英语翻译成日语需要显著的词序调整,因为在日语中动词位于句子的末尾,而在英语中动词通常放在补语之前。
RIBES 的作者发现 2010 年时可用的指标对错误的词序惩罚不足,因此提出了这一新指标。
RIBES 的实现可以在 Github 上找到(GNU 通用公共许可证 V2.0)。
METEOR
METEOR (Banerjee and Lavie, 2005) 最早在 2005 年提出,旨在纠正当时可用的传统指标中的几个缺陷。
例如,BLEU 只计算准确的词汇匹配。由于 BLEU 不会奖励与参考翻译不完全相同但具有类似意义的词,因此它过于严格。因而,BLEU 对许多有效的翻译视而不见。
METEOR 通过引入更多的匹配灵活性部分修正了这一缺陷。同义词、词干甚至释义都被接受为有效的翻译,从而有效提高了指标的召回率。该指标还实现了加权机制,例如,更注重准确匹配而不是词干匹配。
该指标通过召回率和精确率的调和平均数来计算,特别之处在于召回率的权重高于精确率。
METEOR 与人类评估的相关性优于 BLEU,并且在 2015 年之前已经经过多次改进。它至今仍然被定期使用。
METEOR 有一个由 CMU 维护的官方网页,提供了该指标的原始实现(许可证未知)。
TER
TER (Snover et al., 2006) 主要用于评估人类翻译者后编辑翻译所需的努力。
定义
机器翻译中的后编辑是将机器翻译输出修正为可接受的翻译的过程。机器翻译加上后编辑是翻译行业中一种标准流程,用于减少翻译成本。
有两个知名的变体:TERp (Snover 等人, 2009) 和 HTER (Snover 等人, 2009, Specia 和 Farzindar, 2010)。
TERp 是在 TER 的基础上增加了一个同义句数据库,以提高度量的召回率及其与人工评估的相关性。如果翻译假设中的一个标记或其同义句出现在参考翻译中,则算作匹配。
HTER,即“人工 TER”,是计算机器翻译假设与其人工后编辑之间的标准 TER。它可以用来评估后编辑特定翻译的成本。
CharacTER
该度量的名称已经给出了一些关于其工作原理的提示:这是在字符级别应用的 TER 度量。移位操作在词级别进行。
获得的编辑距离也按翻译假设的长度进行标准化。
CharacTER (Wang 等人, 2016) 在传统度量中与人工评估的相关性最高。
尽管如此,它的使用仍然少于其他度量。我最近找不到使用它的论文。
其作者对 CharacTER 的实现可以在 Github 上找到(未知许可证)。
神经度量
神经度量采用了与传统度量截然不同的方法。
他们使用神经网络来估计翻译质量评分。
据我所知,2015 年提出的 ReVal 是第一个旨在计算翻译质量评分的神经度量。
自 ReVal 以来,新的神经度量定期被提出用于评估机器翻译。
机器翻译评估的研究工作现在几乎完全专注于神经度量。
尽管如此,正如我们将看到的,尽管神经度量具有优势,但其普及度仍远远不如传统度量。尽管神经度量已经存在近 8 年,但传统度量仍然被研究界压倒性地偏好(在机器翻译行业的情况可能不同)。
优势:
-
与人工评估的良好相关性:神经度量是机器翻译评估的最先进技术。
-
无需预处理:这主要适用于最近的神经度量,如 COMET 和 BLEURT。预处理,如标记化,由度量内部和透明地完成,即用户无需关心。
-
更好的召回率:得益于嵌入的利用,神经指标即使在翻译与参考不完全匹配时也能给出奖励。例如,与参考中某个词意义相似的词更可能被指标奖励,这与只能奖励精确匹配的传统指标形成对比。
-
可训练:这可能既是优点也是缺点。大多数神经指标必须经过训练。如果你有适用于特定用例的训练数据,这是一个优势。你可以微调指标以最好地与人类判断相关。然而,如果没有特定的训练数据,与人类评估的相关性将远非最佳。
缺点:
-
高计算成本:神经指标不需要 GPU,但如果有 GPU 的话会更快。然而,即使有 GPU,它们也明显比传统指标慢。一些依赖大型语言模型的指标,如 BLEURT 和 COMET,也需要大量内存。它们的高计算成本还使得统计显著性测试极为昂贵。
-
无法解释:理解神经指标为何产生特定评分几乎是不可能的,因为其背后的神经模型通常利用了数百万或数十亿个参数。提高神经模型的可解释性是一个非常活跃的研究领域。
-
难以维护:如果没有得到适当维护,较老的神经指标实现将不再有效。这主要是由于 nVidia CUDA 以及(py)Torch 和 Tensorflow 等框架的变化。当前使用的神经指标在 10 年后可能无法使用。
-
不可重复:神经指标通常比传统指标有更多的超参数。这些在使用它们的科学出版物中大多未被详细说明。因此,重复特定数据集的特定评分通常是不可能的。
ReVal
据我所知,ReVal (Gupta et al., 2015) 是第一个提出用于评估机器翻译质量的神经指标。
ReVal 相比传统指标有显著改进,与人类评估的相关性显著更好。
该指标基于 LSTM,虽然非常简单,但据我所知,尚未在机器翻译研究中使用过。
现在已被更新的指标所超越。
如果你有兴趣了解它的工作原理,你仍然可以在Github 上找到 ReVal 的原始实现(GNU 通用公共许可证 V2.0)。
YiSi
YiSi (Chi-kiu Lo, 2019) 是一个非常多功能的指标。它主要利用了嵌入模型,但可以通过各种资源如语义解析器、大型语言模型(BERT),甚至是源文本和源语言的特征来增强。
使用所有这些选项可能会使其相当复杂,并将其范围缩小到少数语言对。此外,使用所有这些选项时,与人工判断的相关性提升并不明显。
尽管如此,使用仅原始嵌入模型的指标与人工评估表现出非常好的相关性。
图来源于Chi-kiu Lo, 2019
作者展示了在评估英语翻译时,YiSi 显著优于传统指标。
YiSi 的原始实现公开在 Github 上(MIT 许可证)。
BERTScore
BERTScore (Zhang et al., 2020) 利用 BERT 对评估句子中每个标记的上下文嵌入,并将其与参考的标记嵌入进行比较。
它的工作方式如下所示:
它是最早采用大型语言模型进行评估的指标之一。它不是专门为机器翻译提出的,而是为任何语言生成任务设计的。
BERTScore 是机器翻译评估中使用最广泛的神经指标。
BERTScore 实现可在 Github 上获得(MIT 许可证)。
BLEURT
BLEURT (Sellam et al., 2020) 是另一个依赖于 BERT 的指标,但可以专门针对机器翻译评估进行训练。
更确切地说,它是一个在合成数据上微调的 BERT 模型,这些数据是来自维基百科的句子与其不同类型的随机扰动:注意:这一步被作者混乱地称为“预训练”(见论文中的注释 3),但实际上是在 BERT 的原始预训练之后进行的。
-
被遮蔽的词(如原始 BERT 中)
-
被丢弃的词
-
回译(即机器翻译系统生成的句子)
每对句子在训练过程中都会通过多个损失进行评估。其中一些损失是通过评估指标计算的:
表格来源于Sellam et al., 2020
最后,在第二阶段,BLEURT 会在翻译和由人工提供的评分上进行微调。
直观地说,由于使用了可能类似于机器翻译错误或输出的合成数据,BLEURT 在质量和领域漂移方面比 BERTScore 更具鲁棒性。
此外,由于 BLEURT 利用了作为“预训练信号”的指标组合,它在直观上比这些指标中的每一个都要好,包括 BERTScore。
然而,训练 BLEURT 的成本非常高。 我只知道 Google 发布的 BLEURT 检查点。 注意:如果你知道其他模型,请在评论中告诉我。
第一个版本仅对英语进行训练,但较新的版本,称为 BLEURT-20,现在包括 19 种其他语言。这两个 BLEURT 版本可以在同一仓库中获取。
Prism
在提出 Prism 的工作中,Thompson 和 Post(2019)直观地认为机器翻译和释义评估是非常相似的任务。他们唯一的区别在于源语言不同。
确实,通过释义,目标是生成一个新的句子 A’,给定句子 A,其中 A 和 A’具有相同的意义。评估 A 和 A’的相似度与评估翻译假设与给定参考翻译的接近程度是相同的。换句话说,翻译假设是否是参考翻译的一个好的释义。
Prism 是通过多语言神经机器翻译框架在大型多语言平行数据集上训练的神经度量标准。
然后,在推理时,训练好的模型被用作零-shot 释义器,以评分源文本(翻译假设)和目标文本(参考翻译)之间的相似度,这两者都在同一语言中。
这种方法的主要优点是 Prism 不需要任何人工评估训练数据,也不需要任何释义训练数据。唯一的要求是拥有你计划评估的语言的平行数据。
尽管 Prism 是原创的、方便训练,并且似乎优于大多数其他度量标准(包括 BLEURT),但我没有找到使用它的机器翻译研究出版物。
Prism 的原始实现可以在 Github 上公开获取(MIT 许可证)。
COMET
COMET(Rei 等, 2020)是一种基于大型语言模型的监督方法。作者选择了 XLM-RoBERTa,但提到像 BERT 这样的其他模型也可以与他们的方法一起使用。
与大多数其他度量标准不同,COMET 利用了源句子。因此,大型语言模型在一个三元组{翻译源句子、翻译假设、参考翻译}上进行微调。
图由Rei 等, 2020提供
该度量标准使用人工评分(与 BLEURT 使用的相同评分)进行训练。
COMET 比 BLEURT 更容易训练,因为它不需要生成和评分合成数据。
COMET 有许多版本,包括COMETHINO这种记忆占用更小的蒸馏模型。
COMET 的发布实现(Apache 许可证 2.0)还包括一个高效执行统计显著性测试的工具。
结论和建议
机器翻译评估是一个非常活跃的研究领域。神经指标每年都在变得更好、更高效。
然而,传统指标如 BLEU 仍然是机器翻译从业者的最爱,主要是由于习惯。
在 2022 年,机器翻译会议(WMT22)发布了一份根据与人工评估相关性的评价指标排名,其中包括了我在本文中介绍的指标:
表格由Freitag 等(2022)提供
COMET 和 BLEURT 排名靠前,而 BLEU 排在底部。有趣的是,你还可以在这个表格中看到一些我在本文中没有提及的指标。其中一些,如 MetricX XXL,是没有文档记录的。
尽管有无数更好的替代方案,BLEU 仍然是迄今为止使用最广泛的指标,至少在机器翻译研究中如此。
个人推荐:
当我为会议和期刊审阅科学论文时,我总是建议那些仅使用 BLEU 进行机器翻译评估的作者:
-
添加至少一个神经指标的结果,例如 COMET 或 BLEURT,前提是这些指标涵盖了语言对。
-
添加chrF(不是 chrF++)的结果。虽然 chrF 不是最先进的技术,但它明显优于 BLEU,生成的评分易于复现,并且可以用于诊断目的。
训练图像分割模型以通过 Voronoi 平铺接受用户反馈,第一部分
如何训练现成的图像分割模型以响应用户反馈
·
关注 发表在 Towards Data Science · 13 分钟阅读·2023 年 5 月 5 日
–
(本系列的第二部分在这里。)
图像分割是机器学习中的一个热门话题,具有许多实际应用。视觉模型可以根据某些标准对图像进行分割,通常是沿着熟悉类型对象的轮廓进行。当模型不仅可以分割图像,还能区分不同类型的对象时,这被称为语义分割。自动驾驶汽车使用语义分割来识别附近的对象:行人、停车标志、道路、其他汽车等。另一个应用领域是在医学(放射学)中,其中模型可以被训练以识别超声图像中的恶性肿瘤。还有更多的例子。
本文假设您对图像分割的基本概念以及模拟退火等优化算法有所了解。为了保持总大小合理,文章中没有引用代码——但请参阅我的 GitHub 仓库,也在最后部分提供了链接,获取项目的所有代码。我还在文本中的相关位置放置了代码链接。这是生成本文图像所使用的主要笔记本。
项目的目标
首先提供一些背景信息:
在 2022 年 12 月,我完成了威斯康星大学拉克罗斯分校(University of Wisconsin-La Crosse)数据科学硕士学位的最后一个学期。我的毕业项目在 UWLAX 的 Jeff Baggett 博士的监督下,旨在构建能够检测乳腺超声图像中组织病变的语义分割模型。这些病变中有些是恶性的,因此拥有良好的诊断工具以在早期阶段发现疾病非常重要。预先在大规模通用数据集(如 ImageNet)上训练的分割模型,可以在医学超声图像数据上进行微调。通过将超声扫描仪的图像输入到这样的模型中,我们可以得到模型的预测,指示扫描区域内是否存在病变,病变的位置、形状以及可选的性质提示(恶性或良性)。
这里是一个图像分割模型对来自乳腺超声图像数据集的超声图像进行预测的示例:
来源:乳腺超声图像数据集
左侧的框是数据集中超声图像;它包含一个可能是恶性或良性的病变(黑暗区域)。中间的框是地面真实标签,同样是数据集的一部分;一位人类专家在感兴趣区域(病变)周围绘制了轮廓;这些标签用于训练模型,并在训练后评估其性能。右侧的框是我的模型预测,在这种情况下接近地面真实标签。在此情况下,模型未设置区分恶性和良性病变,它们都以黄色调显示。
基于上述描述的顶点项目,本项目(同样是UWLAX 的 CADBUSI 项目的一部分,如同顶点项目)从这个关键观察开始:
医学影像与其他分割应用有所不同,因为观察模型输出的用户(放射科医师)在该领域具有显著的专业知识。用户并不像大多数其他应用那样完全被动。问题不在于模型在识别病变方面是否优于人类操作员,而在于如何将模型的能力与用户的知识结合起来,以获得更好的整体结果。
用户可能在各个方面都同意模型的预测。或者模型和用户之间可能存在分歧。此外,用户可能拥有模型所没有的患者知识。如果用户能够向模型提供提示或反馈作为额外输入数据,将有助于得到更高质量的预测,结合模型和用户的优点。
此外,用户应该能够以简单的方式向模型提供反馈——例如,通过点击图像以突出显示重要区域。用户生成的鼠标点击的坐标成为模型的额外输入,模型应根据这些坐标调整其预测。
如果你从头开始构建模型,你可以设计任何你想要的输入,以包含各种数据。但在这种情况下,你需要进行完整的预训练周期(例如使用 ImageNet),这需要大量的计算和时间资源。如果你使用一个在 ImageNet 上预训练的现成模型,这可以节省大量时间和精力,但模型的输入中可能没有为用户反馈留出空间——这些模型设计为仅接收图像作为输入。
除非你能够识别出现成模型输入中的冗余部分。这意味着输入中存在可以重新利用的冗余通道,以便向模型提供用户反馈,除了常规的影像数据。这一系列文章将描述如何:
-
识别现成图像分割模型中的输入冗余。
-
使用冗余的输入通道来提供用户反馈。
-
训练模型正确识别用户反馈
-
尽可能地自动化整个过程
当模型出错时
让我们考虑这种情况:
来源: 乳腺超声图像数据集
图像左侧似乎有一个感兴趣区域(RoI)——一个小的、暗的、椭圆形的区域。这在标签中显示为黄色区域。但模型的预测结果为空集——预测框中没有黄色像素。模型似乎不认为图像中存在 RoI。我们在这里扮演人类专家的角色,不同意这个观点。
或者这种情况:
来源: 乳腺超声图像数据集
人类专家(标签的作者)认为图像中只有一个真正的感兴趣区域(RoI)。然而,模型识别出了两个独立的 RoI。
在这些情况下,当预测相当边缘化,而人类专家可能有理由不同意模型的预测时,允许用户反馈来引导或调整模型的预测会很有用,这样可以基于用户拥有但模型没有的信息或知识。
理想情况下,用户应该能够通过非常简单的方法提供反馈,例如点击图像中的各个区域——点击坐标提供的信息应被模型考虑,以调整其预测。点击坐标成为模型输入的一部分。这可以通过多种不同的方式实现。
Liu 等人(2022)的《PseudoClick》论文描述了一种模型架构,其中点击通过单独的输入层提供给模型:模型有一个用于实际图像的输入和一个用于点击的不同输入。如果你从头开始构建你的模型,你可以按自己喜欢的方式设计它,并且可以借鉴 PseudoClick 架构的建议。
但如果你使用现成的模型,你必须使用现有的输入。这将在下一节中描述。
使用颜色通道提供反馈
如果你使用现成的视觉模型,可能它是为了处理彩色图像而构建的——模型的输入实际上是三个相同的层,每个层对应一个颜色通道:红色、绿色和蓝色。如果输入是黑白图像,如超声图像的情况,那么相同的信息(纯亮度)在所有颜色通道中分布的方式是相同的。换句话说,为了相同的信息有三个独立的通道似乎是多余的。
如果只使用一个颜色通道处理单色图像,模型是否会表现相同?假设我们将两个颜色通道(R 和 G)归零,只保留 B 通道中的图像信息。
来源: 乳腺超声图像数据集
使用预训练的图像分割模型(例如 SegFormer MiT-b3(可在 HuggingFace 仓库中找到,在 ImageNet 上预训练))测试这个想法,很明显,用通常的图像分割指标(IoU,Dice)测量的模型性能没有变化。模型的工作方式基本相同。在对单色图像进行预测时,颜色通道的冗余既没有帮助也没有伤害。
这意味着我们可以仅在 B 通道中保留图像数据,而将 R 和 G 通道用于额外的输入——用户生成的点击。不仅如此,我们还有两个独立的通道,R 和 G,这可以为模型提供不同类型的输入。
这正是我们需要的:一种点击应该是“激活”或“正面”的,告诉模型*“这是一个感兴趣的区域,集中注意力”,而另一种点击应该是“抑制”或“负面”的,告诉模型“这里没有活动,避免这个区域”*。
问题是,如何在训练数据中放置点击,以使模型对用户输入作出响应?这将在下一部分中描述。
真阳性、假阳性、假阴性
分割模型的预测是图像中像素以某种方式标记的区域——例如,标记为非零值。当预测区域与该图像的真实标签紧密匹配时,我们说模型表现良好。与标签匹配的预测像素称为真阳性(TP)。
在模型做出非零预测但标签中的像素为零的地方,这些像素被称为假阳性(FP)。在标签中的像素为非零而模型预测为零的地方,这些被称为假阴性(FN)。以下是一个示例:
来源:乳腺超声图像数据集
左侧框架是标签。中间框架是模型的预测。在右侧框架中,我们用白色标记了真阳性(TP):预测像素与标签像素匹配。假阳性(FP)是那些预测值为非零但标签中为零的像素,用绿色标记。假阴性(FN)是那些预测值为零但标签中为非零的像素,用红色标记。
如果我们知道模型容易出错并产生错误预测(FP,FN)的图像区域,我们可以向原始数据集中添加标记,标记这些 FP 和 FN 区域。由于我们已经将所有图像信息转移到了蓝色通道,我们可以利用红色和绿色通道来标记这些区域。
例如,我们可以在红色通道中标记 FP 区域的点击。我们希望这些点击能成为“抑制性”或“负面”点击,指导模型避免在此处进行预测。FN 区域可以在绿色通道中标记点击,这些点击将成为“激活性”或“正面”点击,引导模型更多地关注这些区域。示例:
来源:乳腺超声图像数据集
来源:乳腺超声图像数据集
在上面显示的图像中,我们在错误预测的区域(模型预测存在 RoI 但实际上不存在)放置了红色通道点击(负点击),并在假阴性区域(模型预测什么都没有但实际上有 RoI)放置了绿色通道点击(正点击)。为了保险起见,在真阳性区域我们还放置了几个绿色通道点击,以“锚定”预测并确保 TP 区域保持稳定。
使用点击完成的完整训练过程如下所述。
如何使用点击训练模型
这些是训练模型以响应用户输入的主要步骤:
-
决定采用特定的现成图像分割模型,例如预先训练的 SegFormer MiT-b3 模型(使用 ImageNet 进行预训练)
-
处理所有单色图像,使得图像数据仅存在于 B 通道中;将 R 和 G 通道置空
-
将图像数据集分成 5 折;对每个折进行微调,这将创建 5 个模型,每个模型针对数据集中的不同折进行微调;我们称这些为基线模型
-
使用每个基线模型对其未在训练中见过的图像进行预测;这将为数据集中的所有图像生成预测结果
-
对于每个预测,确定 TP、FP、FN 区域;使用上述描述的 R 和 G 点击覆盖 FP、FN 和可选的 TP 区域;小区域每个接收一个点击;大区域接收多个点击;暂时假设点击坐标手动生成(稍后详述)
-
将点击嵌入数据集中的 R 和 G 通道中,但保留 B 通道不变;每个点击将成为 R 或 G 通道中的一个 3x3 像素区域,其中我们将像素值设为该通道的最大值(例如 uint8 类型的 255)
-
使用相同的 5 折,对数据集进行训练并在 R 和 G 通道中添加点击训练 5 个新模型;我们称这些为点击训练模型
换句话说,我们训练基线模型来“预测”模型倾向于错误的地方,我们根据需要在“错误”区域添加点击,然后使用添加了点击的数据集训练新模型(点击训练模型)。我们希望通过 R 和 G 通道提供的点击,点击训练模型能够作出响应。完整的代码在这里显示。
为了清楚起见,处理后加上点击的图像将如下所示:
来源:乳腺超声图像数据集
来源:乳腺超声图像数据集
左侧框架是我们尝试匹配的真实标签。中间框架是基线模型的预测。右侧框架显示了用于训练点击训练模型的处理图像——所有图像数据已转移到 B 通道,点击已根据需要添加到 R 和 G 通道,以修改预测。
你不需要实际修改数据集来添加 R 和 G 点击。你可以简单地注册点击坐标,并修改数据加载器,以便在训练模型时动态应用点击。如果你需要重新生成点击坐标,这种变体要灵活得多。
此外,你可以借鉴图像增强技术,以一定概率在训练中应用点击。点击并不总是添加到输入中,而只是随机地在某些情况下添加。我使用了 0.9 的概率,效果很好。这个想法是避免让模型过度依赖点击。微调这个参数可能需要进一步探索。
结果
它有效吗?
确实如此。这里是一个经过点击训练的模型进行预测,然后实时响应用户反馈的例子:
我们要求模型进行预测,它覆盖了图像上半部分的两个暗区。我们不同意模型的预测——我们认为左侧区域不是感兴趣区域,因此我们在上面放置了一个抑制(红色)点击。我们还在右侧区域放置了一个激活(绿色)点击。现在模型的预测跟随我们提供的额外信息。
应该注意的是,单个点击(约 3x3 像素)会影响模型在直径几百像素区域内的预测。此外,模型会考虑点击的存在以及图像中可见的特征:在 RoI 中放置一个点击会使模型用预测掩膜填充整个区域,遵循图像中可见的轮廓。
模型有时会很容易遵循用户反馈——这是在模型输出预测存在高歧义/低置信度的情况下。还有其他情况,模型的预测会抵抗被负点击“驱逐”——这是在模型自身输出具有低歧义/高置信度时。
可扩展性
到目前为止描述的技术的主要问题是可扩展性。我们假设点击坐标是手动生成的。换句话说,需要人工操作员逐一筛选所有图像,比对真实标签与预测,决定点击位置及数量,并记录所有点击坐标。
显然,这种方法无法扩展。为包含数百张图像的数据集生成一组点击已经很繁琐且耗时,但还不是不可能的。如果数据集包含数千张图像,或者特别是如果在基准模型更改时需要重新生成点击集,则任务变得不可能。需要某种形式的自动化。
这将是本系列第二部分的主题。我们将展示如何自动化创建点击坐标,以便训练过程可以完全无监督地运行。第二部分将描述一种生成点击的算法,这种算法与人工操作员做出的决策非常相似。
链接、引用、评论
这个项目是我在数据科学硕士课程最后一个学期的毕业设计项目的扩展:github.com/FlorinAndrei/datascience_capstone_project
这个项目和我的毕业设计都在威斯康星大学拉克罗斯分校的计算机辅助乳腺超声影像(CADBUSI)项目中完成,由杰夫·巴格特博士监督。datascienceuwl.github.io/CADBUSI/
本文代码的 GitHub 仓库:github.com/FlorinAndrei/segmentation_click_train
本文中使用的所有超声图像均属于乳腺超声图像数据集,按 CC BY 4.0 许可证提供。引用链接:
Al-Dhabyani, W., Gomaa, M., Khaled, H., & Fahmy, A. (2019). 乳腺超声图像数据集。ResearchGate。检索于 2023 年 5 月 1 日,www.sciencedirect.com/science/article/pii/S2352340919312181
其他链接、引用和评论:
Liu, Q., Zheng, M., Planche, B., Karanam, S., Chen, T., Niethammer, M., & Wu, Z. (2022). PseudoClick: 通过点击模仿进行交互式图像分割。arXiv.org。检索于 2023 年 5 月 1 日,arxiv.org/abs/2207.05282
Xie, E., Wang, W., Yu, Z., Anandkumar, A., Alvarez, J. M., & Luo, P. (2021). SegFormer: 使用变换器进行语义分割的简单高效设计。arXiv.org。检索于 2023 年 5 月 1 日,arxiv.org/abs/2105.15203
HuggingFace 上的预训练 SegFormer 模型:huggingface.co/docs/transformers/model_doc/segformer
本文中不属于乳腺超声图像数据集的图像由作者创建。
通过 Voronoi 分割训练图像分割模型以接受用户反馈,第二部分
如何训练现成的图像分割模型以响应用户反馈
·
关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 5 月 5 日
–
这是关于训练图像分割模型以便这些模型可以响应用户反馈并根据反馈(鼠标点击)调整预测的系列文章的第二部分。
第一部分中,我们描述了训练现成图像分割模型以响应用户反馈的一般策略。第一部分结束时识别出的问题是,手动生成训练模型所需的点击是繁琐且耗时的,如果数据集非常大和/或模型需要频繁重新训练,这可能根本不可行。生成点击需要自动化——这就是本文的主题。
问题
让我们再看一看我们试图解决的问题:
来源:乳腺超声图像数据集
左侧框是带有真实标签的图像;人类专家用黄色标记了感兴趣区域(RoI);这是我们期望模型预测的理想形状。中间框是模型的实际预测。右侧框显示了真实阳性区域(标签和预测重合的地方)、假阳性区域(模型预测为 RoI,但标签中没有此区域),以及假阴性区域(模型未预测任何内容,但实际存在 RoI)。TP 区域以白色显示,FP 区域以绿色显示,FN 区域以红色显示。
为了引导模型的预测,我们在 TP 和 FN 区域放置了正点击(绿色),在 FP 区域放置了负点击(红色),然后用包含点击的图像训练了新的模型。
对于人类操作员来说,放置点击直观上是简单的。但如果将过程分解成独立的逻辑步骤和标准,它会变得相当复杂:
-
将 TP、FP、FN 区域拆分成独立的连续段
-
丢弃非常小的段作为无关内容
-
对于每个剩余的区域,根据区域面积决定要放置的总点击数
-
点击不能彼此太近
-
点击不能太靠近区域边缘
最后两个标准很困难。模糊性(“不能太近”)以及这两个标准相互矛盾,使得生成点击的过程看似很难保证能收敛到模拟人类操作员所做的解决方案。
然而,我们将展示一种方法,它结合了数学概念(Voronoi 镶嵌)与物理学的提示(能量和模拟退火),以产生所需的结果。
Voronoi 镶嵌
维基百科页面对该概念的解释相当到位,如果你研究过聚类算法,这可能会感到熟悉,但我在这里也补充几句。
在左侧框架中,我们有一个带有几个种子点的正方形区域。对于任何种子点,框架中必须有一个区域(一个瓦片),其中所有像素距离该种子点比距离其他所有种子点更近。在右侧框架中,我们显示了这些瓦片,用颜色编码以匹配种子。每个瓦片被称为 Voronoi 单元,而寻找这些瓦片的过程称为 Voronoi 镶嵌。
这个例子中的种子是随机选择的。我们得到的镶嵌不是均匀的。为了获得均匀的镶嵌,种子还必须是其对应瓦片的质心(或接近质心)——这称为质心 Voronoi 镶嵌。这里是一个非常简单的例子:
为了找到能导致区域的质心 Voronoi 镶嵌或其近似的点击坐标(种子),可以使用类似于Lloyd 算法的东西,它非常快速(是 k 均值聚类的标准解算器)。这是一个 Lloyd 算法的模拟器,可以在你的浏览器中实时运行。但这里有两个问题:
-
Lloyd 算法通常用于镶嵌矩形区域。尚不清楚它是否(或如何)推广到我们需要镶嵌的任意区域形状。
-
我们只希望在形状(区域和瓦片)为凸时使用质心镶嵌。当形状为凹时,质心可能会落在我们镶嵌区域之外,这完全不是我们想要的(点击点会在区域之外)。
所以我们需要一种能处理任意形状的方案,它能在形状凹陷时保持点击点在区域内,并且在处理简单矩形区域时表现得像 Lloyd 算法。这是对任意区域(即使是凹形)的质心 Voronoi 镶嵌的概括。这是下一节的主题。
能量的模拟退火
考虑这种镶嵌:
点击分布足够均匀,点击坐标与质心的距离不远(所有形状均为凸形)。这对于我们的目的来说分布还不错。我们能否找到一个与点击坐标对应的目标函数,尝试迭代地最大化或最小化,以达到这样的分布?
让我们看看点击周围的空间:
像素能量的热图
想象每个图像中的像素都被分配了一个“能量”。只有一个点击对像素的能量有贡献——最接近的点击。其他所有点击都没有贡献。任何像素的能量与其到最近点击的距离成反比。因此,要找出任何像素的能量,我们需要:
-
找到最近的点击
-
计算像素到点击的距离
-
计算距离的倒数,即像素的能量
上面展示的图像只是给定点击分布的像素能量热图。Voronoi 图块的边缘已经由邻近点击之间最暗的区域所提示。
如果我们计算所有像素的总能量,然后移动点击,寻找提供最高总能量的点击位置,这会导致区域的均匀铺设吗?实际上,这正是上面展示的点击分布的获取方式:
-
从完全随机的点击坐标开始
-
计算感兴趣区域内所有像素的总能量
-
对点击坐标应用模拟退火,以找到最大化总能量的坐标
完整代码在这里展示。该算法很强大,可以很好地处理凹形状——这是一个在镰刀形分割中以视觉上均匀的方式放置点击的示例:
在凹形中放置点击
分割区域是较浅的蓝色阴影,形状像镰刀,与深色背景形成对比。点击是最亮的点。
点击不会离分割边缘太近,因为那样会减少总能量(分割边缘之外的像素没有能量)。它们也不会彼此靠得太近,因为那样不会“激活”远离紧密点击组的像素。该算法是自我调节的。
注:这个问题与手机塔覆盖问题有相似之处,即你尝试在地图上放置 N 个手机塔,使得信号在大多数区域尽可能强。
返回到分割模型
总结一下,我们尝试训练图像分割模型,使其对用户反馈(鼠标点击)做出响应。总体过程是:
-
将图像数据集分成 5 个折
-
为每个折训练一个分割模型;这会生成一组 5 个基线模型
-
使用基线模型对所有图像进行预测;每个模型对训练中未见过的图像进行预测
-
比较预测与标签;提取所有包含真正阳性、假阳性、假阴性预测的区域
-
将 TP、FP、FN 区域拆分成连续的段;丢弃最小的段(少于 100 个像素或更少)
-
对于每个区域,生成均匀的点击,如本文所示;点击次数取决于区域的大小:较大的区域会收到更多的点击,直到一个合理的限制(例如,512x512 像素的图像大约为 4 … 5 次)。
-
将所有图像信息移至 B 通道,腾出 R 和 G 通道,将点击嵌入 R 和 G 通道;TP 和 FN 区域的点击在 G 通道(正点击);FP 区域的点击在 R 通道(负点击)。
-
使用点击增强图像,在相同的折叠上训练 5 个新的模型;这些是点击训练模型,是整个项目的最终成果。
这里是从实际基线模型预测中生成的图像区域均匀点击的一些示例。我们从数据集中选择了 3 张图像,使用基线模型进行预测,并查看每张图像的 TP、FP、FN 区域。每个区域的颜色比背景色浅,点击是每个区域中最亮的点。
来源:乳腺超声图像数据集
点击最终被放置在一个或多或少符合人工操作员放置位置的地方:彼此之间距离不太近,距离边缘也不太近,偏向于每个区域中的大面积或宽区域。点击分布似乎在视觉上是均匀的。
最终思考
我们训练了现成的图像分割模型来响应用户反馈,而没有以任何方式更改其架构,也没有在 ImageNet 上从头开始重新训练它们。
点击训练模型的表现不一定优于基线模型。正如这里所示,使用点击进行训练只是使模型能够响应用户反馈。当然,点击训练模型在用于创建 5 个训练折叠的数据集上会远远优于基线模型。这是因为创建点击本质上在训练和测试之间泄漏了数据。在 100% 之前未见过的数据上,点击训练模型和基线模型的表现相同。
这里有更多点击训练模型对用户反馈响应的示例。
视频中展示了两张不同的图像。在这两张图像中,你可以看到模型对其预测的 RoI 有很高的信心。尝试在预测的 RoI 中放置负点击不是很成功——模型继续预测该区域为 RoI。
模型接受对其他区域作为潜在 RoI 的建议。
在这两种情况下,你可以看到模型的两种输出:纯分割和热图。热图只是 RoI 可能性的地图。
链接、引用、评论
这个项目是我在数据科学硕士学习最后一个学期的顶点项目的扩展:github.com/FlorinAndrei/datascience_capstone_project
本毕业设计及相关工作都在威斯康辛大学拉克罗斯分校的乳腺超声图像计算机辅助诊断(CADBUSI)项目中完成,由 Jeff Baggett 博士监督。datascienceuwl.github.io/CADBUSI/
这篇文章的 GitHub 代码库:github.com/FlorinAndrei/segmentation_click_train
本文中使用的所有超声图像都属于乳腺超声图像数据集,可在 CC BY 4.0 许可下使用。引用链接:
Al-Dhabyani, W., Gomaa, M., Khaled, H., & Fahmy, A. (2019). 乳腺超声图像数据集。ResearchGate。2023 年 5 月 1 日检索自www.sciencedirect.com/science/article/pii/S2352340919312181
其他链接、引用和评论:
Liu, Q., Zheng, M., Planche, B., Karanam, S., Chen, T., Niethammer, M., & Wu, Z. (2022). PseudoClick:带有点击仿真的交互式图像分割。arXiv.org。2023 年 5 月 1 日检索自arxiv.org/abs/2207.05282
Xie, E., Wang, W., Yu, Z., Anandkumar, A., Alvarez, J. M., & Luo. P. (2021). SegFormer:用于语义分割的简单高效设计与 Transformers。arXiv.org。2023 年 5 月 1 日检索自arxiv.org/abs/2105.15203
HuggingFace 的预训练 SegFormer 模型:huggingface.co/docs/transformers/model_doc/segformer
本文中不属于乳腺超声图像数据集的图像由作者创建。
使用自动梯度下降训练 ImageNet,无需超参数
迈向面向架构的优化
·
关注 发表在 Towards Data Science · 8 分钟阅读 · 2023 年 4 月 19 日
–
TL;DR 我们提出了一种名为自动梯度下降(AGD)的优化器,可以在无需超参数的情况下训练 ImageNet。这消除了对昂贵且耗时的学习率调整、学习率衰减调度器选择等的需求。我们的论文可以在这里找到。
我与Jeremy Bernstein、Kevin Huang、Navid Azizan 和Yisong Yue一起工作。请查看 Jeremy 的GitHub以获取干净的 Pytorch 实现,或者查看我的GitHub以获取更多功能的实验版本。图 1 总结了 AGD、Adam 和 SGD 之间的区别。
图 1 实线表示训练准确度,虚线表示测试准确度。左侧: 与我们的方法相比,Adam 和使用默认超参数的 SGD 在 CIFAR-10 的深度全连接网络(FCN)上表现较差。中间: Adam 和 SGD 的学习率网格搜索。我们的优化器表现得与完全调优的 Adam 和 SGD 相当。右侧: AGD 在 ImageNet 上训练达到了令人满意的测试准确度。
动机
任何训练过深度神经网络的人都可能需要调整学习率。这是为了 (1) 确保训练的最大效率,以及 (2) 因为找到合适的学习率可以显著提高整体泛化能力。这也是一件非常麻烦的事。
图 2 为什么学习率对优化如此重要。为了最大化收敛速度,你需要找到恰到好处的 学习率: 大,但又不至于让目标函数中的非线性项使你偏离轨道。
然而,对于 SGD,最佳学习率高度依赖于正在训练的架构。找到它通常需要代价高昂的网格搜索程序,覆盖多个数量级。此外,其他超参数,如动量和学习率衰减调度器,也需要选择和调整。
我们提出了一种叫做自动梯度下降(AGD)的优化器,它不需要学习率来训练各种架构和数据集,甚至可以扩展到 ImageNet 上的 ResNet-50。这消除了任何超参数调整的需求(因为有效学习率和学习率衰减从分析中排除),节省了计算成本,并大大加快了模型训练的过程。
我们为什么需要超参数呢?
深度学习系统由许多相互关联的组件组成:架构、数据、损失函数和梯度。这些组件的交互方式有一定的结构,但至今没有人完全确定这个结构,因此我们仍需进行大量调优(如学习率、初始化、调度器),以确保快速收敛,并避免过拟合。
但是,完美地表征这些交互作用可能会移除优化过程中的所有自由度——这些自由度目前由手动超参数调整处理。二阶方法目前使用 Hessian 表征目标对权重扰动的敏感性,并以这种方式移除自由度——然而,这些方法可能计算量大,因此在大型模型中不实用。
我们通过分析这些交互作用来推导 AGD:
-
我们在给定数据和架构的情况下,将神经网络的输出变化与权重变化联系起来。
-
我们将目标(批次中所有输入的总损失)的变化与神经网络的输出变化联系起来。
-
我们将这些结果结合在一种所谓的主次优化方法中。我们主次优化目标——即,我们推导出一个与目标相切的目标上界。然后我们可以最小化这个上界,知道这样做会使我们向下移动。这在图 3中得到了可视化,其中红色曲线显示了目标函数的主次优化,如蓝色曲线所示。
图 3 左面板展示了主次优化的基本思想——通过最小化一系列上界或主次优化(红色),来最小化目标函数(蓝色)。右面板展示了权重的变化如何引起函数的变化,这进而引起单个数据点上的损失变化,最终引起目标的变化。我们将∆L与∆W相关联,并利用它来构建我们的主次优化。
Pytorch 中的 AGD
在本节中,我们将逐步讲解算法的所有关键部分。有关推导的简略内容,请参见附录 A。
关于参数化
我们使用的参数化方法与传统的 PyTorch 默认设置略有不同。虽然可以在不假设这种参数化的情况下推导 AGD,但使用这种参数化可以简化分析。对于完全连接层l,我们使用正交初始化,并将其缩放,以使奇异值的大小为 sqrt(l的输入维度/ l的输出维度)。
我们使用这种归一化,因为它具有 PyTorch 默认参数化所没有的一些良好特性,包括宽度的稳定性、对激活值爆炸的抵抗力以及促进特征学习。这类似于Greg Yang 和 Edward Hu 的 muP。
关于更新
这一步可以分成两个独立的部分。第一部分是计算 eta(η),即“自动学习率”,它会缩放所有层的更新。Eta 对梯度范数有对数依赖——当梯度较小时,eta 大致线性(像标准优化器一样),但当梯度非常大时,对数会自动执行一种梯度裁剪。
每层的更新使用 η 乘以层的权重范数,再乘以标准化梯度,最后除以深度。除以深度的操作负责与深度的缩放。有趣的是,梯度标准化在分析中消失了,因为其他优化器(如 Adam)以启发式方式结合了类似的思想。
实验
这些实验的目标是测试 AGD 的能力:(1) 在广泛的架构和数据集上收敛,以及 (2) 实现与调整过的 Adam 和 SGD 相当的测试准确率。
图 4 显示了从全连接网络(FCN)到 ResNet-50 的四种架构在 CIFAR-10 到 ImageNet 数据集上的学习曲线。我们将 AGD(用实线表示)与标准优化器(用虚线表示,ImageNet 上为 SGD,其他三种数据集上为调整过的 Adam)进行了比较。第一行 显示了训练目标(损失)和自动学习率 η。第二行 显示了训练和测试准确率。图 5 比较了 AGD、调整过的 Adam 和调整过的 SGD 在一个 8 层 FCN 上的表现。我们看到这三种算法的性能非常相似,测试准确率几乎一致。
图 6 显示 AGD 可以在广泛的深度(2 到 32 层)和宽度(64 到 2048)上训练 FCNs。图 7 显示了 AGD 对批量大小(从 32 到 4096)的依赖性,测试了一个 4 层的 FCN。无论批量大小如何,AGD 似乎都能收敛到一个良好的最优解!
图 4 AGD 与 Adam 在四种架构上的比较:CIFAR-10 上的深度 16 FCN、CIFAR-10 上的 ResNet-18、CIFAR-100 上的 VGG-16 和 ImageNet-1k 上的 ResNet-50。AGD 在超参数调整的 Adam(需要在多个数量级上进行网格搜索)中保持了合理的速度!这些实线表示 AGD,虚线表示 Adam(除了 ImageNet,我们使用了 SGD)。第一行 显示了训练目标(即损失)和训练期间自动学习率 η 的值。第二行 显示了训练和测试准确率。
图 5 AGD 与 Adam 和 SGD 在一个深度 8 的 FCN 上的比较,损失为均方误差。Adam 和 SGD 的学习率经过调整。在左侧,我们绘制了训练和测试目标函数(即损失)。中间 显示了训练和测试准确率。右侧 显示了每个周期中权重的平均、最小和最大变化。
图 6 AGD 能够在极大的深度和宽度范围内开箱即用地收敛。较小的架构由于能力不足,无法实现低损失,但 AGD 仍能训练它们!
图 7 为了确认 AGD 不仅仅适用于批量大小 128,这里展示了一个深度 4 FCN 的各种批量大小。
结论
总结来说,这里有一个“架构感知”优化器:自动梯度下降(AGD),能够在各种批量大小下训练从 CIFAR-10 上的 FCN 等小型系统到 ImageNet 上的 ResNet-50 等大型系统,无需手动调整超参数。
虽然使用 AGD 并没有消除机器学习中的所有超参数,但剩下的超参数——批量大小和架构——通常属于“尽可能大以填满时间/计算预算”的类别。
然而,仍有许多工作要做。我们没有明确考虑由于批量大小引入的梯度的随机性。我们也没有研究像权重衰减这样的正则化。虽然我们在增加对仿射参数(在批量归一化层中)和偏置项的支持方面做了一些工作,但我们尚未进行广泛测试,也没有像这里的其他结果那样得到理论上的充分证明。
也许最重要的是,我们仍需进行对变压器的分析,并在 NLP 任务上测试 AGD。与 GPT-2 在 OpenWebText2 上的初步实验表明,AGD 在这里也有效!
最后,可以查看 Jeremy 的GitHub以获取干净版本,或者我的GitHub以获取支持偏置项和仿射参数的开发版本,如果你想尝试 AGD!我们希望你会觉得它有用。
附录 A
我们将在这里简要介绍证明的重要步骤。这是为了那些想要了解主要思想如何汇聚的人,而无需查看我们论文中的完整证明,论文可以在这里找到。
方程(1)明确规定了如何将数据集S上的总体目标分解为单个数据点。L表示损失,x为输入,y为目标,w为权重。方程(2)展示了目标的线性化误差的分解——在给定某些权重变化Δw时,高阶项对损失变化ΔL(w)的贡献。目标的线性化误差很重要,因为它等于在权重w处扩展的损失中高阶项的贡献——界定这一点将告诉我们可以移动多远,直到高阶项变得重要,并确保我们迈出的步伐是合理的、向下的。
方程(2)右侧的第一个项是两个高维向量的内积,即模型的线性化误差和关于*f(x)*的损失的导数。由于没有明确的理由说明这两个向量应该对齐,我们假设它们的内积为零。
将*L(W+ΔW)*添加到方程(2)的两边,并注意到损失的线性化误差恰好是 Bregman 散度,我们可以简化符号:
Bregman 散度是衡量两点之间距离的度量(在这种情况下是神经网络的两个不同参数选择的输出),其定义基于严格凸函数——在这里是损失函数。
计算 Bregman 散度对于均方误差损失实际上非常简单,并给出了
其中 dₗ 是网络的输出维度。我们现在断言以下缩放。这些缩放有些任意,但将它们以这种形式表示将使分析变得更加简单。
我们使用了对网络输出大小的以下两个界限。方程 (5) 对网络输出的幅度进行界限,这来自于对全连接网络应用(输入缩放)和(权重缩放)。方程 (6) 对权重 W 的变化带来的 f(x) 的最大变化进行界限。方程 (6) 中的第二个不等式在大深度时最紧,但适用于任何深度。
现在,我们将方程 (6) 代回方程 (4) 中,并将所有项展开以得到方程 (7)。
我们可以将方程 (7) 中的总和替换为 G,这在方程 (8) 中定义,并在关于梯度条件的附加假设下讨论,详细内容请参阅论文。最后,我们得到方程 (9)——这就是主化——图 3中的红线。我们通过对 η 求导来最小化主化,并求解结果中的二次方程,保留正解。这给出了以下更新
这就结束了我们对自动梯度下降的推导。如果你有任何评论、问题或其他反馈,请告知我们。
博客中的所有图像均由我们论文的作者制作。图 2 的灵感来源于这个图表。
通过更改仅一行代码,在 GPU 上训练你的 ML 模型
利用 cuML 和 ATOM 让你的机器学习管道快速如闪电
·
关注 发表在 Towards Data Science ·5 分钟阅读·2023 年 3 月 20 日
–
图片由 Thomas Foster 提供,来源于 Unsplash
简介
图形处理单元 (GPU) 可以显著加速预处理步骤或训练机器学习模型的计算。训练模型通常涉及计算密集型矩阵乘法和其他可以利用 GPU 大规模并行架构的操作。在单个处理器上训练大型数据集可能需要几个小时。然而,如果将这些任务转移到 GPU 上,你可以将训练时间减少到几分钟。
在这个故事中,我们将展示如何使用 ATOM 库轻松在 GPU 上训练你的机器学习管道。ATOM 是一个开源 Python 包,旨在帮助数据科学家加速机器学习管道的探索。如果你想了解库的温和介绍,请阅读 这个故事。
设置
ATOM 使用 cuML 作为 GPU 训练的后端库。cuML 是一套快速、GPU 加速的机器学习算法,旨在数据科学和分析任务中使用。不幸的是,cuML 不能通过 pip 安装,因此未作为 ATOM 的依赖项安装。请阅读 这里 了解如何安装它。
cuML 的要求需要考虑:
-
操作系统:Ubuntu 18.04/20.04 或 CentOS 7/8,使用 gcc/++ 9.0+,或 Windows 10+,使用 WSL2
-
GPU:NVIDIA Pascal™ 或更高版本,具有 计算能力 6.0+
-
驱动程序:CUDA 和 NVIDIA 驱动程序版本 11.0、11.2、11.4 或 11.5
提示: 查看 这个仓库 以在 SageMaker Studio Lab 上安装 cuML。
示例
在使用 GPU 的 atom 中训练变换器和模型就像
使用参数 device="gpu"
初始化 atom。device
参数接受任何遵循 SYCL_DEVICE_FILTER 过滤器选择器的字符串。示例包括:
-
device=”cpu” (使用 CPU)
-
device=”gpu” (使用默认 GPU)
-
device=”gpu:0" (使用第一个 GPU)
-
device=”gpu:1" (使用第二个 GPU)
注意: ATOM 不支持多 GPU 训练。如果有多个
机器上的 GPU 和 device
参数未指定使用哪个
一个可用时,默认使用第一个。
使用engine
参数在 cuML 和 sklearnex 执行引擎之间进行选择。在这个故事中,我们将专注于 cuML。XGBoost、LightGBM和CatBoost模型配备了自己的 GPU 引擎。设置 device="gpu"即可通过 GPU 加速这些模型,无论 engine 参数如何。点击这里查看支持 GPU 加速的变换器和模型概述。
提示: 如果你没有 GPU 访问权限,可以使用在线云服务,如Google Colab或Sagemaker Studio Lab来尝试。请确保选择 GPU 计算类型。查看这个笔记本以开始使用。
让我们开始这个例子。
from atom import ATOMClassifier
from sklearn.datasets import make_classification
# Create a dummy dataset
X, y = make_classification(n_samples=100000, n_features=40)
atom = ATOMClassifier(X, y, device="gpu", engine="cuml", verbose=2)
不仅模型,变换器也可以从 GPU 加速中受益。例如,将特征缩放到 mean=0 和 std=1。
atom.scale()
print(atom.dataset)
由于我们声明要使用 cuML 引擎,ATOM 会在有可用时自动从该库中选择变换器。
print(f"Scaler used: {atom.standard}")
print(f"Scaler's module: {atom.standard.__class__.__module__}")
让我们训练三个模型:随机森林在 cuML 中可用,随机梯度下降不可用,XGBoost有自己的 GPU 实现。
atom.run(models=["RF", "SGD", "XGB"])
atom.results
注意模型之间训练时间的巨大差异!
如果我们检查底层估算器,会发现 RF 模型确实是在 GPU 上训练的,SGD 没有(因为它在 cuML 中不可用,ATOM 回退到默认的 sklearn 实现),而 XGB 模型确实使用其本地模块在 GPU 上训练。
for m in atom.models:
print(f"{m}'s module: {atom[m].estimator.__class__.__module__}")
最后,分析结果像往常一样简单。
atom.evaluate()
结论
我们已经展示了如何使用 ATOM 包在 GPU 上训练你的机器学习管道。ATOM 也能在 CPU 上加速。阅读这个故事了解如何操作。
欲了解更多关于 ATOM 的信息,请查看该包的文档。对于错误或功能请求,请随时在GitHub上提出问题或发邮件给我。
参考文献:
- 所有图表和图片(除封面图外)均由作者创建。
相关故事:
ATOM: 一个用于快速探索机器学习管道的 Python 包
自动化优化建模工具(ATOM)是一个开源 Python 包,旨在帮助数据科学家进行…
如何使用正确的包在 Python 中进行快速深度学习实验的简单指南
如何通过仅更改一行代码显著减少训练时间
自定义 YOLOv7 对象检测与 TensorFlow.js
照片由 Martijn Baudoin 拍摄,来源于 Unsplash
在 PyTorch 中训练自定义 YOLOv7 模型,并将其转换为 TensorFlow.js 以实现浏览器上的实时离线检测
·
关注 发布于 Towards Data Science ·6 分钟阅读·2023 年 3 月 28 日
–
最近,我开源了一个YOLOv7 在 Tensorflow.js 中的实现,我收到的最常见问题是:
你是如何将模型从 PyTorch转换为 Tensorflow**.js**的?
本文将通过使用自定义YOLOv7模型在浏览器和离线环境下直接运行来解决实际问题。
我们将解决的行业是实体零售。除了最近发生的数字化转型——主要是在疫情期间—— 实体店仍然是客户最偏爱的购物目的地。
商店中的一切都关乎体验。 零售反馈组(RFG)已经跟踪了大约 15 年的杂货购物体验,一贯发现 影响客户满意度的最关键因素是顾客是否能够在访问期间找到他们所需的一切,无论是在店内还是在线。
所以零售商们不断关注产品可用性和适当的商品组合以满足客户需求。
在上一篇文章中,我展示了如何创建一个 TensorFlow.js 模型来识别 SKU。在这篇文章中,我们将探讨如何使用自定义 YOLOv7 模型来识别空货架——一切都在实时、离线和智能手机的浏览器中运行。
本文将涵盖的内容:
-
配置环境;
-
收集数据;
-
准备模型进行训练;
-
训练和重新参数化模型;
-
评估模型;
-
转换为 TensorFlow.js;
-
在网页浏览器上部署模型。
配置环境
整个训练流程将使用 Google Colab 提供的免费 GPU 执行。如果你想访问包含所有整合步骤的笔记本,点击这里。
收集数据;
我们将用来训练模型的数据集是 零售空货架——缺货(CC0 1.0 许可证)。它有 1155 张图像中的 3608 个标注和一个独特的类别:缺货。
数据集样本来自 哈佛数据中心
注释应采用 YOLOv7 格式,每张图像都有相应的txt
文件。每行的第一个值表示类别——对于 Stockout 数据集,所有类别都相同,等于0
。行中的其余四个数字表示边界框的坐标。
如果你想创建自己的数据集,可以使用像CVAT这样的工具。
aYOLOv7 注释示例 | 图片作者提供
要下载并解压库存数据集,请使用以下代码:
为训练准备模型
第一步是克隆 YOLOv7 的代码库并安装依赖:
然后,下载在COCO 2017 数据集上预训练的权重。这些权重将用于初始化模型并加速训练——这种技术被称为迁移学习。
由于我们处理的是单类问题,我们选择了 YOLOv7-tiny,这是 YOLOv7 的轻量级变体。
在开始训练之前,我们需要配置一个.yaml 文件,设置我们想使用的参数。
训练和重新参数化模型
训练过程直观且可定制,允许你调整诸如 epoch 数量和批次等参数,以适应数据集的需求。
执行结束时,你将得到保存在yolov7/runs/train/your-model-name/weights/best.pt
的权重。
要以可视化格式查看训练指标,请启动 TensorBoard 或打开图像yolov7/runs/train/your-model-name/results.png
。
运行 tensorboard | 图片作者提供
results.png | 图片作者提供
现在你已经训练好了模型,是时候重新参数化权重以进行推理。
除了实时目标检测的架构优化,YOLOv7 还引入了额外的模块和方法,这些模块可以提高训练效率和目标检测精度。这些模块被称为免费礼包,必须优化以实现高效推理。有关更多信息,请参阅模型论文。
检查下面代码中的权重路径并执行,以生成重新参数化的模型。
评估模型
现在模型已针对推理进行了优化,接下来是运行一些测试图像,以查看它是否检测到货架上的空白区域。
转换为 TensorFlow.js
转换模型可能会很具挑战性,因为它需要经过几个转换:PyTorch 到 ONNX,ONNX 到 TensorFlow,最后是 TensorFlow 到 TensorFlow.js。
以下代码将为你处理所有事情。只需确保模型路径配置正确,然后运行它!
完成后,生成的 TensorFlow.js 模型将由 Google Colab 自动下载。
假设一切顺利,模型现在将被转换为TensorFlow.js layers format。
下载到本地的文件夹应包含一个model.json文件和一组二进制格式的分片权重文件。model.json包含模型拓扑(即“架构”或“图”:层及其连接方式的描述)和权重文件的清单。
└ stockout_web_model
├── group1-shard1of6.bin
├── group1-shard2of6.bin
├── group1-shard3of6.bin
├── group1-shard4of6.bin
├── group1-shard5of6.bin
├── group1-shard6of6.bin
└── model.json
部署模型
现在模型已经准备好可以加载到 JavaScript 中。如前所述,我开源了一个 YOLOv7 的 JavaScript 代码。因此,我们可以利用相同的代码库,并用我们刚刚训练的模型替换原有模型。
克隆代码库:
git clone https://github.com/hugozanini/yolov7-tfjs.git
安装这些包:
npm install
进入[public](https://github.com/hugozanini/yolov7-tfjs/tree/master/public)
文件夹并粘贴训练好的模型。你应该有如下结构:
├── git-media
├── index.html
├── LICENSE
├── node_modules
├── package.json
├── public
│ ├── stockout_web_model
│ └── yolov7_web_model
├── README.MD
└── src
前往[src/App.jsx](https://github.com/hugozanini/yolov7-tfjs/blob/master/src/App.jsx)
并将[line 29](https://github.com/hugozanini/yolov7-tfjs/blob/4efa4cba39d1168d4bcf7bfe79274e54e87d2eaa/src/App.jsx#L29)
上的模型名称更改为 stockout:
const modelName = "stockout";
要执行应用程序,请前往根文件夹并运行以下命令:
npm start
将启动本地服务器,你应该会看到类似于这样的内容:
运行示例 | 作者提供的图像
这个模型也部署在 CodesSandbox 上。访问下面的链接查看其运行情况。
利用 YOLOv7,可以检测到多达 80 种不同的类别。对于零售行业的企业而言,这是一个提升店内产品执行的绝佳机会。
通过训练模型识别公司所有产品,零售商可以确保他们的产品正确摆放在货架上,以便顾客轻松找到所需的商品。
为验证训练模型的有效性,我带着手机去了一家杂货店和药店,并记录了一些实时运行检测器的示例。
在下面的视频中,您可以验证解决方案在真实环境中准确检测空货架的能力。
与SKU 识别模型结合使用的缺货检测器可以极大地提升零售业务的效率和效果。
虽然存在基于云的解决方案,但有时它们可能很慢,需要长达 24 小时才能处理一次检测。相比之下,使用 TensorFlow.js 模型在智能手机浏览器上进行离线实时识别,可以让企业做出更即时的决策,并更快速地响应缺货情况。
总的来说,将缺货检测器与 SKU 识别模型结合使用,可以为优化零售操作和提升顾客购物体验提供强大的途径。通过实时分析和离线识别能力,企业可以做出明智的决策,并快速响应变化的情况。
如果您有任何问题或想分享一个用户案例,请通过Linkedin或Twitter联系我。本文中使用的所有源代码均可在项目仓库中找到。
谢谢阅读 😃
训练深度学习模型以检测微控制器上的 DoS 攻击
一个端到端的项目演练,包括一些来自 ChatGPT 的有用帮助
·
关注 发布于 Towards Data Science ·10 分钟阅读·2023 年 4 月 11 日
–
感谢 Hamed 提供的图片!
ESP32 是一款广泛用于物联网项目的 MCU(微控制器单元),因其低成本和 ESP-IDF(Espressif 物联网开发框架)而受到青睐。该开发板配备了双核 32 位处理器、单芯片 Wi-Fi 和蓝牙连接,广泛用于物联网、家居自动化和机器人项目。
物联网应用可以解决各行各业的问题,包括农业、家庭控制和智慧城市。令人遗憾的是,我还没有在我们生活中看到物联网项目,主要是在巴西。当我想到物联网项目时,第一个想到的就是它非常技术化,我们无法在没有大量资金的情况下实现。幸运的是,这种情况不再成立,因为我们现在拥有像 ESP32 这样的开发板,只需大约 3 美元,我们就可以构思和实验物联网项目的想法。
在了解到 ESP32 之后,我的目标是找到一种方式来融入这个物联网世界。LACNIC 是一个国际组织,负责分配和管理互联网编号资源,并推动拉丁美洲和加勒比地区的互联网发展。该组织的研究领域之一是密码学、安全性和弹性。我一直缺乏参与研究项目的感觉,于是我提交了一份技术论文提案,重点关注物联网和网络安全,作为他们的IT 女性指导计划的一部分。
我项目的主要参考是 T800:物联网的防火墙工具和基准测试[1],这是一个由巴西研究人员创建的项目,为 ESP32 创建了一个数据包过滤器以扫描攻击。他们在 Github 上公开了所有代码。由于这是我第一次使用 ESP32 和 lwIP 协议栈(稍后会详细介绍),如果没有这个参考,我将完全迷失。
很酷,但我的项目到底是什么呢?安全性应该是物联网系统开发的一个重要部分,这些设备因更新和维护周期较差而成为网络攻击的易受攻击目标[1]。我的工作重点是检测流量攻击,特别是检测 DoS(服务拒绝)攻击。在流量攻击中,攻击者的目标是在短时间内向设备发送大量网络数据包,旨在中断系统的操作。
然后主要任务是训练一个机器学习模型来检测 DoS 攻击并将其部署到 ESP32 上。大多数用于检测 DoS 和 DDoS 攻击的模型使用体积和统计特征,如“聚合记录的平均持续时间”和“源到目的地的数据包计数”。我的目标是弄清楚是否可以创建一个基于原始数据包特征的模型,如 tcp.window_size 和 udp.length。这是一次相当艰难的旅程,今天我们将深入讨论主要的障碍以及 ChatGPT 如何在其中帮助我。
挑战 1 — 我的训练数据集的特征是否与生产中的特征匹配?
我使用的数据集是CIC IoT Dataset 2022,由加拿大网络安全研究所 (CIC) 创建,旨在生成用于不同 IoT 设备的状态-of-the-art 数据集,用于分析、行为分析和漏洞测试 [2]。他们使用 wireshark 作为网络协议分析器,捕获并保存网络数据包到 pcap 文件中。
在建模部分,我的主要参考是 DeepDefense: Identifying DDoS Attack via Deep Learning [3]。他们设计了一种递归深度神经网络,通过使用仅 20 个网络流量字段来学习网络流量序列中的模式。
为了与其他网络设备通信,ESP32 需要支持 TCP/IP 协议。为此,它使用 lwIP(轻量级 IP),这是一个为低内存和低计算能力的嵌入式系统设计的开源 TCP/IP 协议栈。我的第一个任务是弄清楚 ESP32 的 lwIP 实现中有哪些特性,并找出如何从 pcap 文件中提取它们。
起初我在阅读wireshark 文档和esp-lwip 代码以尝试找到对应的特征,但后来我想“如果问 chatGPT 呢”?结果发现大模型非常有用:
使用 chatGPT 获取 ESP32 代码中的“ip.ttl”
当然,我必须做一些调整,[1] 的代码是我的主要指南,但我通过使用模型帮助我获取所需特征的名称和变量,节省了很多时间。
挑战 2— 我如何进行训练和测试数据的拆分?
数据集中包含 3 个包含洪水攻击的 pcap 文件,针对每个 IoT 设备。攻击是基于 HTTP、UDP 和 TCP 协议进行的。我想“好吧,我可以用 2 个攻击进行训练/验证,用 1 个进行测试”,但主要问题是如何将恶意流量与合法流量合并以创建数据拆分。
每个设备的攻击 pcap 文件,每个 pcap 文件只有恶意流量
数据集的合法流量是按天捕获的,因此我们有一个每天的 pcap(总共 30 天)。我最初的计划是使用类似于 [3] 的递归深度神经网络,因为这种神经网络会从 网络流量序列中学习模式,随机取样每个桶(合法与恶意)将效果不好。更糟糕的是,我读过的所有论文只提到他们使用了 80/20 切分,但没有解释他们是怎么做的。
我的目标是使训练数据集和测试数据集尽可能接近真实场景。经过大量的思考和反复试验,我找到了一种解决方案,虽然不确定是否完全实现了这个目标,但这也是一种进展。策略是使用 pcap 的 _ws.col.Time 特性在不同时间插入随机攻击。算法如下:
-
以一天满是合法数据包的数据作为起点。
-
对于每个攻击 pcap,在全天数据包的随机位置插入一些攻击数据包(50000 到 70000 个),调整攻击数据包的 _ws.col.Time 并每次对数据集进行排序。
-
对每个设备 pcap 重复第 2 步,使用每个设备的前 2 个攻击 pcap(第三个 pcap 将用于测试数据集)。
对于测试数据集,我做了相同的操作(从另一整天的合法数据包开始),但不是仅取 50000 到 70000 个攻击数据包,而是使用所有的攻击数据包。这使得测试数据集非常不平衡,但这正是现实中的情况。
挑战 3 — 随时学习 C++
用于编程 ESP32,我们使用 C++。除了大学里学过的一些基本 C 类别,我没有相关经验,但我想“我已经用 Python 编程很长时间了,只需处理一些语法上的变化就好了,我会没问题的”。确实大多数时候我都没问题,但对于一些事情,很难找到准确的查询来搜索我所需的信息,因此 chatGPT 变得非常有用。
一个例子是 &
操作符。如果你是一个 Python 程序员,从未用过 C++,你认为下一行代码会做什么?
(TCPH_FLAGS(tcphdr) & TCP_SYN)
返回 0 或 1 无论它们是否相同,对吧?错了!
chatGPT 帮助我们搞清楚代码的作用
在训练数据集中,tcp.flags.syn 的值是 0 和 1,因此为了在 ESP32 上部署模型时得到相同的结果,我们需要这样做:
input->data.f[7] = (TCPH_FLAGS(tcphdr) & TCP_SYN) ? 1 : 0;
这个简单的问题让 chatGPT 帮我节省了大量的调试时间。谢谢 chatGPT。
很好,但模型有效吗?
模型使用了 10 个特征,架构非常简单:
train_features = [
"_ws.col.Time",
"ip.hdr_len",
"ip.flags.df",
"ip.frag_offset",
"ip.proto",
"ip.ttl",
"tcp.window_size",
"tcp.flags.syn",
"tcp.flags.urg",
"tcp.hdr_len",
]
simple_nn_model = tf.keras.models.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(10), dtype=tf.float32),
tf.keras.layers.Dense(9, activation="relu", kernel_regularizer="l2"),
tf.keras.layers.Dense(9, activation="relu", kernel_regularizer="l2"),
tf.keras.layers.Dense(units=1, activation="sigmoid", kernel_regularizer="l2"),
],
name="simple_nn",
)
simple_nn_model.compile(
loss="binary_crossentropy",
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"],
)
测试数据集的准确率为 97%。
测试数据集的结果
当我开始研究时,我想知道为什么我们不直接使用每秒数据包的阈值或其他方法,因为 DoS 攻击者会在短时间内向设备发送大量网络数据包。这不起作用,因为存在不同特征的体积攻击。机器学习模型是最佳选择,因为它们更稳健,能够检测到潜在的可疑流量,无论流量速率如何。
但模型是否有效也取决于应用程序和物联网系统。我想了解模型在物联网系统中的表现,因此我在 ESP32 上用一个简单的 UDP 服务器应用程序进行了测试。
我使用了一个 Python 脚本来生成合法流量,另一个 Python 脚本作为攻击者。结果如下:
| | Malicious packets dropped (true positives) | Malicious packets processed (false negatives) | Legitimate packets processed |
| ------------------ | ------------------------------------------ | --------------------------------------------- | ---------------------------- |
| Without dosguard32 | \- | 1798 | 20 |
| With dosguard32 | 1855 | 495 | 20 |
攻击运行了 60 秒,真实流量运行了 80 秒。总处理的数据包包括所有数据包(合法和恶意),而总合法数据包仅包括来自我模拟的真实传感器数据的数据包。防火墙在处理恶意数据包方面减少了 72.44%,同时没有干扰合法流量。恶意数据包的召回率为 0.789,这明显低于测试数据集的结果。需要注意的是,UDP 服务器实验的结果是初步的,因为它们基于单次执行。
如果你感兴趣,这是我用来模拟真实流量的代码(是的,我也向 chatGPT 请求了这段代码):
import socket
import random
import time
start_time = time.time()
# UDP server config
UDP_IP = "10.0.0.105"
UDP_PORT = 3333
# Create socket UDP
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Sending data
while (time.time() - start_time) < 80:
temperature = round(random.uniform(18.0, 30.0), 2)
humidity = round(random.uniform(30.0, 80.0), 2)
data = f"Temp: {temperature} C, Humidity: {humidity} %".encode()
sock.sendto(data, (UDP_IP, UDP_PORT))
time.sleep(1)
我认为下一步主要是研究如何创建真正类似于真实网络流量的训练和测试数据集。在测试数据集中,我们获得了 99% 的召回率,而在我的“真实场景”测试中,很多攻击数据包仍被分类为合法。也许我的 Python 脚本攻击的特征与训练数据集中的攻击特征非常不同?模型是否过拟合?这些都是继续研究的有趣问题。
最后的想法
这个项目完全让我走出了舒适区,因为我不得不处理许多新的事物:
-
TCP/IP 协议(计算机网络)
-
DoS 攻击(网络安全)
-
ESP32 上使用的 TPC/IP 协议实现(嵌入式设备)
-
了解如何使用 TensorFlow Lite 将模型嵌入 ESP32(嵌入式设备 + 数据科学)
到目前为止,最困难的部分是处理 ESP32 上的 TCP/IP 栈,主要因为关于它的学习资料不多。幸运的是,[1] 的作者公开了他们的代码,我也根据他们的工作顺利上了轨道。这也是我喜欢开源社区的原因之一 ❤。
令我惊讶的是,chatGPT 也帮了我很多。我原本对它是否能解释一些 lwIP 方法和变量持怀疑态度,但它做得非常好。感觉就像我有一个 lwIP 专家可以提问和讨论想法。
关于大型语言模型是否会取代我们以及各种争论有很多,但我们也需要讨论它们如何改变游戏规则。它们在提高我们学习新事物的能力和使知识对每个人更易获取方面非常强大。
当然,你可以在这里查看我们今天讨论过的所有代码:github.com/dmesquita/dosguard32
如果你读到这里,非常感谢你的阅读! 😄
参考文献
[1] Fernandes, Gabriel Victor C., 等。“为物联网设备实现智能包过滤器。” XL 巴西计算机网络与分布式系统研讨会论文集。SBC,2022。
[2] Sajjad Dadkhah、Hassan Mahdikhani、Priscilla Kyei Danso、Alireza Zohourian、Kevin Anh Truong、Ali A. Ghorbani,“朝着开发现实的多维物联网分析数据集”,提交至:第 19 届国际隐私、安全与信任年会(PST2022),2022 年 8 月 22–24 日,加拿大弗雷德里克顿。
[3] 袁小勇、李传煌、李晓林。 “DeepDefense:通过深度学习识别 DDoS 攻击。” 2017 IEEE 国际智能计算会议(SMARTCOMP)。IEEE,2017。
[4] Hamza, Ayyoob、Hassan Habibi Gharakheili 和 Vijay Sivaraman。“物联网网络安全:需求、威胁和对策。” arXiv 预印本 arXiv:2008.09339(2020)。