升级到 PyTorch 2.0 的技巧和窍门
迁移到全新 “编译模式” 时需要注意的事项
·
关注 发表在 Towards Data Science · 19 分钟阅读 · 2023 年 5 月 21 日
–
图片由 Mohamed Nohassi 提供,刊登在 Unsplash
作者
任何新的 AI 开发框架、AI 加速器或 AI 计算平台的发布,都可能带来运行时优化和成本降低,从而改善我们的 AI 开发生命周期。最近发布的 PyTorch 2.0 也不例外。PyTorch 2.x 的亮点之一是引入了torch.compile,据报道,PyTorch 2.x 可以显著加速训练和推理过程。与我们熟悉的 PyTorch 逐步执行模式不同,在该模式下每个 PyTorch 操作都是“急切地”运行,compile API 将你的模型转换为中间计算图(FX 图),然后将其编译成适合底层训练加速器的低级计算内核,使用如 内核融合 和 乱序执行 等技术(有关更多详细信息,请参见这里)。
在这篇文章中,我们将展示这一令人兴奋的新特性及使用过程中可能遇到的一些问题和行为。你可能已经看到一些帖子,突出了使用 torch 编译的便捷性或性能提升效果。或者(像我一样),你可能在过去的两周里一直在与新 API 作斗争,试图让其在你的模型上良好运作。实际上,对于许多公共模型,只需用 torch.compile 调用包裹它们即可(如这里所报道)。然而,正如我们将看到的那样,有许多因素可能干扰图的编译和/或达成期望的性能改进。调整你的模型和/或成功达到最佳性能可能需要你重新设计项目或修改一些编码习惯。
在开始之前,我们应该提到几件事。我们在这篇文章中的意图是分享一些我们在适应 torch.compile API 过程中遇到的问题示例。这些示例绝非全面。你可能会遇到本文未提及的问题。还要记住,torch.compile 仍在积极开发中。我们所写的内容可能在你阅读时已经不再适用。务必保持最新,关注最新的发布和文档。
在 torch 编译中,存在许多创新技术,包括TorchDynamo、FX Graph、TorchInductor、Triton等。虽然我们在此不会深入探讨这些不同的组件,但我们鼓励你从PyTorch 文档、2022 年 PyTorch 大会或这篇有用的 TDS 帖子中了解它们。通常,对幕后发生的事情有一个好的理解可以帮助你弄清楚为什么模型没有编译成功以及如何解决这个问题。
本文绝不应被视为官方 PyTorch 文档的替代品(例如,这里)。本文也不应被视为对 PyTorch 相对于 TensorFlow(或其他 ML 训练框架)、编译模式相对于急切模式,或任何我们提到的工具、库或平台的认可。我发现所有框架都有其优缺点。我对任何特定框架没有强烈的偏好或热情。我的热情在于解决有趣的技术挑战——挑战越难越好——无论它们存在于何种平台或框架上。你可以说我对框架是中立的。尽管如此,请允许我对 PyTorch 和 TensorFlow 库如何随时间演变进行两个完全无关紧要的观察。可以跳过这些观察,直接回到正题。
TensorFlow 与 PyTorch 战争的两个完全无关紧要的观察
观察 1:在过去,当生活很简单时,PyTorch 和 TensorFlow 之间有明显的区别。PyTorch 使用急切执行模式,TensorFlow 使用图模式,大家都很满意,因为我们都知道自己在争论什么。但后来出现了 TensorFlow 2,它将急切执行作为默认执行模式,TensorFlow 变得有点像 PyTorch。现在,PyTorch 也推出了自己的图编译解决方案,变得有点像 TensorFlow。TensorFlow 与 PyTorch 的战争依然继续,但两者之间的差异正在慢慢消失。请参见这条推文,其中对 PyTorch 演变的评论我觉得很有趣。
观察 2:AI 开发是一项时尚的业务。与时尚行业类似,流行的 AI 模型、模型架构、学习算法、训练框架等都会随季节变化而变化。与时尚行业一样,AI 也有自己的出版物和会议,你可以通过这些途径跟上最新的趋势。直到几年前,我们大多数工作的模型都是用 TensorFlow 编写的。而人们对此不满。他们主要的两点抱怨是高层的 model.fit API 限制了他们的开发灵活性,以及图模式使他们无法进行调试。他们说:“我们必须转到 PyTorch”,因为“我们可以按照自己的方式构建模型并轻松调试”。几年的时间过去了,同样的人现在却在说:“我们必须适应 PyTorch Lightning(或其他高层 API),并且必须通过 torch.compile 加速训练”。要明确的是……我不是在评判。我只是想说,也许我们应该更加自我觉察。
回到实际内容
本文的其余部分组织成一系列关于如何开始使用 PyTorch 2 编译 API 的技巧以及你可能遇到的一些潜在问题。根据你项目的具体细节,将模型适配到 PyTorch 的图模式可能需要非同小可的努力。我们的希望是这篇文章能帮助你更好地评估这一努力,并决定采取最佳的步骤。
安装 PyTorch 2
从 PyTorch 安装文档 来看,安装 PyTorch 2 与安装其他版本的 PyTorch 没有什么不同。实际上,你可能会遇到一些问题。首先,PyTorch 2.0(截至本文撰写时)似乎需要 Python 3.8 或更高版本(见 这里)。希望你已经更新到最新的 Python 版本,这不会成为问题,但在不太可能(且不幸)的情况下你没有更新,这可能会成为你升级的另一个动机。此外,PyTorch 2 包含了一些之前版本中不存在的包依赖(最显著的是 pytorch-triton),这可能引入新的冲突。更有甚者,即使你成功构建了 PyTorch 2 环境,你可能会发现调用 torch.compile 会导致严重且完全无法解释的 段错误。
节省麻烦的一种方法是使用一个预构建并经过验证的 PyTorch 2.0 Docker 镜像。在下面的示例中,我们将使用一个官方的 AWS Deep Learning Container 镜像,其中包含 PyTorch 2.0。具体来说,我们将使用763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.0.0-gpu-py310-cu118-ubuntu20.04-sagemaker 镜像,该镜像设计用于在 Amazon SageMaker 上的 GPU 实例进行训练,使用 Python 3.10 和 PyTorch 2.0。
向后兼容性
PyTorch 2 的一个好处是它完全向后兼容。因此,即使你选择继续使用 eager 执行模式而不使用 torch.compile,你仍然被强烈鼓励升级到 PyTorch 2.0 并从其他 新功能和增强 中受益。
玩具示例
让我们从一个图像分类模型的玩具示例开始。在下面的代码块中,我们使用 timm Python 包(版本 0.6.12)构建一个基本的 Vision Transformer (ViT) 模型,并在一个假数据集上训练 500 步。我们定义 use_compile 标志以控制是否进行模型编译(torch.compile),并定义 use_amp 以控制是否使用 自动混合精度 (AMP) 还是全精度 (FP)。
import time, os
import torch
from torch.utils.data import Dataset
from timm.models.vision_transformer import VisionTransformer
use_amp = True # toggle to enable/disable amp
use_compile = True # toggle to use eager/graph execution mode
# use a fake dataset (random data)
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=[index % 1000], dtype=torch.int64)
return rand_image, label
def train():
device = torch.cuda.current_device()
dataset = FakeDataset()
batch_size = 64
# define an image classification model with a ViT backbone
model = VisionTransformer()
if use_compile:
model = torch.compile(model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers=4)
loss_function = torch.nn.CrossEntropyLoss()
t0 = time.perf_counter()
summ = 0
count = 0
for idx, (inputs, target) in enumerate(data_loader, start=1):
inputs = inputs.to(device)
targets = torch.squeeze(target.to(device), -1)
optimizer.zero_grad()
with torch.cuda.amp.autocast(
enabled=use_amp,
dtype=torch.bfloat16
):
outputs = model(inputs)
loss = loss_function(outputs, targets)
loss.backward()
optimizer.step()
batch_time = time.perf_counter() - t0
if idx > 10: # skip first few steps
summ += batch_time
count += 1
t0 = time.perf_counter()
if idx > 500:
break
print(f'average step time: {summ/count}')
if __name__ == '__main__':
train()
在下表中,我们展示了在使用ml.g5.xlarge实例类型和Amazon SageMaker 上运行训练脚本时的性能比较结果。模型编译的影响会因平台而异(例如,参见这里)。一般而言,现代服务器级 GPU 上的加速效果会更高。请记住,这些仅是您可能看到的结果类型示例。实际结果将高度依赖于项目的具体细节。
性能结果(按作者)
我们可以看到,使用AMP(28.6%)相较于使用 FP(4.5%),模型编译带来的性能提升显著。这是一个众所周知的差异(例如,参见这里)。如果您尚未使用 AMP 进行训练,您可能会发现从 FP 到 AMP 的过渡可以实现最显著的性能增益。我们还可以看到,在我们的模型案例中,性能提升伴随着 GPU 内存利用的轻微增加。
请注意,由于在编译图上实现分布式训练的方式,当扩展到多个 GPU 时,性能对比可能会发生变化。有关更多详细信息,请参见这里。
高级编译选项
torch.compile
API 包含多个选项,用于控制图的创建。这些选项使您能够针对您的特定模型进行精细调整编译,并可能进一步提升性能。下面的代码块包含了函数签名(来自这个source)。
def compile(model: Optional[Callable] = None, *,
fullgraph: builtins.bool = False,
dynamic: builtins.bool = False,
backend: Union[str, Callable] = "inductor",
mode: Union[str, None] = None,
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False) -> Callable:
"""
Optimizes given model/function using TorchDynamo and specified backend.
Args:
model (Callable): Module/function to optimize
fullgraph (bool): Whether it is ok to break model into several subgraphs
dynamic (bool): Use dynamic shape tracing
backend (str or Callable): backend to be used
mode (str): Can be either "default", "reduce-overhead" or "max-autotune"
options (dict): A dictionary of options to pass to the backend.
disable (bool): Turn torch.compile() into a no-op for testing
"""
编译模式:编译模式允许您选择减少编译所需开销(“reduce-overhead”)和最大化潜在性能提升(“max-autotune”)之间的权衡。有关更多详细信息,请参见这里。
在下表中,我们比较了上述 ViT 模型在不同编译模式下的编译结果。
性能结果(作者提供)
我们可以看到编译模式的表现基本符合宣传,“reduce-overhead” 在额外内存利用的代价下减少了编译时间,而 “max-autotune” 在编译时间开销高的情况下实现了最大性能。
编译器后端:编译 API 允许您确定使用哪个后端将中间表示(IR)计算图( FX 图)转换为低级内核操作。这个选项对于调试图编译问题和更好地了解 torch.compile 的内部机制(如在这个有趣的示例中所示)非常有用。在大多数情况下(截至撰写本文时),默认的 TorchInductor 后端似乎提供了最佳的训练性能结果。请参见这里获取当前现有后端的列表,或者运行下面的代码查看您的环境中支持的后端。如果您愿意,也可以添加自己的后端 😃.
from torch import _dynamo
print(_dynamo.list_backends())
例如,通过修改上述代码以使用 nvprims-nvfuser后端,我们比急切模式获得了 13% 的性能提升(相比于默认后端的 28.6% 提升)。
强制单一图:fullgraph标志是确保您没有任何不希望的图断裂的极其有用的控制。有关更多信息,请参见下文。
动态形状标志:截至撰写本文时,对具有动态形状的张量的编译支持仍然有限。编译具有动态形状的模型的常见副作用是过度重新编译,这可能显著增加开销并大幅减慢训练速度。如果您的模型确实包含动态形状,将dynamic标志设置为True将会带来更好的性能,特别是减少重新编译的次数。
性能分析
我们已经广泛讨论了(例如,这里)对训练性能进行分析的重要性,作为加速训练速度和降低成本的一种手段。我们用来分析 PyTorch 模型性能的关键工具之一是 PyTorch Profiler。PyTorch Profiler 允许我们评估和分析图编译如何优化训练步骤。在下面的代码块中,我们用 torch.profiler 包装了我们的训练循环,并为 TensorBoard 生成了结果。我们将输出保存在 SM_MODEL_DIR 中,该目录会在训练任务结束时自动上传到持久存储。
out_path = os.path.join(os.environ.get('SM_MODEL_DIR','/tmp'),'profile')
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=20,
warmup=5,
active=10,
repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
dir_name=out_path)
) as p:
for idx, (inputs, target) in enumerate(data_loader, start=1):
inputs = inputs.to(device)
targets = torch.squeeze(target.to(device), -1)
optimizer.zero_grad()
with torch.cuda.amp.autocast(
enabled=use_amp,
dtype=torch.bfloat16
):
outputs = model(inputs)
loss = loss_function(outputs, targets)
loss.backward()
optimizer.step()
p.step()
下图截取自 TensorBoard PyTorch Profiler 标签的 GPU 内核 视图。它提供了在编译模型试验的训练步骤中,运行在 GPU 上的内核的详细信息。
TensorBoard PyTorch Profiler 标签下的内核视图截图(作者提供)
通过将这些图表与急切执行运行的图表进行比较,我们可以看到图编译增加了 GPU 的 Tensor Cores 的使用率(从 51% 增加到 60%),并且引入了使用 Triton 开发的 GPU 内核。
诊断模型编译问题
PyTorch 编译仍在积极开发中(目前处于测试阶段),你在编译模型时遇到问题是完全有可能的。如果你运气好,你会得到一个有用的错误信息,并有一种简单(且合理)的解决办法。如果你运气不好,你可能需要更加努力地找出问题的根源,和/或得出结论:在目前的成熟度水平下,模型编译无法满足你的需求。
解决编译问题的主要资源是 TorchDynamo 故障排除页面,其中包含了调试工具的列表,并提供了 诊断错误 的逐步指南。不幸的是,撰写本文时,这些工具和技术似乎更多地面向 PyTorch 开发者,而非 PyTorch 用户。它们可以帮助找出编译问题的根本原因,提供一些关于如何绕过这些问题的提示,和/或将问题报告给 PyTorch。然而,你可能会发现它们在实际解决问题上并没有帮助。
在下面的代码块中,我们展示了一个简单的分布式模型,该模型包括对torch.distributed.all_reduce的调用。该模型在急切模式下按预期运行,但在图编译过程中(截至本文撰写时)失败,出现“属性错误”(torch.classes.c10d.ProcessGroup 没有名为 ‘shape’ 的字段)。通过将日志级别提高到INFO,我们发现错误发生在计算的“步骤 #3”中,即 TorchInductor。我们可以通过验证“急切”模式和“aot_eager”后端的编译成功来确认这一点。最后,我们可以创建一个最小的代码示例,使用PyTorch Minifier重现该错误。
import os, logging
import torch
from torch import _dynamo
# enable debug prints
torch._dynamo.config.log_level = logging.INFO
torch._dynamo.config.verbose=True
# uncomment to run minifier
# torch._dynamo.config.repro_after="aot"
def build_model():
import torch.nn as nn
import torch.nn.functional as F
class DumbNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1176, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = torch.flatten(x, 1)
x = self.fc1(x)
with torch.no_grad():
sum_vals = torch.sum(x,0)
# this is the problematic line of code
torch.distributed.all_reduce(sum_vals)
# add noise
x = x + 0.1*sum_vals
return x
net = DumbNet()
return net
def train():
os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR',
'localhost')
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT',
str(2222))
torch.distributed.init_process_group('nccl', rank=0,
world_size=1)
torch.cuda.set_device(0)
device = torch.cuda.current_device()
model = build_model()
model = torch.compile(model)
# replace with this to verfiy that error is not in TorchDynamo
# model = torch.compile(model, 'eager')
# replace with this to verfiy that error is not in AOTAutograd
# model = torch.compile(model, 'aot_eager')
model.to(device)
rand_image = torch.randn([4, 3, 32, 32], dtype=torch.float32).to(device)
model(rand_image)
if __name__ == '__main__':
train()
遗憾的是,在我们的示例中,运行生成的 minifier_launcher.py
脚本会导致一个不同的属性错误(‘Repro’ 对象没有属性 ‘_tensor_constant0’),尽管整个过程很有趣,但记录的调试步骤在解决我们演示的编译问题时并没有太大帮助。
显然,我们希望你不会遇到任何编译问题。如果你遇到了,请知道:1. 你并不孤单 😃,2. 尽管它们可能与这里演示的问题不同,但按照故障排除指南中描述的步骤,可能会对它们的来源有所指示。
常见的图断裂
PyTorch 急切模式最受推崇的优势之一是能够将纯 Python 代码与 PyTorch 操作交错使用。不幸的是,(截至目前)当使用 torch.compile 时,这种自由度被显著限制。原因在于,某些 Python 操作会导致 TorchDynamo 将计算图拆分成多个组件,从而阻碍潜在的性能提升。你的目标应该是尽可能地减少图断裂。作为最佳实践,你可能会考虑在将模型迁移到 PyTorch 2 时使用fullgraph标志进行编译。这不仅会促使你移除任何导致图断裂的代码,还会教会你如何最佳地调整 PyTorch 开发习惯以适应图模式。然而,请注意,你需要禁用此标志以运行分布式代码,因为当前 GPU 之间的通信方式需要图断裂(例如,请参见此处)。或者,你可以使用torch._dynamo.explain 工具来分析图断裂,详细说明请参见这里。
以下代码块展示了一个简单的模型,在其前向传递过程中有四个潜在的图断裂(截至目前)。在一个典型的 PyTorch 模型中,看到这些操作中的任何一个都不罕见。
import torch
from torch import _dynamo
import numpy as np
def build_model():
import torch.nn as nn
import torch.nn.functional as F
class DumbNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1176, 10)
self.fc2 = nn.Linear(10, 10)
self.fc3 = nn.Linear(10, 10)
self.fc4 = nn.Linear(10, 10)
self.d = {}
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = torch.flatten(x, 1)
assert torch.all(x >= 0) # graph break
x = self.fc1(x)
self.d['fc1-out'] = x.sum().item() # graph break
x = self.fc2(x)
for k in np.arange(1): # graph break
x = self.fc3(x)
print(x) # graph break
x = self.fc4(x)
return x
net = DumbNet()
return net
def train():
model = build_model()
rand_image = torch.randn([4, 3, 32, 32], dtype=torch.float32)
explanation = torch._dynamo.explain(model, rand_image)
print(explanation)
if __name__ == '__main__':
train()
重要的是要强调,图断裂不会导致编译失败(除非设置了 fullgraph 标志)。因此,完全有可能你的模型在编译和运行时实际包含多个图断裂,这会导致性能下降。
训练问题排查
尽管成功编译模型是一项值得庆祝的成就,但这并不能保证训练一定成功。如上所述,运行在 GPU 上的低级内核在急切模式和图模式之间会有所不同。因此,某些高级操作可能会表现出不同的行为。特别是,你可能会发现急切模式下运行的操作在图模式下会失败(例如,我们遇到的这个 torch.argmin 失败)。或者,你可能会发现计算中的数值差异对你的训练产生影响。
更糟的是,图模式下的调试比急切模式下要困难得多。在急切模式中,每一行代码都是独立执行的,这使我们可以在代码的任何位置设置断点并评估当前张量值。另一方面,在图模式中,我们代码定义的模型在处理之前会经历多个转换,因此,您的断点可能不会被触发。
过去,我们扩展了图模式下调试的难点并提出了几种解决方法。当您遇到问题时,可以尝试以下两步方法。首先,恢复到急切模式,那里调试较少困难,并祈祷问题能够重现。如果没有,尝试在编译的计算图中评估感兴趣的中间张量,通过故意在模型中插入图断点来实现。您可以通过将模型明确地分成两个(或更多)部分并分别应用 torch.compile,或通过插入print和/或Tensor.numpy调用来生成图断点。根据您的操作方式,您甚至可能成功触发代码中的断点。然而,请记住,以这种方式拆分图形可能会修改低级操作的顺序,因此可能无法准确重现完全编译的图形执行。但它确实给您提供了更多的灵活性,以便深入探究问题。
如果您遇到编译模式和急切模式之间意外的差异,请参见准确性调试部分以及故障排除指南。
将损失函数包含在图中
正如我们在上面的示例中所演示的,通过用 torch.compile 调用包装 PyTorch 模型(或函数)可以启用图执行模式。您可能已经观察到损失函数不在编译调用中,因此不在生成的图中。在许多情况下,包括我们展示的那些,损失函数是训练步骤中的一个相对小的部分,急切运行不会造成太多开销。然而,如果您有一个特别重的损失函数,您可以通过将其包含在编译的计算图中来进一步提升性能。例如,在下面的代码块中,我们定义了一个损失函数,用于(天真地)从一个大型 ViT 模型(具有 24 个 ViT 块)到一个较小的 ViT 模型(具有 12 个 ViT 块)进行模型蒸馏。
import torch
from timm.models.vision_transformer import VisionTransformer
class ExpensiveLoss(torch.nn.Module):
def __init__(self):
super(ExpensiveLoss, self).__init__()
self.expert_model = VisionTransformer(depth=24)
if torch.cuda.is_available():
self.expert_model.to(torch.cuda.current_device())
self.mse_loss = torch.nn.MSELoss()
def forward(self, input, outputs):
expert_output = self.expert_model(input)
return self.mse_loss(outputs, expert_output)
我们的实现包括一个在每个输入批次上调用大模型的损失函数。这是一个比上面提到的 CrossEntropyLoss 更加计算密集的损失函数,急切 运行它并不理想。
我们描述了解决这个问题的两种方法。第一种方法是将损失函数简单地包装在一个 torch.compile 调用中,如下所示:
loss_function = ExpensiveLoss()
compiled_loss = torch.compile(loss_function)
这种选项的缺点是,损失函数的编译图与模型的编译图不相交。第二种选项通过创建一个包含两者的包装模型并返回结果损失作为输出,来将模型和损失函数一起编译。此选项在下面的代码块中演示:
import time, os
import torch
from torch.utils.data import Dataset
from torch import nn
from timm.models.vision_transformer import VisionTransformer
# use a fake dataset (random data)
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=[index % 1000], dtype=torch.int64)
return rand_image, label
# create a wrapper model for the ViT model and loss
class SuperModel(torch.nn.Module):
def __init__(self):
super(SuperModel, self).__init__()
self.model = VisionTransformer()
self.expert_model = VisionTransformer(depth=24 if torch.cuda.is_available() else 2)
self.mse_loss = torch.nn.MSELoss()
def forward(self, inputs):
outputs = self.model(inputs)
with torch.no_grad():
expert_output = self.expert_model(inputs)
return self.mse_loss(outputs, expert_output)
# a loss that simply passes through the model output
class PassthroughLoss(nn.Module):
def __call__(self, model_output):
return model_output
def train():
device = torch.cuda.current_device()
dataset = FakeDataset()
batch_size = 64
# create and compile the model
model = SuperModel()
model = torch.compile(model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers=4)
loss_function = PassthroughLoss()
t0 = time.perf_counter()
summ = 0
count = 0
for idx, (inputs, target) in enumerate(data_loader, start=1):
inputs = inputs.to(device)
targets = torch.squeeze(target.to(device), -1)
optimizer.zero_grad()
with torch.cuda.amp.autocast(
enabled=True,
dtype=torch.bfloat16
):
outputs = model(inputs)
loss = loss_function(outputs)
loss.backward()
optimizer.step()
batch_time = time.perf_counter() - t0
if idx > 10: # skip first few steps
summ += batch_time
count += 1
t0 = time.perf_counter()
if idx > 500:
break
print(f'average step time: {summ/count}')
if __name__ == '__main__':
train()
这种方法的缺点是,当需要以 推理 模式运行模型时,内部模型需要从包装模型中提取出来。
在我们的案例中,两种选项都带来了大约 8% 的性能提升,展示了这种优化的重要性。当损失函数被急切运行时,总步长时间为 0.37 秒,而当损失函数被编译时,总步长时间为 0.34 秒。
动态形状
根据 文档 的报告,动态形状模型的编译支持有限(截至本文撰写时)。根据动态性的细节,动态模型可能会带来显著的性能开销,可能会引入图断裂和/或触发过多的 图重编译。图重编译发生在原始编译期间对模型所做的假设(称为 guards)被违反时。
torch.compile API 包含 动态 标志,用于指示编译器优化动态形状。然而,截至本文撰写时,这种优化的效果尚不明确。如果你在编译和优化动态图时遇到问题,可能需要等到支持水平成熟后再考虑使用这一功能。
摘要
PyTorch 2.0 编译模式具有显著提高训练和推理速度的潜力,因此可以实现显著的成本节约。然而,实现这一潜力所需的工作量可能差异很大。许多公共模型只需更改一行代码即可完成。其他模型,尤其是包含非标准操作、动态形状和/或大量交织的 Python 代码的模型,可能需要更多的努力。然而,现在可能是开始调整你的模型的最佳时机,因为编译模式似乎会长期存在。
在 Polars 中处理字符串的技巧与窍门
原文:
towardsdatascience.com/tips-and-tricks-for-working-with-strings-in-polars-ec6bb74aeec2
从排序列名到拆分列
·发布在 Towards Data Science ·阅读时长 9 分钟·2023 年 1 月 17 日
–
图片来源 Raphael Schaller 于 Unsplash
在我过去关于 Polars 的文章中(medium.com/search?q=wei-meng+lee+polars
),我深入探讨了如何开始使用 Polars,它的惰性计算模式如何帮助优化查询和提高处理大数据集的效率,以及如何利用它进行各种任务,如数据清理、数据分析和数据可视化。
我没有深入探讨的一个领域是字符串处理,这在处理数据框时是一个非常常见的话题。在本文中,我将介绍一些你可以在 Polars 中进行字符串处理时使用的技巧和方法。它们包括:
-
排序 DataFrame 列
-
计算字符串长度
-
根据标题选择列
-
使用正则表达式过滤行
-
拆分字符串列
-
替换字符串值
Polars 中的所有标题必须是字符串类型
在我们深入各种技巧和窍门之前,重要的是要记住,在 Polars 中,所有列标题都是字符串类型。请考虑以下示例:
import polars as pl
import numpy as npdf = pl.DataFrame(np.random.randint(0, 100, size=(10, 4)),
columns=list('CDAB'))
df
上面的示例展示了一个具有四列的 Polars DataFrame。请注意,与 Pandas 不同,在 Pandas 中列标题可以是数字类型,而在 Polars 中,所有列标题必须是字符串类型。以下是不允许的:
df = pl.DataFrame(np.random.randint(0, 100, size=(10, 4)),
columns=[1,2,3,4]) # error
如果你真的想使用数字作为列标题,你需要将它们转换为字符串:
df = pl.DataFrame(np.random.randint(0, 100, size=(10, 4)),
columns=list('1234'))
列排序
要排序 Polars DataFrame 中的列,首先使用 Python 的 sorted()
函数对列名进行排序,然后使用 select()
函数重新排列列的顺序:
df.select(
sorted(df.columns)
)
select()
函数返回一个新的 DataFrame,列的顺序已重新排列:
本文中的所有图像均由作者创建
sorted()
函数的默认排序顺序是字母顺序。要以相反的顺序进行排序,将 reverse
参数设置为 True
:
df.select(
sorted(df.columns, reverse=True)
)
你可能会想使用方括号表示法进行排序(如 Pandas 中所示):
df[sorted(df.columns, reverse=True)] # not recommended in Polars
但是,如果你打算将结果与其他 Polars 函数一起使用,这种方法不推荐,因为这种方法不能与延迟计算一起使用。
如果你只是想大致了解数据的样子,可以使用方括号表示法。
字符串长度计数
有时你需要计算特定列中字符串的长度。为此,我有一个名为 names.csv 的自制 CSV 文件:
name,age
Kristopher Ruch,23
Lakiesha Halton,45
Yun Frei,23
Sharyn Llanos,76
Lorretta Herren,21
Merrilee Akana,67
Boyd Gilman,89
Heidy Smalling,11
Leta Batalla,45
Siu Wayne,67
Sammie Gurule,23
Jayne Whetzel,11
Byron Doggett,67
Luke Alcazar,90
Petra Doutt,12
Tula Parkhurst,67
Davina Hess,26
Enda Cornelius,78
Merlyn Cora,89
Jeanett Hardeman,34
首先,将其加载到 Polars DataFrame 中:
import polars as pl
q = (
pl.scan_csv('names.csv')
)
q.collect()
你可以使用 lengths()
函数获取每个名称的长度,然后将其存储在名为 length_of_name 的新列中:
import polars as pl
q = (
pl.scan_csv('names.csv')
.select(
[
'name',
'age',
pl.col('name').str.lengths().alias('length_of_name'),
])
)
q.collect()
基于标题选择列
在我之前关于 Polars 的文章中 (towardsdatascience.com/getting-started-with-the-polars-dataframe-library-6f9e1c014c5c
),我展示了如何使用 select()
函数从数据框中选择列。让我们通过 Titanic 数据集详细了解一下。
数据来源:本节的数据来源于
www.kaggle.com/datasets/tedllh/titanic-train
。许可 — 数据库内容许可(DbCL)v1.0
opendatacommons.org/licenses/dbcl/1-0/
加载 Titanic 数据集时,你会看到它有 12 列:
import polars as pl
q = (
pl.scan_csv('Titanic_train.csv')
)
q.collect()
如果你只想检索 Name 和 Age 列,将它们放入列表中并传递给 select()
函数:
import polars as pl
q = (
pl.scan_csv('Titanic_train.csv')
.select(
['Name','Age']
)
)
q.collect()
如果你想要 所有 列,除了 PassengerId 列,可以使用 pl.exclude()
函数:
import polars as pl
q = (
pl.scan_csv('Titanic_train.csv')
.select(
pl.exclude('PassengerId')
)
)
q.collect()
exclude()
函数支持正则表达式。以下示例检索所有列,除了那些名称以“S”开头的列:
import polars as pl
q = (
pl.scan_csv('Titanic_train.csv')
.select(
pl.exclude("^S.*$") # exclude all columns that starts with S
)
)
q.collect()
如果你想检索特定的列,可以使用 pl.col()
函数。以下示例检索所有以“S”开头的列:
import polars as pl
q = (
pl.scan_csv('Titanic_train.csv')
.select(
pl.col('^S.*$') # get all columns that starts with S
)
)
q.collect()
使用正则表达式筛选行
除了 pl.col()
函数,contains()
函数也支持正则表达式。以下代码片段检索所有名称以“William”结尾的行:
import polars as pl
q = (
pl.scan_csv('Titanic_train.csv')
.filter(
pl.col('Name').str.contains('William$')
)
.select(
[
'Name'
]
)
)
q.collect().to_pandas()
你也可以尝试以下表达式:
-
[Ww]illiam
-
(?i)illiam
-
^William
你能弄清楚它们的作用吗?
这是一个挑战:你如何找到所有不以“William”结尾的名字?
好吧,你可以使用正则表达式来做到这一点,但最简单的方法是使用is_not()
函数来否定contains()
函数中指定的条件:
import polars as pl
q = (
pl.scan_csv('Titanic_train.csv')
.filter(
pl.col('Name').str.contains('William$').is_not()
)
.select(
[
'Name'
]
)
)
q.collect().to_pandas()
请注意,Polars 不支持正则表达式中的前瞻和回顾。
字符串列拆分
另一个关于字符串的流行活动是列拆分。现在让我们基于空格使用split()
函数拆分name列(在names.csv文件中):
q = (
pl.scan_csv('names.csv')
.select(
[
'name',
pl.col('name').str.split(' ').alias('splitname'),
'age',
])
)
q.collect()
结果是一个名为split_name的新列,如下所示:
请注意,名字现在被拆分成了字符串列表。接下来你需要做的是将字符串列表转换为多个列,表示名字和姓氏。你可以使用with_column()
和pl.struct()
函数来完成:
q = (
pl.scan_csv('names.csv')
.select(
[
'name',
pl.col('name').str.split(' ').alias('split_name'),
'age',
])
.with_column(
pl.struct(
[
pl.col('split_name').arr.get(i).alias(
'first_name' if i==0 else 'last_name')
for i in range(2)
]
).alias('split_name')
)
)
q.collect()
with_column()
函数返回一个更新了列的新 DataFrame。在这种情况下,我将使用pl.struct()
函数更新split_name
列,该函数遍历split_name列中的名字列表。将结构视为一组列。更新后的数据框现在如下所示:
值得注意的是,split_name列现在是struct
类型。最后一步是使用unnest()
函数将struct
列拆分为单独的列:
q = (
pl.scan_csv('names.csv')
.select(
[
'name',
pl.col('name').str.split(' ').alias('split_name'),
'age',
])
.with_column(
pl.struct(
[
pl.col('split_name').arr.get(i).alias(
'first_name' if i==0 else 'last_name')
for i in range(2)
]
).alias('split_name')
)
.unnest('split_name')
)
q.collect()
最终结果如下:
现在让我们用 Titanic 数据集尝试另一个例子。特别是,我想关注Name列:
q = (
pl.scan_csv('Titanic_train.csv')
)
q.collect()['Name'].to_pandas().unique()
这是Name列中唯一名字的快照:
'Braund, Mr. Owen Harris',
'Cumings, Mrs. John Bradley (Florence Briggs Thayer)',
'Heikkinen, Miss. Laina',
'Futrelle, Mrs. Jacques Heath (Lily May Peel)',
'Allen, Mr. William Henry', 'Moran, Mr. James',
'McCarthy, Mr. Timothy J', 'Palsson, Master. Gosta Leonard',
'Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)',
'Nasser, Mrs. Nicholas (Adele Achem)',
...
名字以以下格式保存:
姓氏,头衔。名字
理想情况下,你可以在split()
函数中使用正则表达式。你可以使用以下正则表达式将Name列拆分为姓氏、头衔和名字:
‘([\’A-Za-z ()”//.-]+), ([A-Za-z]+). ([A-Za-z ()”//.-]*)’
这个正则表达式在我之前的文章中讨论过:
## 使用正则表达式(RegEx)进行特征工程(Pandas DataFrame)
探索如何使用正则表达式轻松操作你的字符串列
towardsdatascience.com
不幸的是,在写作时,Polars 中的split()
函数不支持正则表达式。
所以策略是多次执行拆分:
这是第一次拆分:
q = (
pl.scan_csv('Titanic_train.csv')
.select(
[
pl.col('Name').str.split(r', ').alias('split_name'),
])
.with_column(
pl.struct(
[
pl.col('split_name').arr.get(i).alias(
'Last Name' if i==0 else 'First Name')
for i in range(2)
]
).alias('split_name')
).unnest('split_name')
)
q.collect()
结果现在显示姓氏被提取出来,接着是名字,其中包含标题和名字:
这是第二次分割:
q = (
pl.scan_csv('Titanic_train.csv')
.select(
[
pl.col('Name').str.split(r', ').alias('split_name'),
])
.with_column(
pl.struct(
[
pl.col('split_name').arr.get(i).alias(
'Last Name' if i==0 else 'First Name')
for i in range(2)
]
).alias('split_name')
).unnest('split_name')
#---Second split---
.select(
[
pl.exclude('First Name'), # get all columns except first name
pl.col('First Name').str.split(r'. ').alias('split_name'),
]
)
.with_column(
pl.struct(
[
pl.col('split_name').arr.get(i).alias(
'Title' if i==0 else 'First Name')
for i in range(2)
]
).alias('split_name')
).unnest('split_name')
)
q.collect()
现在从名字中提取标题:
替换字符串值
替换 DataFrame 中的字符串值是你常常进行的另一个任务。在 Titanic 数据集中,乘客名称中有很多重复的标题。例如,对于女性乘客,使用的一些标题有 Ms、Miss、Mlle、Mlle 和 Mme。你通常会想将标题数量减少到一个更可管理的范围。
要在列中替换字符串,请使用 replace()
函数:
q = (
pl.scan_csv('Titanic_train.csv')
.select(
[
pl.col('Name').str.replace('Mlle.','Miss.'),
])
)
q.collect()
上面的代码片段将 Name 列中所有的 “Mlle.” 替换为 “Miss.”。如果你想将多个标题替换为 “Miss.”,你可以使用正则表达式,如下所示:
q = (
pl.scan_csv('Titanic_train.csv')
.select(
[
pl.col('Name').str.replace('Mlle.|Ms.|Mme.','Miss.'),
])
)
q.collect()
上面的代码片段将 “Mlle.”、 “Ms.” 和 “Mme.” 替换为 “Miss.”。
如果你喜欢阅读我的文章并且它对你的职业/学习有所帮助,请考虑注册成为 Medium 会员。每月 5 美元,它让你可以无限制地访问 Medium 上的所有文章(包括我的)。如果你使用以下链接注册,我将赚取少量佣金(对你没有额外费用)。你的支持意味着我将能花更多时间撰写像这样的文章。
[## 通过我的推荐链接加入 Medium - Wei-Meng Lee]
阅读 Wei-Meng Lee 的每一个故事(以及 Medium 上其他成千上万的作者)。你的会员费直接支持…
weimenglee.medium.com](https://weimenglee.medium.com/membership?source=post_page-----ec6bb74aeec2--------------------------------)
总结
在 Polars 中处理字符串类似于 Pandas。此外,如果你熟悉正则表达式,它一定会让你的工作变得更轻松。我在这篇文章中使用了相当多的函数,所以这里有一个方便的指南,供你在下一次处理 Polars 中的字符串时参考:
-
select()
— 从 DataFrame 中选择你需要的列 -
with_column()
— 返回一个更新了列的新 DataFrame -
unnest()
— 将一个结构体列分解为单独的列 -
str.lengths()
— 返回字符串的长度 -
str.contains()
— 检查字符串是否包含指定的字符串 -
str.split()
— 基于指定的字符串分割字符串 -
str.replace()
— 用另一个字符串替换字符串 -
pl.exclude()
— 排除特定列 -
pl.col()
— 包括特定的列 -
pl.struct()
— 包含一组列的列
提高你 R 技能的技巧和窍门
原文:
towardsdatascience.com/tips-and-tricks-to-improve-your-r-skills-b0f58006d0c1
学习如何编写高效的 R 代码
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 5 月 11 日
–
1234567890-=照片来源于 AltumCode 在 Unsplash
R 广泛用于商业和科学领域作为数据分析工具。该编程语言是数据驱动任务的基本工具。对于许多统计学家和数据科学家来说,R 是解决统计问题的首选。
数据科学家们通常处理大量数据和复杂的统计问题。内存和运行时间在这里扮演了重要角色。你需要编写高效的代码以实现最佳性能。在本文中,我们将介绍一些可以直接在下一个 R 项目中使用的技巧。
使用代码性能分析
数据科学家们经常希望优化他们的代码以提高运行速度。在某些情况下,你会依赖直觉尝试一些方法。这种方法的缺点是你可能优化了代码的错误部分,因此浪费了时间和精力。只有了解代码慢的部分,才能进行优化。解决方案是 代码性能分析。代码性能分析可以帮助你找到慢的代码部分!
Rprof() 是一个内置的代码性能分析工具。不幸的是,Rprof() 并不是很用户友好,因此我们不推荐直接使用它。我们推荐使用 profvis 包。Profvis 允许可视化来自 Rprof() 的代码性能数据。你可以通过 R 控制台使用以下命令来安装该包:
install.packages("profvis")
在下一步中,我们将通过一个示例进行代码性能分析。
library("profvis")
profvis({
y <- 0
for (i in 1:10000) {
y <- c(y,i)
}
})
如果你在 RStudio 中运行这段代码,你将得到以下输出。
火焰图(图片由作者提供)
在顶部,你可以看到你的 R 代码以及每行代码的内存和运行时间条形图。这种显示方式提供了代码中可能存在的问题的概览,但无法帮助你确定确切的原因。在内存列中,你可以看到每次调用分配的内存(右侧条形图)和释放的内存(左侧条形图)(以 MB 为单位)。时间列显示每行代码的运行时间(以 ms 为单位)。例如,你可以看到第 4 行的时间为 280 ms。
在底部,你可以看到带有完整调用栈的火焰图。该图提供了整个调用序列的概览。你可以将鼠标指针移动到单个调用上以获取更多信息。还可以注意到垃圾回收器()消耗了大量时间。为什么呢?在内存列中,你可以看到第 4 行的内存需求增加。第 4 行分配并释放了大量内存。每次迭代都会创建 y
的另一个副本,导致内存使用增加。请避免这种复制-修改的任务!
你还可以使用数据选项卡。数据选项卡为你提供了所有调用的简洁概述,特别适合复杂的嵌套调用。
数据选项卡(图像由作者提供)
如果你想了解更多关于 provis 的信息,你可以访问 Github 页面。
向量化你的代码
也许你听说过向量化。那么它是什么呢?向量化不仅仅是避免使用 for()
循环。它更进一步。你需要从向量而不是标量的角度来思考。向量化对于加速 R 代码非常重要。向量化的函数使用用 C 编写的循环,而不是 R。C 中的循环开销更小,使其速度更快。向量化意味着找到在 C 中实现的与任务紧密匹配的现有 R 函数。函数 rowSums()
、colSums()
、rowMeans()
和 colMeans()
对加速你的 R 代码非常有用。这些向量化矩阵函数总是比 apply()
函数更快。
为了测量运行时间,我们使用了 R 包 microbenchmark。在这个包中,所有表达式的评估都在 C 中完成,以最小化开销。作为输出,该包提供了统计指标的概述。你可以通过 R 控制台使用以下命令安装 microbenchmark 包:
install.packages("microbenchmark")
现在,我们将 apply()
函数的运行时间与 colMeans()
函数进行比较。以下代码示例演示了这一点。
install.packages("microbenchmark")
library("microbenchmark")
data.frame <- data.frame (a = 1:10000, b = rnorm(10000))
microbenchmark(times=100, unit="ms", apply(data.frame, 2, mean), colMeans(data.frame))
# example console output:
# Unit: milliseconds
# expr min lq mean median uq max neval
# apply(data.frame, 2, mean) 0.439540 0.5171600 0.5695391 0.5310695 0.6166295 0.884585 100
# colMeans(data.frame) 0.183741 0.1898915 0.2045514 0.1948790 0.2117390 0.287782 100
在这两种情况下,我们计算数据框每一列的均值。为了确保结果的可靠性,我们使用 microbenchmark 包进行了 100 次运行(times=10
)。结果显示,colMeans()
函数大约快三倍。
如果你想了解更多关于向量化的知识,我们推荐在线书籍 R Advanced。
矩阵与数据框
矩阵与数据框有一些相似之处。矩阵是一个二维对象。此外,一些函数的工作方式相同。不同之处在于:矩阵的所有元素必须具有相同的类型。矩阵常用于统计计算。例如,函数lm()
会将输入数据内部转换为矩阵,然后进行计算。通常,矩阵的速度比数据框快。现在,我们来比较矩阵和数据框之间的运行时间差异。
library("microbenchmark")
matrix = matrix (c(1, 2, 3, 4), nrow = 2, ncol = 2, byrow = 1)
data.frame <- data.frame (a = c(1, 3), b = c(2, 4))
microbenchmark(times=100, unit="ms", matrix[1,], data.frame[1,])
# example console output:
# Unit: milliseconds
# expr min lq mean median uq max neval
# matrix[1, ] 0.000499 0.0005750 0.00123873 0.0009255 0.001029 0.019359 100
# data.frame[1, ] 0.028408 0.0299015 0.03756505 0.0308530 0.032050 0.220701 100
我们使用 microbenchmark 包进行 100 次运行以获得有意义的统计评估。可以看出,矩阵访问第一行的速度比数据框快约 30 倍。这非常令人印象深刻! 矩阵明显更快,因此你应该优先使用矩阵而不是数据框。
is.na() 和 anyNA
你可能知道函数is.na()
来检查向量是否包含缺失值。还有函数anyNA()
来检查向量是否有任何缺失值。现在我们测试哪一个函数的运行时间更快。
library("microbenchmark")
x <- c(1, 2, NA, 4, 5, 6, 7)
microbenchmark(times=100, unit="ms", anyNA(x), any(is.na(x)))
# example console output:
# Unit: milliseconds
# expr min lq mean median uq max neval
# anyNA(x) 0.000145 0.000149 0.00017247 0.000155 0.000182 0.000895 100
# any(is.na(x)) 0.000349 0.000362 0.00063562 0.000386 0.000393 0.022684 100
评估结果表明,anyNA()
的平均速度显著快于 is.na()
。如果可能的话,你应该使用 anyNA()
。
if() … else() 与 ifelse()
if() ... else()
是标准控制流函数,而 ifelse()
更加用户友好。
Ifelse()
按以下方案工作:
# test: condition, if_yes: condition true, if_no: condition false
ifelse(test, if_yes, if_no)
从许多程序员的角度来看,ifelse()
比多行的替代方案更易于理解。缺点是 ifelse()
在计算效率上不如 if() ... else()
。以下基准测试表明,if() ... else()
的运行速度比 ifelse()
快 20 倍以上。
library("microbenchmark")
if.func <- function(x){
for (i in 1:1000) {
if (x < 0) {
"negative"
} else {
"positive"
}
}
}
ifelse.func <- function(x){
for (i in 1:1000) {
ifelse(x < 0, "negative", "positive")
}
}
microbenchmark(times=100, unit="ms", if.func(7), ifelse.func(7))
# example console output:
# Unit: milliseconds
# expr min lq mean median uq max neval
# if.func(7) 0.020694 0.020992 0.05181552 0.021463 0.0218635 3.000396 100
# ifelse.func(7) 1.040493 1.080493 1.27615668 1.163353 1.2308815 7.754153 100
在复杂循环中应避免使用 ifelse()
,因为它会显著减慢你的程序速度。
并行计算
大多数计算机有多个处理器核心,可以并行处理任务。这个概念叫做并行计算。R 包 parallel 实现了 R 应用中的并行计算。该包在基本 R 中预安装。使用以下命令,你可以加载该包并查看你的计算机有多少个核心:
library("parallel")
no_of_cores = detectCores()
print(no_of_cores)
# example console output:
# [1] 8
并行数据处理非常适合蒙特卡罗模拟。每个核心独立地模拟模型的一个实现。最后,结果被汇总。以下示例基于在线书籍 Efficient R Programming。首先,我们需要安装 devtools 包。借助此包,我们可以从 GitHub 下载 efficient 包。你必须在 RStudio 控制台中输入以下命令:
install.packages("devtools")
library("devtools")
devtools::install_github("csgillespie/efficient", args = "--with-keep.source")
在 efficient 包中,有一个 snakes_ladders()
函数,它模拟了一场蛇梯棋游戏。我们将使用模拟来测量 sapply()
和 parSapply()
函数的运行时间。parSapply()
是 sapply()
的并行化变体。
library("parallel")
library("microbenchmark")
library("efficient")
N = 10⁴
cl = makeCluster(4)
microbenchmark(times=100, unit="ms", sapply(1:N, snakes_ladders), parSapply(cl, 1:N, snakes_ladders))
stopCluster(cl)
# example console output:
# Unit: milliseconds
# expr min lq mean median uq max neval
# sapply(1:N, snakes_ladders) 3610.745 3794.694 4093.691 3957.686 4253.681 6405.910 100
# parSapply(cl, 1:N, snakes_ladders) 923.875 1028.075 1149.346 1096.950 1240.657 2140.989 100
评估显示,parSapply()
的模拟计算速度平均比 sapply()
函数快约 3.5 倍。哇! 你可以快速将这个技巧融入到你现有的 R 项目中。
R 语言与其他语言的接口
有时 R 会很慢。你使用各种技巧,但你的 R 代码仍然太慢。在这种情况下,你应该考虑用另一种编程语言重写你的代码。对于其他语言,R 提供了以 R 包形式的接口。例如,Rcpp 和 rJava。编写 C++ 代码很简单,特别是如果你有软件工程背景。然后你可以在 R 中使用它。
首先,你需要使用以下命令安装 Rcpp:
install.packages("Rcpp")
以下示例展示了这种方法:
library("Rcpp")
cppFunction('
double sub_cpp(double x, double y) {
double value = x - y;
return value;
}
')
result <- sub_cpp(142.7, 42.7)
print(result)
# console output:
# [1] 100
C++ 是一种强大的编程语言,使其最适合于代码加速。对于非常复杂的计算,我们建议使用 C++ 代码。
结论
在这篇文章中,我们学习了如何分析 R 代码。provis 包支持你分析你的 R 代码。你可以使用 rowSums()
、colSums()
、rowMeans()
和 colMeans()
等矢量化函数来加速你的程序。此外,如果可能的话,你应该优先使用矩阵而不是数据框。使用 anyNA()
而不是 is.na()
来检查向量是否有缺失值。通过使用 if() ... else()
而不是 ifelse()
来加速你的 R 代码。此外,你可以使用 parallel 包中的并行函数进行复杂的模拟。通过使用 Rcpp 包,你可以实现复杂代码段的最大性能。
有一些书籍用于学习 R。你将在以下找到我们认为非常适合学习高效 R 编程的三本书:
-
高效 R 编程:更智能编程的实用指南
-
动手编程 R:编写你自己的函数和模拟
-
R 数据科学:导入、整理、转换、可视化和建模数据
👉🏽 加入我们的免费每周 Magic AI 时事通讯,获取最新的 AI 更新!
免费订阅 以在我们发布新故事时获得通知:
[## 订阅邮件,以便在 Janik 和 Patrick Tinz 发布新内容时收到通知。
订阅邮件,以便在 Janik 和 Patrick Tinz 发布新内容时收到通知。如果你还没有 Medium 账户,注册时将会创建一个…
tinztwinspro.medium.com](https://tinztwinspro.medium.com/subscribe?source=post_page-----b0f58006d0c1--------------------------------)
了解更多关于我们的信息,请访问我们的关于页面。不要忘记关注我们的X。非常感谢你的阅读。如果你喜欢这篇文章,欢迎分享。祝你一天愉快!
通过我们的链接注册成为 Medium 会员,以阅读无限量的 Medium 故事。
学术研究出版技巧
如果你是一个尝试发表论文的研究生,请查看这个!
·
关注 发布于 Towards Data Science ·6 分钟阅读·2023 年 4 月 10 日
–
图片由 Lala Azizli 提供,来源于 Unsplash
本文的目标是与研究生分享一些关于在期刊和会议上发表论文的指南。它基于我作为博士生的过去经验。我的一篇 AI(深度学习)论文[1]最近在 Google Scholar 上获得了 1000 多次引用[2]。根据 Web of Science™数据库[引用影响力—3],只有约 0.026%的论文引用次数超过 1000 次。尽管看到论文发表后的影响力很高,但在论文被接受之前经历了许多反复尝试。这篇论文被多次拒绝,花了几年时间,并经历了多次修改才被接受。我尝试将我的经验提炼成一个指导过程。希望以下指南能帮助你在出版过程中。
问题
发布很困难,尤其是当你的学校要求提交到高影响力的期刊/会议时。大多数博士项目都有出版要求,这甚至可能成为拖延你毕业的原因。
目标
-
缩短发表时间
-
放大论文的影响力
指南
这里是一些我的经验,希望能帮助你减轻与出版相关的压力。举个例子,我认为启动公司和博士论文发表过程有一些共同点。两者初期都有不确定的未来,但最终都可能在各自领域产生新的贡献。就像创业一样,你的出版过程需要调查、愿景、策略、迭代和扩展。我的思路按以下顺序整理。
图 1:发表论文的 5 个里程碑 | 作者提供的图片
调查
阅读你尝试发表领域中的最新技术状态。记住,你是想对已有的工作进行补充。许多出版物是开放获取的,因此找到最新研究的论文不应是问题。此外,你的大学也应能提供必要的访问权限。Google Scholar 和 ResearchGate 是很好的开放资源。鉴于技术和科学研究的快速发展,我还会关注你研究领域的顶尖研究人员和公司,在 LinkedIn、Twitter 等类似网站上获取最新动态。如果适用,我还会探索统计数据和数据存储库进行探索性分析。关于调查的速度,推荐的数量从每周阅读 1 到 7 篇论文以准备文献综述,最终将 30 到 200 篇论文纳入/引用到你的综述中。
视野
调研后,制定你想要出版的主题愿景。一个创业公司从一个它打算解决的问题开始。创始人制定一个与他们和他们打算解决的问题相一致的愿景。为了帮助你找到匹配的方向,请从以下问题开始:
-
在你的研究领域/专业中有哪些未解决的问题?例如,这个 [4] 是一个关于各学科未解决问题的好参考列表。
-
在调研论文中讨论了一些活跃的研究领域是什么?
-
在你资格考试课程中,哪些课程让你最兴奋?
-
你的导师有哪些专业领域?
图 2:帮助制定出版愿景的维恩图 | 图片作者提供
尝试从上面维恩图中至少有三个重叠的区域选择一个主题。写一个摘要并获取你导师的反馈。确保它包括你的主要目标和论文的提纲。反馈越关键,此阶段需要的调整就会越多。
策略
在创业公司的初期阶段,其主要目标是实验、最大化学习,并找到市场真正需要的创新产品,即找到难以捉摸的产品-市场匹配。对你来说,那就是找到一个出版物,即论文-出版匹配。两者都需要策略。
在制定策略之前,请考虑以下问题:
-
你能投入多少时间和精力来进行出版工作?你将如何平衡其他承诺?
-
你的导师有多少时间来指导你?尝试安排一个定期的沟通会议。
-
你毕业所需的必要出版物的目标时间是什么?确保这一目标是现实的。
-
你是如何管理压力的?请注意,北美约 50%的博士生在获得学位之前会中途退学 [5]。你将需要一些生活技巧来完成博士学业。不出所料,三分之二的创业公司也失败了,即投资回报为负 [6]。
一旦你弄清楚这些问题并做了一些回顾,制定一个包含以下内容的出版策略:
-
论文主题/问题
-
论文类型、提纲和目标长度
-
目标出版物列表:
-
按接受难度/影响因子排名
-
包括他们的反馈时间(审稿周期)
-
确保出版物符合你所在机构设定的所有标准
-
-
预期实现出版成功的目标日期
- 计算n次迭代和修改论文的时间n次
迭代和调整
一旦你和你的导师对策略/重点达成一致,开始你的研究,并尝试达到一个你可以展示的成果(也就是在创业术语中的最小可行产品或 MVP)。当你有了 MVP 论文时,首先提交给那些难度较高但审稿周期较短的期刊/会议。较短的审稿周期可以让你更快地迭代和吸收反馈,并在需要时更早地调整策略。
不要因为被拒绝而灰心丧气,应该利用批评(或严厉的)反馈对你的论文和/或策略进行大胆的修改。确保你在评估、回应并纳入评审人提供的反馈/差距。带有反馈的拒绝是一个祝福,是过程中的重要部分。把它看作是训练 AI 模型的过程。人工神经网络的权重根据标记/训练数据的反馈在训练过程中进行调整。正如下面图 3 所示,这些调整最初较大,随着接近论文发表的契合度(或 AI/ML 中的全局最小值)逐渐减小。
图 3:随机梯度下降,一种迭代的机器学习训练算法,用于达到全局/局部最小值(用+号标记)。论文发表契合度遵循类似的迭代 | 作者提供的图片
选择与你论文研究相符的期刊/会议,并仔细阅读提交指南。没有什么比等待几个月的反馈更糟糕的了,结果却发现你提交的论文因为不符合期刊/会议的主题或错过了重要的提交步骤而无法被考虑。如果适用,研究你所提交期刊/会议的过去论文并引用其中相关的工作会有帮助。另外,如果你的合著者(例如导师)在某些期刊上有过成功的论文,也可以尝试那些期刊。
规模
经过几次迭代,你将逐渐掌握接近发表的过程,因为拒绝的原因会变得不那么严重。达到那个阶段可能需要很长时间。请注意,你的第一篇论文可能是最困难的,并且需要最长的时间。
在机器学习中,有一种叫做迁移学习的技术[7],你可以将解决一个任务获得的知识应用到另一个相关任务中,所需的学习努力会减少。就像这样,你的第一篇论文中验证的所有学习将非常有助于加快你第二篇、第三篇以及随后的发表进程。
图 4:展示努力与发表(结果)的图表 | 作者提供的图片
更快的反馈周期和迭代是关键。利用这一点来达到必要的出版数量。坚持适合你的方法,并继续学习和调整流程,以便获得更有影响力的出版物。祝好运!
免责声明
这些建议基于我的学习经验,可能并不适用于每个人。
参考文献
-
A. Shrestha 和 A. Mahmood,“深度学习算法与架构综述”,发表于 IEEE Access,卷 7,第 53040–53065 页,2019 年,doi: 10.1109/ACCESS.2019.2912200. (pdf 下载 链接)
-
谷歌学术档案 — A. Shrestha
-
为什么初创公司失败:
hbr.org/2021/05/why-start-ups-fail
-
AI/ML 中的迁移学习 —
en.wikipedia.org/wiki/Transfer_learning
理解 R 中的正则表达式的提示
原文:
towardsdatascience.com/tips-to-understand-regular-expressions-in-r-5d25be06f2a8
使用 stringR 获得更多的正则表达式知识
·发表于Towards Data Science ·8 分钟阅读·2023 年 1 月 18 日
–
照片由Jason Leung提供,来源于Unsplash
介绍
当一个项目涉及文本分析,比如情感分析、文本挖掘或你需要执行的其他分析任务时,通常会在某个时刻需要解析文本。这意味着需要提取文本的一部分或在文本中找到给定的模式,以便提取见解,例如。
如果我们处理的是简单的模式,比如一个单词或一个数字,那么在 R 中处理起来相对容易。
假设你有以下文本,并且你想找到单词random
。我们可以使用许多简单的函数来执行这个任务。但首先加载library(stringr)
。然后查看它的三个函数。
text <- "This is a random text. If you want to try to find a pattern here,
let's say the numbers 1 or 2 or 3, you can use stringR."
# Find out if the word "random" is present
str_find(text, pattern='random')
[1] TRUE
# Extract the word "random"
str_extract(text, pattern= 'random')
[1] "random"
# Find where the pattern is located in the text
str_locate(text, 'random')
start end
[1,] 11 16
现在,假设我们想找另一种类型的模式,非常具体的,比如句子中的任何数字,甚至更具体,字母和数字的组合。这就是正则表达式或regexp非常有用的地方。
正则表达式是一系列字符,帮助我们在文本中查找模式。
正则表达式 101
要编写你的第一个正则表达式,首先要理解这些模式可以由哪些元素组成。此外,记住始终将你的模式包裹在引号内,无论是单引号还是双引号。
开始吧。
旁注:
*str_view_all*
函数展示了由给定模式捕获的所有可能值,但它不会提取任何内容。这可以通过 stringR 库中的其他函数来完成。尽管如此,它是一个出色的工具,可以查看是否获得了正确的文本。
一个字符:方括号
如果你想查找一个字符,[]
将是你最好的朋友。当你写它们时,无论你放入什么,都表示你想要查找的那个字符。让我们通过实际示例来更好地理解这一点。
如果我们想查找在我们之前的文本中是否有字母k
,这里是要使用的模式:[k]
。
str_locate(text, pattern= '[k]')
start end
[1,] 81 81
是的,我们确实有一个字母k。它在位置 81 上。考虑到这仅对教育目的有用,我们再往上走一步。如果我们想知道我们的文本中是否包含字母 k、w 或 b 中的任意一个,我们可以使用这个模式:'[kwb]'
。
注意我们是这样写的:
-
我们使用
[]
来告诉程序我们想要查找一个字符 -
在方括号内,我们有我们要查找的模式,它表示k 或 w 或 b,哪个先出现就查找哪个。结果在第 31 个位置上是字母w。
str_locate(text, pattern= '[kwb]')
start end
[1,] 31 31
我们可以使用[]
来表示其他许多模式。让我告诉你,这是我通常编写的正则表达式中最常用的模式之一。甚至还有一些预先创建的模式可以帮助我们在编写代码时。如果你点击这个链接,在 stringR 的速查表中你可以找到许多快速模式,如:
-
[:digit:]
用于查找句子中的所有数字。在我们的文本中:“1” “2” “3”
-
[:alpha:]
用于查找字母。
使用‘[:alpha]’在我们的文本中找到的内容。图片由作者提供。
[:punct:]
用于查找标点符号。在我们的文本中:"." "," "'" "," "."
我们还可以在方括号内使用值范围或区分大写或小写字母。例如,[a-z]
查找从 a 到 z 之间的一个小写字母,总是返回首先找到的内容。类似地,[A-Z]
做同样的事,但查找大写字母。
点号.
意味着查找任何东西
在正则表达式中,点号.
表示任何字符。因此,如果我们创建一个类似'[Ii].'
的模式并将其应用于我们的文本变量,我们就是在说我们要找的是一个字母I 或 i 后面跟着任何其他字符的模式。
str_view_all(text, '[Ii].')
这就是生成的图像。
字母 i 后面跟着任何其他字符。图片由作者提供。
查找某物的存在或缺失
正则表达式的另一个有趣部分是使用一些模式来查找空格、数字、字母的存在或缺失。通常,这会是一个双反斜杠后跟一个字母表示存在,字母的大写形式表示缺失。
-
\\s
表示空格,\\d
表示数字,\\w
表示单词,\\b
表示边界。 -
\\S
表示非空格,\\D
表示非数字,\\W
表示非单词。
# The new text variable
text2 <- "This is a random text 2\. Use stringR and Regexp."
# View all digits
str_view_all(text2, '\\d')
# View all EXCEPT digits
str_view_all(text2, '\\D')
# View all spaces
str_view_all(text2, '\\s')
# View all EXCEPT spaces
str_view_all(text2, '\\S')
# View all boundaries
str_view_all(text2, '\\b')
以下是相应的结果。
正则表达式的结果。图片由作者提供。
重要的是要注意,边界可以是任何包围模式的东西。想象一下你有一个 ID 号码如ID-3234
。在这里,如果我们想仅使用这个正则表达式'\\b\\d+'
提取数字,-
将被识别为边界。
0、1 或更多次重复
+, * , ? , { }
是你指示字符无或多次重复的方式。
-
***** : 使用星号表示正则表达式可以发生零次或多次。
-
? : 问号表示正则表达式发生零次或一次。
-
+ : 使用加号表示正则表达式可以发生一次或多次。
-
{2,4}: 大括号用于定制重复次数。这里,从 2 次到 4 次。
# Creating the text
text3 <- "This is a random text 3\. Hellooooo."
# Finds the l and one or more reps of o.
str_view(text3, '[l][o]+')
[OUT] looooo
# Finds exactly 3 reps of o.
str_view_all(text3, 'o{3}')
[OUT] ooo
# Finds exactly 2 to 5 reps of o.
str_view_all(text3, 'o{2,5}')
[OUT] ooooo
# One letter or the other: N or M?
text4 <- "nice or mice"
str_extract_all(text4, '[nm]ice')
[1] "nice" "mice"
使用除外符号^
我们可以在方括号内使用^
作为除外的同义词。在我们的text2
字符串中,我们可以随意说我们不想让模式捕获字母a, m, d, x, e。下面的正则表达式模式与^
将移除方括号内列出的字母。
# Character NOT a, m, d, x or e
str_view_all(text2, '[^amdxe]')
代码返回下图。
我们正则表达式中的字母没有被选择。图像由作者提供。
之前和之后
要提取给定模式之前或之后的内容,这里是正则表达式。假设我们有一个text5 <- ‘Extract the ID 321’
,我们想提取单词 ID 之前和之后的任何内容。
# Everything before ID
# . for any character
# + for one or more occurrences
str_extract_all(text5, '.+(?=ID)')
[1] "Extract the "
# Everything after ID
str_extract(text5, '(?<=ID ).+')
[1] "321"
以…开始和结束
下一些代码是用来提取以某物开始或结束的模式。
#Text
text6 <- 'ID-1234'
# starts with ID
str_extract(text6, '^ID')
[1] "ID"
# ends with numbers
# \\d = digits
# + = one or more occurrences
# $ = ends with
str_extract(text6, '\\d+$')
[1] "1234"
分步构建
在学习了正则表达式的基础知识后,这里有一个好的提示来开始创建你自己的:分步构建。查看你想要创建的模式,思考它,然后开始编写每一部分,记住本文中学到的元素。
实践
模式:提取字母 A 后面跟着任何数量的数字。
-
第一个元素是限定的字母 A。因此
[A]
将处理它。 -
下一部分是任意数量的数字。所以
\\d
代表数字,+
是一个或多个重复的符号。 -
正则表达式 =
'[A]\\d+'
text <- 'A234 B334 C434 A334 B345'
str_extract_all(text, pattern= '[A]\\d+')
[1] "A234" "A334"
模式:从书单中,如示例Romeo and Juliet by William Shakespeare (6389)
,提取书名、作者名和下载次数。
-
要提取书名,我们必须获取直到单词
by
的所有内容。因此,我们使用.+
表示任意字符出现一次或多次(?= by)
直到by。 -
要提取作者名,我们必须获取单词
by
之后的所有内容。这里,(?<=by )
是寻找by之后的部分,.+
与前面的项目相同,[^(\\d)]+
意味着^不( )包含数字。 -
下载次数仅为数字
\\d
,出现一次或多次+
。
text <- "Romeo and Juliet by William Shakespeare (6389)
A Room with a View by E. M. Forster (5146)
Middlemarch by George Eliot (4913)
Little Women; Or, Meg, Jo, Beth, and Amy by Louisa May Alcott (4682)
Moby Dick; Or, The Whale by Herman Melville (4521)
The Enchanted April by Elizabeth Von Arnim (4490)
The Complete Works of William Shakespeare by William Shakespeare (4432)
The Blue Castle: a novel by L. M. Montgomery (4418)
Cranford by Elizabeth Cleghorn Gaskell (4276)
The Adventures of Ferdinand Count Fathom — Complete by T. Smollett (4257)"
# Extracting the names of the books to a data.frame
books = str_extract_all(text, pattern= '.+(?= by)')
# Extracting the names of the authors to a data.frame
authors = str_extract_all(text, pattern= '(?<=by ).+ [^(\\d)]+')
# Extracting the downloads
downloads = str_extract_all(text, pattern= '\\d+')
# As dataframe
gutenberg_top10 = data.frame(book = books,
author= authors,
downloads= downloads)
# Correct columns names
colnames(gutenberg_top10) <- c('book', 'author', 'downloads')
从古腾堡提取前 10 本最下载的书籍的练习结果。图像由作者提供。
在你离开之前
好吧,这就是关于正则表达式的基础知识。我希望你能感受到它的强大。你可以使用正则表达式基本上找到任何东西。
我知道一开始可能看起来很吓人,但这只是一个练习的问题,你会开始对它更加自信。
你可以随时使用 Regex 101 网站 进行练习,参考资料。
如果你喜欢这些内容,请关注我的博客以获取更多信息。
阅读 Gustavo Santos 在 Medium 上的文章。数据科学家。我从数据中提取见解,以帮助个人和公司…
gustavorsantos.medium.com](http://gustavorsantos.medium.com/?source=post_page-----5d25be06f2a8--------------------------------)
参考资料
正则表达式(缩写为 regex 或 regexp,有时称为有理表达式)是一系列…
en.wikipedia.org](https://en.wikipedia.org/wiki/Regular_expression?source=post_page-----5d25be06f2a8--------------------------------) [## 正则表达式
正则表达式是一种简洁且灵活的工具,用于描述字符串中的模式。本小册子描述了关键…
cran.r-project.org](https://cran.r-project.org/web/packages/stringr/vignettes/regular-expressions.html?source=post_page-----5d25be06f2a8--------------------------------) [## regex101:构建、测试和调试正则表达式
带有语法高亮、解释、PHP/PCRE、Python、GO、JavaScript 等的速查表的正则表达式测试器…
regex101.com](https://regex101.com/?source=post_page-----5d25be06f2a8--------------------------------) [## 简单、一致的常用字符串操作包装器
一套一致、简单且易于使用的封装在出色的 stringi 包周围的包装器。所有函数和参数…
stringr.tidyverse.org](https://stringr.tidyverse.org/?source=post_page-----5d25be06f2a8--------------------------------)
github.com/rstudio/cheatsheets/blob/main/strings.pdf
厌倦了二维码?自己制作一个标志性标记
二维码无处不在:想要创建一个更原创的解决方案吗?让我们制作自己的标志性标记,并学习如何检测和解码它。
·
关注 发布于Towards Data Science · 15 分钟阅读 · 2023 年 11 月 4 日
–
照片由Michael Dziedzic提供,发布于Unsplash
在这篇文章中,让我们学习如何制作一个新的标志性标记以及如何通过训练对象检测模型来检测它。然后,我们将学习如何使用图像处理技术解码我们的标记。
让我们将其分解为三个步骤:
-
创建标志性标记
-
在图像中检测标记
-
解码标记
创建基准标记
目前已经存在很多用于计算机视觉的基准标记,最著名的是二维码。还有其他二维码,使用程度和鲁棒性各不相同,也可以使用。下面是一个不完全的代码列表。
一些最著名的基准标记及其名称和创建日期(来源 www.mdpi.com/1424-8220/21/16/5407
,在 CC-BY 许可下)
如上图所示,基准标记可以有很大不同,但它们都有相同的目的:包含易于解码的信息。
什么是好的基准标记?
理想情况下,一个好的基准标记具有以下属性:
-
易于检测:在能够解码标记之前,你必须能够准确地在图像中检测到它。
-
易于解码:标记必须易于解码且没有任何歧义(即,解码后的标记产生唯一的值)
基于这些属性,现在让我们从现有的标记中构建我们自己的标记。
设计我们的基准标记
我个人喜欢 RUNE 标记(出于非常随意的理由):
-
圆形和点的形状使其比方形标记更柔和
-
这看起来非常明显,使得对象检测模型很可能容易检测到。
-
它很容易自定义:我们可以调整每个圆圈上的点数以及圆圈的数量,以满足我们的需求和期望的美学。
但它在原始形式下并不完美:两个旋转的标记可能会导致相同或不同的解码结果。
这两个标签之间的唯一区别是 90°的旋转:它们无法被区分。图像作者提供。
为了减轻这个问题,我们将对标记添加一个条件:一个且只有一个扇区没有黑点,如下所示。
一个只有一行且没有黑点的标签,可以解决旋转问题。图像作者提供。
这样的标记可以很容易解码:假设每个扇区可以有三种可能的值:0、1 或 2,具体取决于三种可能的情况:
-
一个小黑点:0
-
一个大黑点:1
-
两个点:2
一些扇区的表示:扇区 0 是唯一一个没有黑点的,而其他扇区总是至少有一个黑点。图像作者提供。
更一般地说,考虑一个具有 C 圆圈层的标记,一个扇区可以有最多 2ᶜ−1 种值(因为没有黑点的情况保留给扇区 0)。
最终,对于一个有 d+1 个点的标记,可能的组合数等于 (2ᶜ— 1)ᵈ。对于一个 2 个圆圈层和每个圆圈 20 个点的标签,这意味着 3¹⁹ ~ 11.6 亿个可能值。
构建我们的基准标记
下面我们解释一段用于生成随机基准标记图像的代码。
生成随机标签的方法概要。有关完整工作的代码链接,请参见文章末尾。
如你所见,第一步是生成一个随机值列表。考虑到C为圆圈层数和d+1为每层圆圈的点数,我们使用 numpy 生成一个在 0 到 2ᶜ−1 之间的d个值的列表。
基于这个随机值列表,我们计算点值:0 表示白色点,1 表示黑色点。最后,我们绘制最终标签,给定一个像素大小,并将输出保存为图像。当然,完整的代码库链接在文章末尾提供并记录。
我们选择了一个标记设计,并知道如何生成这样的标记。为了能够在实际条件下使用这样的标记,我们需要一个能够在图像中检测和解码这种标记的解决方案。这非常简单,分为两个步骤:
-
使用目标检测检测标记
-
解码检测到的标记
现在让我们进入这个管道的第一步。
检测标记
首先的步骤是检测给定图像中标记的存在和位置。为此,有许多目标检测模型。我们将在这里使用一个YOLOv8 模型,它非常容易训练和在生产环境中使用。
但在实际训练目标检测模型之前,我们需要数据:来自不同背景和环境的图像,包含来自不同缩放级别和视角的标签。
我们将生成和使用合成数据来训练模型,而不是收集和标记数据,这可能非常耗时。
生成数据
我们只需要两个要素来生成合成数据,以训练一个目标检测模型:
-
各种免费使用的背景图像,例如可以从Unsplash获取。
-
我们将随机生成的标记图像
有了这两个要素,我们所需要做的就是使用Albumentations进行一些增强,生成大量独特的合成图像及其相关标签。
下面提供了一段代码,允许生成图像,给定背景图像的路径和标记特征,如圆圈层数和每层圆圈中的点数。
生成合成数据的代码概要。有关完整工作的代码链接,请参见文章末尾。
这是一段相当长的代码,随意深入了解,但简单来说,它做了以下几件事:
-
生成一个随机标签,图像边界是边界框标签。
-
应用如仿射、透视或缩放等变换,感谢 Albumentations。
-
随机将该标签插入到随机选择的背景图像中
-
根据需要多次执行
使用这种方法,我们可以轻松生成足够大的数据集,包含数百或数千张图像。以下是一些创建的图像示例,带有红色边界框标签。
一些生成的图像,其中标签为红色边框。图像由作者提供。
正如我们所见,生成的图像相当多样,因为我们添加了背景和增广处理,如模糊和透视。
当然,我们不会对训练集和验证集使用相同的背景图像,以确保模型评估尽可能不受偏见影响。
一种允许在正确的文件夹中生成图像及其相关标签的 Python 脚本已在 GitHub 仓库中提供。
训练和评估模型
使用之前创建的数据集,我们现在可以在这些数据上训练一个目标检测模型。借助 YOLOv8 库,只需几行代码即可训练一个新模型。
YOLOv8 小模型在 100 个周期上的训练代码 gist。有关完整代码的链接见文章末尾。
正如我们所见,我们只需要实例化一个模型并在数据上进行训练。经过 100 个周期(如果在训练过程中遇到早停条件,则可以更少,比如我在这里大约 80 个周期后),我得到了约 0.5 的 mAP@50,如下面生成的结果所示。
YOLOv8 库生成的结果。结果的 mAP@50 约为 0.5。图像由作者提供。
尽管结果远未完美,但对于仅使用合成数据训练的数据集来说已经足够好。现在让我们用网络摄像头的实况测试这个模型。
为此,我们可以使用以下 gist 中的代码:
YOLOv8 模型在网络摄像头实况上的推理代码 gist。有关完整代码的链接见文章末尾。
这个代码相当直接:
-
我们加载模型并获取网络摄像头的实况
-
对于每张新图像,我们计算模型推理并显示任何检测到的边界框
-
当按下 Escape 键时,我们停止实况
我用我手机上的标记图像运行了这段代码,正如我们在下面的图像中看到的那样,它效果非常好。
仅在合成数据上训练的 YOLO 模型的检测结果,在我的网络摄像头实况上测试。图像由作者提供。
虽然它在所有配置中并不能完美地检测标记,但对于仅使用合成数据训练的模型来说已经足够好。为了获得更好的结果,我相信可以稍微调整数据增广,当然真实的标记数据将会非常有帮助。
现在我们完成了管道的第一部分,让我们进入第二步:标签解码。
解码标记
我们现在拥有了生成和检测新基准标记的完全可用的代码。
一旦你可以在图像中检测到标签,下一步当然是解码它。让我们从经过我们之前训练的模型检测到的标记的裁剪图像开始。
来自我们目标检测模型的裁剪图像。让我们解码这个标记。
我开发了一个由以下步骤组成的解码算法:
-
斑点检测以检测点
-
外圆检测和椭圆拟合
-
用于单应性计算的点选择
-
单应性矩阵计算和图像展开
-
最后,进行标记解码
主要思路如下:只要我能将检测到的标记与参考标记匹配(知道每圈的圆层数和每圈的点数),我就可以通过检查图像是白色还是黑色来相对容易地解码它。但为了做到这一点,我首先需要将图像展开以使其与参考标记匹配。
让我们一起回顾这些步骤。
检测点
首先任务是检测 YOLO 模型检测到的图像中的点。
从输入的裁剪图像中,我们将使用 OpenCV 应用以下图像处理列表:
-
将图像转换为灰度图
-
使用 Otsu 算法 二值化图像
-
使用斑点检测器查找点
以下 gist 中的代码完成了这些操作:
从裁剪图像中检测点的代码。请查看文章末尾的链接以获取完整的代码。
如我们所见,为了最大化实际标记点的有效检测,设置了许多参数,如最小和最大面积,以及最小圆度。这些参数的微调花费了不少时间,但可以随意调整这些参数。
使用此代码处理我们的裁剪图像会得到以下的斑点检测结果。
输入裁剪图像上的斑点检测结果:点检测得很好。图像作者提供。
如我们所见,点检测得很好。下一步是检测外圆。
检测外圆
现在,我们需要检测最外层的圆圈(无论标签中的圆圈数量,这种解决方案都可以推广)。这将允许我们找到外圆上的点,以便我们最终展开图像。
要计算椭圆,我们所做的就是保留较大的点(在 OpenCV 中称为关键点),并从这些点拟合椭圆方程。这就是以下代码的功能:
代码允许从检测到的点计算椭圆方程。请注意,此代码始终计算一个中心估计,这在接下来的几个步骤中会很有用。
当我应用此代码并将检测到的点作为散点图展示,并展示拟合的椭圆时,得到如下结果:
检测到的斑点的散点图和最外圈的拟合椭圆。图像作者提供。
如我们所见,拟合的椭圆定义明确且与点的位置一致。请注意,由于我们在拟合椭圆,无论检测到的标记由于透视而变形程度如何,它都能正常工作。
现在我们需要找到实际上在这个椭圆上的点。这很简单:我们只需找到满足我们刚刚计算出的椭圆方程(带有给定阈值)的点位置即可。这是通过以下代码实现的:
用于返回椭圆上关键点的代码要点。请参见文章末尾的链接以获取完整代码。
现在我们知道点的位置,以及哪些点在最外圈上,我们可以使用这些点来计算单应性矩阵并解扭曲图像。
为单应性计算选择点
现在的目标是找到一些与参考图像匹配的点,以便计算单应性矩阵。
参考标签的图像,所有点都填充。作者提供的图像。
基于上面的参考图像,我们需要使用正确的单应性矩阵来解扭曲检测到的斑点。
为了计算单应性矩阵,我们可以简单地使用 OpenCV 函数findHomography。此函数需要参考图像和输入图像中至少 4 个匹配点作为输入,以便找到单应性矩阵。这个单应性矩阵将允许我们解扭曲检测到的图像,并与参考图像匹配。
从我们检测到的最外圈上的斑点来看,无法确定这些点在原始参考图像上的位置。因此,我们将选择最外圈中最近邻点的最长链条,以便与参考图像匹配。为此,有两个步骤:
-
计算邻接矩阵,以便我们知道每个点的相邻点(如果有的话)
-
从邻接矩阵中计算相邻点的最长链条
对于第一步,我们可以使用以下代码:
计算邻接矩阵的代码。请参见文章末尾的链接以获取完整代码。
这段代码将计算邻接矩阵,作为一个 Python 字典:对于最外圈上每个现有的点索引作为键,相应的值是找到的相邻点索引列表。
从这个邻接矩阵中,现在很容易找到最长的相邻点链条。为此,我使用了以下代码:
用于计算相邻点最长链条的代码要点。请参见文章末尾的链接以获取完整代码。
这段代码将高效地找到相邻点的最长链条,并返回它们的索引列表。
如果我们在这个输出中至少有 4 个点,我们可以理论上计算单应性矩阵。不幸的是,在大多数情况下,这不会非常准确,因为这些点几乎在同一条线上,无法准确计算单应性矩阵。为了解决这个问题,我们将添加一个额外的点:一个相对于中心对称放置的点:这将使单应性计算更准确。
我们可以用以下代码找到相对于中心的对称点(在进行椭圆拟合时计算得到):
找到对称点的代码,给定输入的最长链和椭圆上的所有关键点。请参见文章末尾的链接以获取完整的工作代码。
请注意,由于我们处于椭圆上,使用中心估计来找到给定点的对称点并不是 100%可靠的方法:它可能会输出错误的点。这一点在计算解码时我们会记住。
最终,我们得到以下图像中的结果,其中蓝色圆圈是最长链上的点,红色圆圈是预期的对称点(其中一个是最长链的一部分)。
点选择的结果。蓝色圆圈中的点是最长链上的点(除了最左侧的点)。红色圆圈中的点是检测到的对称点。中央的红点是估计的椭圆中心。图片由作者提供。
如我们所见,我们确实选择了 7 个相邻点的链,并选择了另一个点作为链中最左侧点的对称点。
解开图像
既然我们已经在输入图像中选择了一些点,接下来让我们在参考图像中找到匹配的点并计算单应性矩阵。为此,我们需要以下输入:
-
裁剪图像中所选点的位置:这是我们刚刚计算的内容
-
这些点在参考图像中的等效位置:需要计算,知道参考标记
要计算这些点的位置,我们将使用以下代码,允许计算点的位置。
用于生成参考点位置以计算单应性矩阵的代码要点。请参见文章末尾的链接以获取完整的工作代码。
请注意,我们通过一个名为symmetry_index_offset的参数增加了一个自由度:这将允许处理对称点计算中的可能错误,通过将偏移量添加到对称点的位置。
通过正确的点位置在裁剪图像和参考图像中,我们现在可以计算单应性矩阵并解开图像。为了确保我们在对称点上没有犯错误,我们将在[-2, 2]的范围内以 1 为步长进行计算,如下面的代码片段所示:
单应性矩阵计算和图像解开的代码要点。请参见文章末尾的链接以获取完整的工作代码。
我们在这里做的是用 OpenCV 的functionfindHomography计算单应性矩阵,然后用warpPerspective解开图像。我们对 5 个偏移值执行此操作,以便得到 5 张解开的图像。
结果图像如下:
结果解开的图像。只有-1 偏移量的图像正确解开。图片由作者提供。
正如我们所看到的,根据偏移量,未校正结果相差甚远。尽管通过视觉检查很容易理解 -1 的偏移量是正确的,但我们希望将这一检查自动化。我们将在下一步中处理这个问题:实际的标记解码。
解码标记
从给定的未校正图像开始,最后一步是解码标记。我们非常接近,这一步可能是最简单的。
我们需要做的就是检查每个预期点的位置,未校正图像的颜色。由于图像经过了 Otsu 二值化,这非常简单。我们只需检查预期点位置周围 3x3 像素区域内是否有黑色像素:如果有,则存在点;如果没有,则不存在点。
计算未校正图像的列表代码所使用的代码概要。请参阅文章末尾的完整代码链接。
这基本上是上面代码所做的。然后根据位置,我们分配一个值。这样,这个函数的输出就是一个值的列表。最后,我们寻找一个 -1 值(表示预期的区域没有黑点,请参考 设计我们的基准标记 部分以获取相关提醒),并重新排列数组,将其放在最后一个索引位置。
例如,以下是每个未校正图像计算出的代码:
-
偏移量 -2: [0, 2, 0, -1, 1, -1, 0, 0, 0, 2, 2, 1, 2, 2, 0, 2, 2, 2, 0, -1]
-
偏移量 -1: [2, 2, 2, 0, 2, 0, 1, 1, 1, 2, 2, 1, 2, 2, 0, 2, 2, 2, 0, -1]
-
偏移量 0: [0, -1, 2, 2, 0, 0, -1, 0, 0, -1, 0, 1, 2, 2, 0, 2, 2, 2, 0, -1]
-
偏移量 1: [-1, 2, 2, 2, 2, 0, -1, -1, 0, 0, 0, 0, 0, 2, 0, 2, 2, 2, 0, -1]
-
偏移量 2: [-1, 2, 1, 2, 1, 0, 1, -1, -1, -1, -1, 0, 0, 0, 2, 2, 0, 0, 2, -1]
正如我们所见,只有一张图像在最后一个索引位置只有一个 -1 值:使用 -1 偏移量的未校正图像。这是我们已经很好地校正的图像(如视觉检查所见),能够实际解码标记。
由于这个代码列表对于每个可能的标记都是唯一的,你可以在这里停止,或者计算一个唯一的整数值。可以通过以下代码片段很容易地计算出唯一值:
计算解码最终值所使用的代码概要。请参阅文章末尾的完整代码链接。
在我们的例子中,这将返回 -1 对于所有错误校正的图像,并返回 377667386 对于实际标记。
就是这样,我们从输入图像一路走到了实际的唯一代码!现在让我们总结一下,并反思我们所做的工作中的局限性。
创建一个完整的流程
现在我们已经拥有所有的构建模块,我们只需要将它们组合起来,以获得一个漂亮的、自定义的基准标记解码器,它可以替代 QR 代码!
总结一下,以下是一个完整工作流程中的步骤:
-
从输入图像中,使用对象检测来检测标记
-
对于每个检测到的对象,裁剪图像并进行下一步
-
使用 Otsu 二值化和斑点检测来检测点
-
使用椭圆拟合找到最外层的点
-
使用最近邻点和对称点的最长链计算同质变换矩阵
-
使用同质变换矩阵去畸变图像
-
使用参考图像解码标记
就这样!我不会让你自己编写所有代码,一切都可以在github 仓库中找到,还有一个预训练的目标检测模型。
你会在这个仓库中找到用于运行子步骤(例如生成合成图像、训练目标检测模型等)的 python 脚本,以及一个使用你的摄像头作为输入运行完整流程的 python 脚本,这样你就可以进行测试!
最后想法
希望你喜欢这篇文章并从中学到了东西!我个人非常喜欢这个项目,因为它结合了机器学习和传统的图像处理。
尽管如此,我开发的算法仍然有一些限制,我希望能克服它们。确实,并非所有标记都可以解码:
-
外圈上没有超过 2 个相邻点的标记将无法正确解码
-
对于没有对称点的标记,它会给出不可靠的结果,因为同质变换矩阵不准确。
另一个限制是有时在去畸变过程中,同质变换会镜像图像,导致列表代码被反转,从而最终解码的整数值不同。
如果你有任何想法来克服这些限制,欢迎给我发消息或提议拉取请求!
另一个话题是,这里的解码只给出一个整数值。你需要将这个整数值与应用中的任何相关内容(如链接、项目、图像等)匹配,以使其真正有用。我相信可以直接将这样的标记解码为 ASCII 字符列表,但我自己没有尝试过:再次强调,任何贡献都是非常欢迎的。
参考文献
原始 RUNE-Tag 论文:
F. Bergamasco, A. Albarelli, E. Rodolà 和 A. Torsello,“RUNE-Tag: 高精度的具有强遮挡恢复能力的基准标记”,CVPR 2011,科罗拉多斯普林斯,美国,2011,第 113–120 页,doi: 10.1109/CVPR.2011.5995544。
原始 RUNE-Tag 仓库: github.com/artursg/RUNEtag
对你的数据工程师角色感到厌倦吗?
我是如何转型为数据分析工程师的
·
关注 发表在 Towards Data Science ·8 min read·2023 年 8 月 19 日
–
照片由 Campaign Creators 在 Unsplash 提供
几年前,我曾经在职业生涯中感到不满。我在数据工程领域工作了三年,最初对科技世界的兴奋感已经消退。我开始意识到我对工作并不像我曾经希望的那样充满热情。
我认为,无论你身处何地,跟随自己的兴趣非常重要,以找出你真正想做的事情。这可能意味着追求那些让你快乐的工作之外的爱好,或者参与你已经工作的公司中的不同团队。
我记得在大学时我非常喜欢市场营销和业务方面的东西,所以我决定开始探索那方面的选择。我开始与每天使用数据解决业务问题的数据分析师交谈。他们就像是有更多业务曝光的数据工程师!
通过要求参与数据分析项目,我学会了使用dbt并进一步发展了我的 Python 技能。玩弄新技术帮助我看到,我仍然喜欢从事数据工作,只是需要使用正确的工具和解决正确的问题。最终,这促使我寻求一个与这些新发现的兴趣和技能集更匹配的不同角色。这个角色恰好就是数据分析工程师。
如何知道你应该过渡到数据分析工程
很多人害怕从数据工程师过渡到数据分析工程师,因为他们不知道时机是否合适。事实是,永远不会有“正确”的时间去做任何事情。然而,如果你感到没有挑战和满足感,你永远不会过早做出这个选择。
从数据工程到数据分析工程可能适合你如果你对数据本身感到好奇,而不是构建支持数据的产品。数据分析工程更以客户为导向,因为你使用客户的数据来回答关键业务问题。你更专注于增加收入和洞察,而不是构建事物。
这让我引出了第二点。如果你希望更紧密地联系业务,并希望做出推动公司增长的决策,数据分析工程可能适合你。作为数据工程师,你的任务由项目经理分配。你不一定能决定自己解决什么问题或认为需要优先处理什么。然而,在数据分析工程中,你可以。
我是如何发现数据分析工程的
说实话,当我意识到自己想要换数据角色时,我对数据分析工程一无所知。我认为我唯一能更接近业务的选项就是成为数据分析师。这正是我尝试做的。
我申请了很多数据分析师职位,但运气不佳。我缺乏分析师所需的深入业务经验,以及构建适当仪表盘的技能。我开始专注于那些结合了我已有技能和我想要学习技能的职位,而不是单纯关注职位名称。
最终,我偶然发现了分析工程师这一角色。这个角色要求掌握 SQL、Python、AWS、编排、dbt 和数据仓库等技能,而这些技能正是我作为数据工程师时所获得的。然而,它还要求具有使用现代数据堆栈工具的经验、与业务团队沟通的能力,以及一些基础的 BI 报告技能。
尽管我没有太多现代数据堆栈的经验,但我很幸运找到了一家相信我的热情和学习欲望的公司。有时候,如果匹配合适,公司愿意在没有所有要求的情况下雇用你!
在转向分析工程师之前需要发展的技能
在寻求分析工程师角色之前,我花了一些时间真正发展以下几个技能。这些是最重要的技能,会让你在其他候选人中脱颖而出。如果你专注于这三项技能,我相信你可以轻松地在工作中自学成为一名优秀的分析工程师所需的其他技能。
dbt
dbt,或称为数据构建工具,是一种真正为分析工程师开辟道路的数据转换工具。事实上,它背后的公司也是首创“分析工程师”这一名称的!虽然你不一定需要了解 dbt 才能成为一名分析工程师,但这是一项许多公司在招聘时寻找的技能。
dbt 是一个基于 SQL 的工具,所以如果你已经知道 SQL,学习起来相对简单。你需要熟悉设置 dbt 项目、数据建模最佳实践以及每种 dbt 数据模型的目的。我建议查看 dbt 样式指南,以了解在 dbt 中编写 SQL 代码的“应做”与“忌做”之处。这将帮助你学习当你加入团队时可能已经存在的标准。
dbt 还使用了一种称为 Jinja 的模板语言用于文档和工具内的功能。然而,dbt 称这些为宏,而不是函数。宏是更高级的功能,相当于函数。你可以使用它们在数据模型中自动化 SQL 输出。虽然这不是获得工作的必要条件,但这无疑是使你作为分析工程师的工作更轻松的有用技能。
商业沟通
这可能是作为数据工程师最难以发展和练习的技能。我们习惯于与其他工程师进行技术性对话,以至于忘记了用易于理解的方式解释事物。在与业务团队沟通时,你需要知道如何向非技术观众解释技术概念。越简单越好。你所沟通的听众几乎永远不会像你一样具备技术背景。
此外,你需要理解这些团队所使用的业务术语。你会不断听到不同的指标,如 CAC、MRR、NPM 和 ROI。确保你知道这些是什么意思!很可能,这些就是你构建数据模型的原因。你需要理解最终目标/指标以及计算这些指标的数据。
构建数据管道
幸运的是,作为一个数据工程师,你很可能已经具备了这个技能!构建数据管道是许多数据工程师的常见技能,这些技能在分析工程师中也同样适用。然而,在这个背景下,你需要知道如何协调不同的数据源与 dbt 模型。你需要能够处理来自多个不同源系统的依赖关系。
在分析工程中,常见的数据管道工具包括 Airflow、Prefect 和 Dagster。这些工具都是使用简单的 Python 构建的,Python 也是构建管道的重要技能。选择一个工具并熟悉它!一旦你学会了一个,你很可能也能学会其他的工具。
我希望在转型为分析工程师之前知道的事情
当然,回顾一下从数据工程师转型为分析工程师的过程,我有一些希望自己早知道的事情,这些事情会让我的旅程更加愉快。
你的数据工程技能每天仍然会被使用,并且非常有价值。
当你从数据工程师转型为分析工程师时,很容易认为你的整个角色和技能集都会发生变化。然而,分析工程师和数据工程师之间有很多重叠的部分!通常,你作为数据工程师学到的技能会成为你最大的超能力。你可能能做其他分析工程师做不到的事情,所以要充分利用这些技能!
我实际上认为,这会让你在申请职位时比其他候选人更具优势,特别是如果你正在寻求一个小公司里的数据职位,该公司刚刚开始组建数据团队。很可能,他们希望以最小的投入获得最大的回报,并且希望找到一个可以做各种事情的人。一个充满热情的数据工程师转变为分析工程师,正是适合这个工作的最佳人选!
一开始你不需要知道所有的东西。
很容易陷入需要一次性学习所有东西的状态。我需要知道如何在申请那个职位之前设置一个像 Airbyte 这样的开源数据连接器。我需要在申请那个职位之前有在 dbt 中构建自定义宏的经验。我不能申请那个职位,直到我掌握每种 SQL 窗口函数。
不要专注于你不知道的东西,而要关注你知道的!是什么让你独特?为什么有人会选择雇佣你而不是别人?思考那些在工作中无法学到的让你脱颖而出的事情。行业总在变化,这意味着总会有你不知道的东西。接受它!这就是分析工程的最佳部分之一。
当我开始作为分析工程师时,我根本不知道什么是维度建模。我实际上是在最初试验组织我们的数据仓库和建模数据时通过 dbt 的文档了解到的。最终,我学会了什么是维度建模,应用这些技术到我的工作中,并提高了我的数据建模技能。现在我可以和你谈论维度建模整整一天!
申请新角色从来不会太早。
如果你不开心,就要采取行动。你在不适合你的职业中呆得越久,就越错过学习和成长的机会。每次我换职业,都是因为我的工作不再给我带来挑战。如果我们不学习新事物,不成长为一个更好的人,我们就停滞不前。而这难道不是更糟糕的事情吗?
即使你觉得自己一开始没有合适的技能,通过探索你的选择,你将会学习到东西。浏览不同的角色,看看公司在寻找什么。职位发布中的技能有什么共同点?你是否认为分析工程角色中有你不喜欢的地方?查看职位描述是发现这些信息的好方法!
结论
回头看,我无法从数据工程到分析工程有更棒的过渡。我深知纯技术角色对我而言并不合适。我需要动手处理数据,利用数据解决客户和业务问题。
生活太短暂,不要留在你不完全享受的角色中!不要害怕探索其他可能的领域,无论是否是分析工程。最坏的情况是什么?如果你跳槽后发现这实际上不适合你,你总是可以回到之前的角色。但是,你永远不会知道除非你尝试!
想了解更多关于分析工程的信息,订阅我的免费每周通讯,我会分享学习资源、教程、最佳实践等。
查看我的第一本电子书,分析工程基础,这是一本关于入门分析工程角色的全面指南。
1 还是 0:图像分类中的像素攻击
原文:
towardsdatascience.com/to-1-or-to-0-pixel-attacks-in-image-classification-ec323555a11a
探索对抗性机器学习的领域
·发表于Towards Data Science ·阅读时间 13 分钟·2023 年 11 月 23 日
–
图片由Pietro Jeng提供,来源于Unsplash
嗨,大家好!
今年,我参加了我的第一次Capture The Flag (CTF)比赛,这次经历可以说非常引人入胜。挑战,特别是那些涉及像素攻击的挑战,引起了我的注意,并成为了这篇文章的主要焦点。虽然我最初打算分享我在比赛中进行的一个简单的像素攻击,但这篇文章的目标也是深入探讨如何增强机器学习模型,以更好地抵御像比赛中遇到的像素攻击。
在我们深入理论之前,让我们通过一个引人入胜的场景来引起你的注意。
想象一下:我们的公司 MM Vigilant 致力于开发一款前沿的物体检测产品。概念简单而革命性——客户拍下所需物品的照片,几天后它就会送到客户家门口。作为幕后出色的数据科学家,你打造了终极的基于图像的物体分类模型。分类结果无可挑剔,模型评估指标一流,利益相关者也非常满意。模型投入生产,客户也很高兴——直到投诉接踵而至。
经调查,发现有人在图像到达分类器之前对其进行干扰。具体来说,每张钟表的图像都被恶意地分类为镜子。结果如何?任何期待钟表的人都会收到意外的镜子。这真是个意外的转折,不是吗?
我们在 MM Vigilant 的利益相关者对这种事故的发生感到既担忧又好奇,更重要的是,如何采取措施来防止它。
我们刚刚探讨的场景是一个假设情境——尽管图像篡改是一个非常可能的情况,特别是当模型存在漏洞时。
那么,让我们仔细看看这种图像操作的一个例子……
图像分类中的像素攻击
像素攻击,特别是在图像分类的背景下,旨在欺骗机器学习(ML)分类器,将图像分类为其他类别。虽然可以讽刺地认为,一个不佳的模型已经表现出这种行为,但这里的目标是击败最先进的模型。不用说,这些攻击有很多方法和动机,但这篇文章,限于范围,将重点关注黑箱、针对性像素攻击及其相关的预防措施。
让我们从直观上来理解这个概念。实际上,任何输入到神经网络的图像都经过每个像素的一系列数学运算来进行分类。改变一个像素,因此,会导致这些数学运算的结果发生变化,从而改变预测得分。这可以推断到这样一种程度,如果一个主要/“对分类至关重要”的像素被操控,它将对类别的预测得分产生足够大的影响,从而导致误分类,如下图所示。
图片来源于作者
在图像分类攻击领域,有两种知名的方法,取决于误分类的期望结果:
-
针对性攻击
-
未针对性攻击
针对性攻击
针对性像素攻击涉及一种有目的的转换,目标是将图像分类为特定类别。例如,想象一个故意尝试将熊分类为船或将苹果分类为考拉的行为。这些攻击有两个目标:最小化原始类别的得分,同时最大化目标类别的得分。
未针对性攻击
相反,未针对性像素攻击的前提是避免将图像分类为其原始类别。任务简化为最小化指定类别的预测得分。换句话说,目标是确保一只熊,例如,被分类为除了熊之外的任何东西。
值得注意的是,每个针对性攻击都可以被认为是未针对性攻击,但反之则不一定成立。
除了攻击类型之外,还有两种不同的方法来实现这些攻击,具体取决于被攻击模型的可用性(传统/白盒方法)或仅有的结果分数(黑盒方法)。
传统攻击
在传统或白盒攻击中,模型通常是可用的。可以获取梯度信息并用于如 快速梯度符号方法(FGSM) 这样的攻击。这种方法通过沿梯度方向对输入数据进行小幅扰动,导致误分类,而不会显著改变图像的视觉外观。
可以在以下代码库中找到该方法的简单 GitHub 实现。
[## GitHub - ymerkli/fgsm-attack: 目标和非目标快速梯度符号方法的实现
目标和非目标快速梯度符号方法的实现 - GitHub - ymerkli/fgsm-attack: 实现了…
github.com](https://github.com/ymerkli/fgsm-attack/tree/master?source=post_page-----ec323555a11a--------------------------------)
黑盒攻击
黑盒攻击则完全依赖于模型预测。可以使用如 差分进化 这样的技术来执行这种类型的攻击。
差分进化是一种模拟自然选择的优化算法。它通过在迭代中创建和组合潜在解决方案,基于设定的标准从一个种群中选择表现最佳的解决方案。这种方法在探索解决方案空间方面效果显著,并且常用于对机器学习模型的对抗攻击。
鉴于我们的挑战集中在黑盒目标像素攻击上,让我们直接进入实现部分。
CTF 挑战
对于 CTF 挑战,其目标是将一张清晰的狼的图像误分类为格兰尼·史密斯苹果——向“小红帽”故事致敬。数据集中包含大约 1000 个类别,图像分辨率为 768x768 像素,超越了 MNIST、CIFAR 甚至 ImageNet 的分辨率,困难在于通过识别最少的像素数来欺骗模型以达到目标误分类。值得注意的是,尽管高分辨率图像复杂,但机器学习分类的本质,如上所述,在于任务的非直观性,将图像简化为一组值和一系列非常依赖这些个体值的数学运算。
在我们深入研究代码之前,让我们先看看我们狼的原始图像。难道它看起来不具有伪装成苹果的潜力吗?那绿色的眼睛、圆圆的脸以及绿色的背景——这些都是一个果味冒名顶替者的所有特征。
由 AI Village 提供,许可证 署名 4.0 国际 (CC BY 4.0)
在开始的旅程中,黑箱模型对“木狼”类别的初始评分约为 0.29,而对“青苹果”类别的评分为 0.0005。我最初考虑应用 scipy 的差分进化 方法。这种方法在涉及 CIFAR 和 ImageNet 数据集的像素攻击中已显示成功。差分进化技术涉及从 n 个随机样本开始,代表种群大小。在每一步,选择最好的后代,通过模型评分来确定,最终导致我们期望的结果。然而,鉴于时间限制和任务涉及仅更改单个图像的评分,我选择了更直接的策略。
方法
我首先将原始图像划分为逐渐更小的块,从 2x2 开始,一直到 16x16。针对目标青苹果(绿色),我逐一将块中的值更改为苹果绿色,并观察对木狼类别和青苹果类别分数的影响。然后,我手动选择了 2-3 个 16x16 的块,在这些块中应用了一种差分进化的方法。这意味着在该区域内随机选择的像素进行了大约 50-75 次迭代的单像素更改。
尽管我在给定的两天内无法准确定位到那个臭名昭著的单一像素,但我成功地进行了高度像素化的攻击,将狼的分类改变为青苹果,从而获得了三部分任务的两个子问题的标志。
现在我们有了背景信息,让我们跳入一些代码中,以便你能从这篇文章中学到一些东西。
Python 代码
我将其视为黑箱问题,当给定图像时,查询会提供所有类别的预测列表。预测按值排序,因此预测的类别是列表中的第一个值。
import requests
import base64
import cv2
import numpy as np
import matplotlib.pyplot as plt
def query(input_data):
response = requests.post({link to get the blackbox score},
json={'data': input_data})
return response.json()
get_scores
函数以正确的格式将图像输入查询,并在大多数情况下以字典形式获得所需的结果。
def get_scores(input_image):
# Some preprocessing since the query accepted only bytes
_, input_image = cv2.imencode('.png', input_image)
image_bytes = input_image.tobytes()
input_data = base64.b64encode(image_bytes).decode()
result = query(input_data)
"""
the result is a json dict {} with the variable 'output' or 'flag',
the output consists of scores for 1000 classes of which two are timber
wolf and granny smith. Initially the score for timber wolf is around 0.29
and the score for granny smith id 0.0005
"""
dict_score = {"timber wolf" : 0, "Granny Smith" : 0}
try:
print(result['flag'])
except:
pass
# the scores in the output are always sorted so the first score
# is always the predicted score
dict_score["predicted_class"] = result['output'][0][1]
dict_score["predicted_score"] = result['output'][0][0]
# next we get the scores for our wanted target and our original class
count = 0
for sublist in result['output']:
score, class_name = sublist
if class_name == "timber wolf":
dict_score['timber wolf'] = score
count+=1
elif class_name == "Granny Smith":
dict_score["Granny Smith"] = score
if count ==1:
break
return dict_score
相关代码
核心思想是选择在苹果颜色的 RGB 范围内的像素,并测试约 50-75 个像素,以找到最大化“青苹果”类别分数并最小化“狼”类别分数的像素。我逐渐增加了选择区域的大小,并根据需要修改优化过程。例如,当“青苹果”类别的分数超过“狼”类别的分数时,我考虑所有增加“青苹果”类别分数的像素,只要它高于“狼”类别的分数,而不是专注于减少“狼”类别的分数,这显然加快了一些进程。
尽管没有找到那个难以捉摸的单一像素,我成功执行了高度像素化的攻击。
# Load your image
input_image = cv2.imread('/timber_wolf.png')
# Get the dimensions of the original image
image_height, image_width, _ = input_image.shape
# Define the size of the window (dxd)
# initially I had a large window size for testing purposes
# to identify regions of high interest
window_size = 1 #image_height//64
# get the initial scores
scores = get_scores(input_image)
dict_pixels ={'pixels':[]}
best_score_tw = scores['timber wolf'] #the current/best score for timber wolf
best_score_gs = scores['Granny Smith'] #the current/best score for granny smith
print(best_score_tw, best_score_gs)
max_iter = 75
iter_1=-1
pixel_count = -1 # number of pixels that have been changed
max_pixel_count = 40 # number of pixels we want to change
temp_image = input_image
rand_red_best, rand_green_best = (0, 0)
row_best, col_best = (0, 0)
while pixel_count < max_pixel_count:
while iter_1 < max_iter:
# although I did change the values from time to time
row = np.random.randint(192,388)
col = np.random.randint(192,388)
iter_1 +=1
output_image = input_image.copy()
left = row
upper = col
right = min(x + window_size, image_width)
lower = min(y + window_size, image_height)
# the pixel values for RGB were kept close to the color of the apple
rand_red = np.random.randint(0,153)
rand_green = np.random.randint(170,255)
rand_blue = 0#np.random.randint(0,255)
output_image[upper:lower, left:right] = [rand_red, rand_green, rand_blue]
scores = get_scores(output_image)
# I actually also changed this a couple of times depending on where the output was
#if (scores['timber wolf'] - scores['Granny Smith']) < min_score :
# initially I wanted pixels that bridged that gap between both classes the most.
# Once granny smith score crossed the timberwolf score I only cared about increasing
# score for granny smith class as long as timberwolf stayed below granny smith sclass
if (best_score_tw > scores['timber wolf']) and (best_score_gs < scores['Granny Smith'])
temp_image = output_image
best_score_tw = scores['timber wolf']
best_score_gs = scores['Granny Smith']
rand_red_best = rand_red
rand_green_best = rand_green
min_diff = scores['timber wolf'] - scores['Granny Smith']
best_row, best_col = row, col
print(iter_1, [rand_red,rand_green,0], ':', row,col, ";\n",min_diff,'\n')
pixel_count += 1
input_image = temp_image
scores = get_scores(input_image)
print(pixel_count,
'\n', row, col, [rand_red_best, rand_green_best, 0],
'\n', scores, '\n')
dict_pixels['pixels'].append(([row_best,col_best],[rand_red_best,rand_green_best,0]))
np.save('/output_image.npy', input_image)
np.save('/pixel_data.npy', dict_pixels)
scores = get_scores(input_image)
best_score_tw = scores['timber wolf']
best_score_gs = scores['Granny Smith']
print(best_score_tw, best_score_gs)
结果看起来像这样,它被分类为青苹果。
原始图像由 AI Village 提供,许可 署名 4.0 国际 (CC BY 4.0)(由作者编辑)
放大结果
我的狼图像显然被篡改了,但这是一个高分辨率图像,攻击成功了。我相信如果再给点时间,可能会用更少的像素达到更好的欺骗效果。
像素化的狼被分类为青苹果。由 AI Village 提供,许可 署名 4.0 国际 (CC BY 4.0)(由作者编辑)
提醒一句…
在目睹了一个可能的像素攻击版本,这个版本仅根据预测分数和试验与错误导致了模型的错误分类之后,让我们进一步探讨如何避免这种情况。
当然,这里的目标不是鼓励执行像素攻击,除非这是对自己模型的抗性检查。探索对抗性机器学习实践的复杂性本质上是为了培养如何保护模型不受这些方法影响的意识。
所以让我们深入探讨如何避免这些情况…
像素攻击的可能弱点
像素攻击,特别是在黑箱设置中,已经涉及到大量的试验和错误,但各种策略可以进一步增强模型对这些攻击的鲁棒性。
1. 使用高分辨率图像
高分辨率图像更难被攻击,因为它们需要更多的资源和更高数量的改变特征/像素,因此更难以微妙地篡改。
图像由作者提供
澄清:例如,考虑一个来自 CIFAR 的 32x32 图像,它的像素较少,使其更容易受到篡改。相比之下,高分辨率图像由于像素数量更多,因此不易受到像素攻击。另一方面,这些图像虽然在隐蔽篡改时更具挑战,但在训练过程中可能会产生更高的计算成本。因此,需要在安全性和计算效率之间找到平衡。
提高接受结果的预测分数阈值
由于被攻击的图像具有较低的预测分数,可以利用分数阈值来检测潜在的对抗性攻击。
作者提供的图片
澄清:例如,设置一个阈值,低于该阈值的预测被视为不确定,这为防御对抗性攻击提供了额外的安全层。
再次值得指出的是,这是一个权衡,高阈值提升了信心,但可能会限制分类器的敏感度。找到合适的平衡对于避免拒绝有效预测同时抵御对抗性攻击至关重要。
考虑到 CNN 在关键应用中的鲁棒性
结果表明,尽管卷积神经网络(CNNs)并非免疫,但由于利用了空间层次结构,它们对这种对抗性攻击的敏感性较低。
作者提供的图片
澄清:简单来说,虽然平均模型将像素视为单独的输入,但卷积神经网络(CNN)通过内核窗口考虑预定义的关联,从而增强了对抗性操控的鲁棒性。
预测前的图像预处理
在将图像输入神经网络进行预测之前,应用一种稳健的预处理技术可能是值得的,从而限制黑箱攻击。
作者提供的图片
澄清:例如,图像压缩有助于减少篡改的影响,而计算机视觉算法可以识别图像中的失真或异常。此外,由于被操控的像素可能与原始图像的颜色或模式不完全匹配,因此可以应用插值技术。
安全的机器学习模型
上述方法虽然有效,但并非一刀切。最终,保护某个机器学习模型免受对抗性攻击包括在各种条件下(包括暴露于潜在对抗性输入)对模型进行严格的测试和验证。
决定添加多少安全性以及多频繁更新模型取决于其重要性和可能面临的威胁类型。但是,了解伦理考虑和理解可能的威胁有助于减少攻击风险。
总结……
虽然像素攻击或任何图像的操控对基于图像的人工智能系统确实是一个大问题,但我们也可以做很多事情来防范这些攻击。攻击者可以篡改单个像素来欺骗模型,使其犯错误,从而危害图像识别和安全系统等关键应用的可靠性。这不仅会导致安全漏洞,还会破坏客户和利益相关者的信任。
另一方面,机器学习从业者拥有工具来确保模型不容易受到这些攻击的影响。
在这篇文章中,我尝试探索了像素攻击,受到 CTF 挑战的启发,并深入研究了欺骗图像分类模型的一些复杂性。虽然狼确实变成了一个青苹果,但这需要大量的计算和反复试验,如果模型采取了一些预防措施,攻击将会失败。
我在下面列出了一些类似的方法资源,希望你能发现这些话题对保持模型安全有用。
资源
[## GitHub - Hyperparticle/one-pixel-attack-keras: Keras 实现的“单像素攻击”…
使用差分进化在 Cifar10 上进行“单像素攻击以欺骗深度神经网络”的 Keras 实现和…
github.com [## GitHub - max-andr/square-attack: Square Attack: 一个查询高效的黑箱对抗攻击,通过…
Square Attack: 一个查询高效的黑箱对抗攻击,通过随机搜索 [ECCV 2020] - GitHub …
github.com [## GitHub - kenny-co/procedural-advml: 无任务通用黑箱攻击计算机视觉…