TowardsDataScience 2023 博客中文翻译(二百八十八)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

照亮您数据科学之旅的可转移技能

原文:towardsdatascience.com/shining-light-on-transferrable-skills-for-your-data-science-journey-a4c67c3d0de8

我对那些从学术界转向商业数据科学的关键可转移技能的看法

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Kirill Lepchenkov

·发布于Towards Data Science ·阅读时间 9 分钟·2023 年 4 月 7 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

光束形状图像(作者拍摄)

前言

我在激光物理、非线性光学和固态激光工程领域担任研究员已有 5 年。虽然我完全沉浸于这个领域,并对自己所做的工作感到兴奋,但在某个时刻,我过渡到了商业数据科学行业。

在数据科学领域工作了额外 6 年后,我感到我在应用物理领域发展起来的技能在与激光物理完全无关的商业项目中得到了完美应用。

关于学术经验可能多么有用已经有很多讨论,但我决定表达我个人对这一主题的看法。

为了阐明我的观点,我决定根据每个技能组的实用性及其原因进行评级。

这篇文章适合谁?

我认为我写这篇文章主要是为了那些考虑从学术环境转向商业领域的人,但也是为了我自己,反思两者之间工具、技能和思维方式的交集。

文献综述经验 → 7/10

为什么文献综述在商业数据科学中是如此重要且可转移的技能(习惯)?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在我物理学时期的文献综述(作者的桌面)

在我看来,文献综述在商业数据科学中有点被忽视和误解。我并不是说我们对全新模型架构和框架设计的阅读不够(这部分做得非常好)。

但当涉及到快速有效地获取关于项目主题的更结构化和有价值的信息时——在我看来,这是数据科学领域存在的最大空白。

文献综述 可能不是这里最好的术语。我也可以称其为背景研究,或最先进的分析

在处理商业问题时,我认为对问题主题有一定的理论基础是至关重要的。文献综述的作用:

  • 为数据战略的可靠决策奠定基础。 了解领域内现有的技术和方法。

  • 加快入职过程。 如果你对自己正在从事的领域不熟悉,尽快获取相关知识是实现价值生成的第一步。

  • 提高与领域专家的沟通质量。 领域专家,也称为主题专家,对于解决数据问题至关重要。但他们通常不编程,而且非常忙。因此,数据科学家必须掌握一些领域特定的术语和概念,以便与这些专家有效沟通和顺畅合作。

  • 大幅提升洞察力的质量。 根据我的经验,文献综述为数据收集、预处理、建模和评估提供了决策基础,最终提高了你提供的洞察力的质量。在我的经验中,它有效,但并非总是如此。

关注文献综述,并投入时间和精力,体现了一种特定的心态——开放、谦虚和好奇。 文献综述有助于避免重新发明轮子或陷入确认偏差的陷阱。

我相信,随着大语言模型和基于这些模型的服务的扩展,文献综述的过程会发生变化,但我们还未到达那一步。

记录→ 9/10

将学术界的记录实践转移到商业数据科学中,对我来说非常有价值。除了多个实际好处,它在经历研究人员工作生活中的起伏时,给你一种无价的连续感。在我看来,通过采用保持实验室笔记本这一关键习惯,数据科学家可以轻松跟踪实验、记录想法和观察,监控个人和职业成长。我写了一整篇文章来阐述这样做的好处,欢迎查阅!

实验室笔记本作为数据科学从业者的选择武器

我的一套有效笔记记录原则,以实验室笔记本的形式呈现

towardsdatascience.com

编程知识 → 6/10

在我的科学历程中,我每天都在处理实验数据、进行数值模拟和统计学习。编程对于开发和测试新的激光设计(数值模拟)也是必不可少的。

我一直在不断使用它来处理典型的数据科学任务:

  • 实验数据处理(Python,Wolfram

  • 数值模拟(Wolfram, Matlab, Python)

  • 统计学习(Wolfram, Matlab, Python)

  • 数据可视化(Origin Pro, Python, R)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我的“数据工作”科学工具栈

Wolfram(更具体地说是 Wolfram Mathematica)是使用最频繁的工具,因为我们在实验室里有它的许可证。它有很棒的工具集用于求解非线性微分方程,我们广泛使用它进行数值模拟。

Python 是我处理实验生成的数据(光束形状、振荡图)的首选工具。

说到数据可视化,Origin 是主要工具,因为它允许将视觉元素嵌入到文本文档中,同时保持可编辑性。折线图、直方图(包括核密度估计器)、回归分析——Origin 是一个很棒的工具。Origin 有一个图形用户界面,所以这不仅仅是编程的问题,我提到它是为了确保 Python 和 R 不会独占所有数据可视化的功劳。

总的来说,我对上述提到的每一个工具都有扎实的使用经验:我了解语法,并且能够以相当高的效率解决问题。那么为什么只有 6/10 呢?为什么在学术界获得的编程技能在商业数据科学中相对难以转移?这确实是一个相当强的声明,但我认为学术经验的缺点可能会超过其优点。主要是因为许多科学环境中完全忽视了良好的软件实践

警告:这一说法基于我在应用物理领域的个人经验,并且绝对不适用于所有在学术界工作的人。对这一部分的内容要持保留态度!

一方面,忽视良好的软件原则是研究人员优化研究速度和发表数量,而非代码质量和可维护性的自然结果。另一方面,几乎没有人从正统的软件开发转到学术界(出于经济原因),因此根本没有真正的生产专业知识。我还应该提到,设计实验、进行文献综述、收集测量数据、编写处理代码以及获得有价值的见解——所有这些同时进行是非常耗费精力的。因此,你根本没有足够的资源去学习软件开发。

测量能力→ 9/10

这一点难以解释,所以请耐心听我说。应用激光物理中的测量工作本身就是一个独立的学科。提供有价值的测量是一项需要多年训练的技能!原因有很多:你必须理解过程的物理学,遵循测量协议,并且具备操作复杂且昂贵仪器的专业知识和训练。

例如,我一直在使用二极管泵浦脉冲固态激光器,测量激光束的多个参数:脉冲持续时间、脉冲能量、重复频率、束形、发散度、偏振、光谱内容、时间特性和光束腰部。进行这些测量中的任何一项都非常困难。比如说,你想测量束形(见下图)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

beam profiles 3d(作者拍摄)

光束形状指的是激光束在其截面或横截面上的强度空间分布。

理论上,你只需将激光光束对准 CCD 相机,几秒钟内就能获得光束形状。但实际操作起来则大相径庭。如果你正在使用脉冲固态激光且有相当的脉冲能量,并且你知道自己在做什么,你会将激光光束导向高质量的光学楔子,将大部分脉冲能量集中到一个陷阱中,并使用仅有原始光束一部分能量的反射光束进行工作。这样做是为了保护 CCD 相机免受灾难。但使用楔子还不够。你还需安装一个可调的光束衰减器,将其锁定到最暗模式,然后逐渐降低吸收率,直到在 CCD 相机上获得正确的曝光。

如果你正在使用对人眼不可见的红外激光,你会面临一个问题:你必须在看不到实际光束的情况下通过小孔引导光束。这项技能只能通过训练和实践获得。顺便说一下,每一步光束操作都必须极其小心,以遵守安全规定:你必须佩戴适当的防护眼镜,使用保护屏幕等。

好的,继续,现在你的光束被衰减并完美地对准了 CCD 相机。但你还有很多工作要做:将 CCD 相机连接到激光电源单元以实现同步并产生稳定的图像。如果你做对了所有步骤——你就能获得图像。等等,图像?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

光束轮廓 2D(由作者捕获)

然后你意识到,如果你的激光以 50 Hz 的脉冲重复频率运行,这意味着它每秒产生 50 个脉冲。每个脉冲可能有略微不同的光束轮廓。你该如何获得结果?你应该随便选一个脉冲并捕捉图像?还是应该使用一定数量的脉冲生成平均图像?哦,管理 CCD 相机的软件默认启用了平均功能?

让我们结束这个“测量光束形状”的废话吧。根据我一生中的所有测量经验,我有两个关键的可转移品质:警惕(永远不要仅仅相信表面)和对元数据的细致关注(数据是如何测量或记录的,使用了哪些工具,甚至最初发生的原因)。这两者在处理实际数据时都是金标准。因为它让你在产生实际影响时更加高效,而不会陷入麻烦之中。这在学术界和商业数据科学中都很受重视。

数据通信熟练度 → 10/10

当我在学术界时,我并没有认为数据沟通是一个特别值得关注或有价值的写作主题。处理数据可视化、讨论数据和理论,以及撰写科学论文只是工作的组成部分。但经过多年的研究,你在不同层次(正式和非正式)上获得了扎实的数据沟通技能。

写作科学论文是正式数据沟通类型中最具挑战性的技能之一。要能写出一个结构合理(摘要 → 引言 → 文献综述 → 方法论 → 结果 → 讨论 → 结论 → 致谢)的引人入胜的文章,需要大量的练习。文章的结构本身假设你有一个故事要写。而且这不仅仅是写作:你还必须知道如何制作引人注目且有目的的数据可视化。这一切都是为了将你的信息传达给观众。

我将这一技能的可转移性评分为 10 分(满分 10 分),因为商业数据科学毫不意外地依赖于人与人之间的互动、传达你的思想和结果。

结论

总体而言,我相信拥有科学背景的人可以为数据科学领域带来独特的视角和宝贵的技能。对于那些认为转向商业数据科学意味着放弃所有辛勤工作和专业知识的学术界人士,我提供一个不同的观点:你有大量的价值可以带到桌面上。在我看来,最佳的行动方案是利用你现有的技能,同时掌握你转型领域的新技术和最佳实践(我们都知道这是一个终身的旅程)。

最短路径(Dijkstra)算法:一步步的 Python 指南

原文:towardsdatascience.com/shortest-path-dijkstras-algorithm-step-by-step-python-guide-896769522752

使用 OSMNX 1.6 和长距离路径的更新

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Bryan R. Vallejo

·发布于 Towards Data Science ·6 分钟阅读·2023 年 10 月 4 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。摩洛哥的最短路径(约 350 公里)

这个著名的算法在 Python 库 OSMNX 中实现,可以用来寻找两个位置之间按距离或时间加权的最短路径。该算法使用 OpenStreetMap(OSM)网络,通过 Python 库 NETWORKX 在后台寻找驾驶、步行或骑车的路线。

我写这个更新是因为函数的参数稍有变化,且有人询问为什么我的代码在旧博客文章中无法工作,这只是因为代码是用旧版本的 osmnx 编写的。

旧教程包含了相当有价值的过程,但我决定做一个一步步的指南,这样获取最短路径的过程会更准确,使用这个指南的分析师可以真正理解整个过程。

这里是旧教程,如果你想查看一下。

在芬兰赫尔辛基,使用不同的网络

## OSM 街道网络中使用的最短路径算法

车辆、自行车和行人最短路径分析的 GIS 自动化技巧

towardsdatascience.com

在爱沙尼亚塔尔图,使用步行网络

## 使用 OSM 步行网络的最短路径算法

使用 OSM 数据在爱沙尼亚塔尔图寻找最短步行路径

towardsdatascience.com

OSM 数据许可

介绍

在这个实践中,我将使用摩洛哥的两个位置。这个实践由我的一位读者 Hanae 提出,她提供了原点和目的地。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。原点和目的地位置。

编码实践

正如我提到的,我将做一个逐步指南,所以让我们开始。在此之前,让我们导入所需的库。

import osmnx as ox
import geopandas as gpd
from shapely.geometry import Point, LineString
import pandas as pd
import matplotlib.pyplot as plt

1. 定义原点和目的地

简单地,我们将创建几何对象作为点:

# origin and destination geom

origin_geom = Point(-5.6613932957355715, 32.93210288339607)

destination_geom = Point(-3.3500597061072726, 34.23038027794419)

2. 提取 OSM 图对象

然后,我们将提取图形,用于生成最短路径。我们逐步来看。

  • 从原点和目的地创建 GeoDataFrames
# create origin dataframe
origin =  gpd.GeoDataFrame(columns = ['name', 'geometry'], crs = 4326, geometry = 'geometry')
origin.at[0, 'name'] = 'origin'
origin.at[0, 'geometry'] =origin_geom

# create destination dataframe
destination =  gpd.GeoDataFrame(columns = ['name', 'geometry'], crs = 4326, geometry = 'geometry')
destination.at[0, 'name'] = 'destination'
destination.at[0, 'geometry'] = destination_geom
  • 获取包含原点和目的地的图

我们将使用 Geopandas 的 envelope 函数来将多边形用作掩膜以提取图形。

首先一个简单的函数。

def get_graph_from_locations(origin, destination, network='drive'):
    '''
    network_type as drive, walk, bike
    origin gdf 4326
    destination gdf 4326
    '''
    # combine and area buffer
    combined = pd.concat([origin, destination])

    convex = combined.unary_union.envelope # using envelope instead of convex, otherwise it breaks the unary_union

    graph_extent = convex.buffer(0.02)

    graph = ox.graph_from_polygon(graph_extent, network_type= network)

    return graph

然后,使用它并绘制结果。

graph = get_graph_from_locations(origin, destination)
fig, ax = ox.plot_graph(graph, node_size=0, edge_linewidth=0.2)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。图包含原点和目的地

3. 找到原点和目的地的最近节点

获取使用原点和目的地位置的网络中最接近的节点。节点代码可以使用 osmnx 函数获得。

# ------------- get closest nodes

# origin
closest_origin_node = ox.nearest_nodes(G=graph, 
                                       X=origin_geom.x, 
                                       Y=origin_geom.y)

# destination
closest_destination_node = ox.nearest_nodes(G=graph, 
                                           X=destination_geom.x, 
                                           Y=destination_geom.y)

你可以检查并注意到我们目前只有代码。

4. 找到最短路径

然后,使用最短路径函数来获取路线。

# run
route = ox.shortest_path(graph, 
                         orig = closest_origin_node, 
                         dest = closest_destination_node, 
                         weight = 'length')

这将返回一堆路径中节点的代码。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像节点代码

5. 从节点创建 Line Geometry

我们将从图中提取节点的几何形状,并创建一个表示最短路径的 LineString 几何体。

首先为此创建一个函数。

def nodes_to_route(graph_nodes, path_nodes):

    # Extract the route nodes of the graph
    route_nodes = graph_nodes.loc[path_nodes]

    # ---> note! If you have more routes, check for each one, to be removed in length is 1\.  A path can not be built with only 1 node.

    # Create a LineString out of the route
    list_geom = route_nodes.geometry.to_list()
    path = LineString(list_geom)

    # Append the result into the GeoDataFrame
    route_df = gpd.GeoDataFrame( [[path]] )

    # Add a column name
    route_df.columns = ['geometry'] 

    # Set geometry
    route_df = route_df.set_geometry('geometry')

    # Set coordinate reference system
    route_df.crs = graph_nodes.crs

    # remove nans
    route_df = route_df.dropna(subset=['geometry'])

    return route_df

获取节点,并在函数中使用它们。

# get all network nodes
graph_nodes = ox.graph_to_gdfs(graph, edges=False)

# get the line geometries from osm nodes
route_gdf = nodes_to_route(graph_nodes, route)

6. 计算距离

我们将使用墨卡托投影来测量路线的米数。如果你想要更准确的结果,可以使用位置投影。

首先,为此创建一个函数。

def compute_distance(shortest_path_gdf):
    '''
    Compute distance in EPSG:3387

    '''

    # project WGS84 to EPSG3387
    distances = shortest_path_gdf.to_crs("EPSG:3387").geometry.length

    # add
    shortest_path_gdf['distance'] = distances

    return shortest_path_gdf

然后,使用它:

# calculate distance m
route_distance_gdf = compute_distance(route_gdf)

它将测量约 351.243 米的路线。

7. 保存网络和路径

将网络和路径保存到本地磁盘上用于地图。

提取网络并定义 GeoDataFrame:

# fetch network
network = ox.graph_to_gdfs(graph, nodes=False)

# get only needed columns
network_gdf = network.reset_index(drop=True)[['geometry']]

然后存储:

network_gdf.to_file(r'osm_network.gpkg')
route_distance_gdf.to_file(r'osm_shortest_path.gpkg')

你可以使用这些数据来创建自己的地图。例如,这个在 QGIS 中:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。QGIS 中的最短路径和网络

8. 绘制结果

我们将通过绘制所有元素来检查我们的工作是否正确。

# plot network
ax = network_gdf.plot(figsize=(12, 10), linewidth = 0.2, color='grey', zorder=0);

# origin and destination
origin.plot(ax=ax, markersize=46, alpha=0.8, color='blue', zorder=1)
destination.plot(ax=ax, markersize=46, alpha=0.8, color='green', zorder=2)

# route
route_distance_gdf.plot(ax=ax, linewidth = 3, color='red', alpha=0.4, zorder=3)

plt.axis(False);

结果将会是这样的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。最短路径、网络、起点和终点在 Matplotlib 中

已知限制

最短路径是通过节点网络的联合生成的,线条并不完全匹配道路。这完全没问题,因为我们要的只是一个近似值。如果你需要导航,应该使用 Google API 进行路由,或其他提供商。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。线条是通过节点创建的。

结论

使用 OSMNX 的最短路径算法提供了路线的近似值,并且可以广泛用于城市或区域规模的可达性研究。这个 Python 库不断更新,函数或参数可能会有所变化,因此建议在我们的工作流程中持续更新库版本。

如果你有问题或需要定制分析,欢迎联系我:

Bryan R. LinkedIn

深度伪造技术是否应该开源?

原文:towardsdatascience.com/should-deepfakes-be-open-sourced-87d7644a0765?source=collection_archive---------9-----------------------#2023-05-25

意见

讨论了开放深度伪造技术的利弊

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Jack Saunders

·

关注 发布于Towards Data Science ·5 分钟阅读·2023 年 5 月 25 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片生成使用了DreamStudio

我是一名博士研究人员,创建可以被认为是深度伪造的技术用于我的研究。我对创建逼真的数字双胞胎和提升娱乐水平的能力感到着迷。在开始我的研究之前,我曾认为这些模型可能会造成太多伤害,不适合向公众发布。在过去几个月里,我注意到越来越多的顶尖声音主张将开源软件作为人工智能领域的核心原则。虽然这次讨论几乎完全集中在大规模语言模型(LLMs)上,但我认为这一理念在整个领域中普遍存在。我完全支持几乎所有人工智能模型的开源,但对于我自己的研究领域,我不太确定。

几乎没有哪个领域的误用潜力像深度伪造技术那样高。

到目前为止,我的方法是走中间道路。尽量以一种不需要博士学位才能理解的水平来传达深度伪造技术的工作原理。然而,我经常怀疑这是否是正确的方法。本文的目的是尝试启动关于我们深度伪造研究人员应采取方向的讨论。 有鉴于此,本文讨论了开源的一些优缺点。

开源的优点

我们常常听到关于深度伪造技术的负面信息。有人可能会问,像我这样的研究人员为什么还要考虑创建开源模型。然而,实际上有很多合理的理由这样做:

  • 透明性:对我来说,这是最重要的一点。如果大多数主要的深度伪造模型完全开源,那么它们就会变得透明。这样,监管者就可以理解他们正在处理的内容,其他研究人员也可以开发更好的检测算法。对于深度伪造技术来说,那些希望利用它们进行恶意行为的“坏演员”和那些试图防止这种伤害的“好演员”之间将会有一场军备竞赛。你可以确定“坏演员”无论我们是否发布我们的模型都会开发他们的深度伪造技术。在开源中,我们可以给“好演员”提供更多的数据,以构建他们的防害模型。

  • 公平性:如果我们选择不进行开源,只有那些拥有人才和计算资源的机构才能创建深度伪造技术。根据经验,开发这些模型需要很长时间,没有开源软件的话,很少有人能做到这一点。这可能进一步将权力集中在已经强大的手中。深度伪造技术可以在多个市场中使用,并且可能具有数十亿美元的潜在价值。例如,仅配音市场就预计超过 35 亿美元。如果只有像谷歌这样的公司能够创建深度伪造技术,那么只有谷歌这样的公司才能从中获益。

  • 意识: 深度伪造技术正在迅速发展。我们很可能很快会达到这样一个阶段,即你无法相信在线上看到的任何视频,除非它以其他方式经过验证。虽然很多人对此有模糊的认识,但我认为很少有人真正理解其含义。作为深度伪造研究人员,我们有责任帮助教育公众。我们需要真正鼓励每个人保持良好的数字怀疑态度,检查他们在线上看到的任何媒体的来源,并质疑其真实性。开源软件有所帮助。当模型可以在网上自由获取时,教育变得更加容易。如果你能自己创建深度伪造,你自然会留意其他人的伪造。

缺点

当然,将深度伪造模型开源存在许多潜在的缺点,从明显的到更微妙的都有。

  • 人们将滥用它们: 无论我们如何监管或检测模型有多么有效,总会有一小部分人会出于最糟糕的理由使用深度伪造。从报复色情虚假信息,这项技术有一些非常恶劣的应用。如果我们开源模型,就会使所有人更容易访问它们,这无疑会造成伤害。确实,一些坏人无论如何都会做到这一点,尤其是大型犯罪或国家组织,但大多数寻求伤害的人可能本来无法做到,如果没有开源模型的话。

  • 保护措施可能被移除: 保护深度伪造技术不被滥用的较好方法之一是引入保护措施。特别是,大多数创建深度伪造的团队都使用水印技术。水印涉及将数据添加到创建的视频中,以一种对人类和大多数软件都不可见但可以被拥有“密钥”的人轻松识别的方式。这意味着,例如,YouTube 或 Twitter 可以快速检测到视频是否由深度伪造平台创建,并将其移除。由于水印只能被那些获得了这个秘密密钥的人看到,坏人无法移除它们。如果我们开源深度伪造生成,那么坏人将可以简单地跳过添加水印。 这使得深度伪造变得不可检测。

  • 单向共享: 如果我们再次考虑所谓的好人和坏人之间的军备竞赛,那么我们可以看到开源的另一个缺点。如果我们这些好人开源了我们所有的软件,那么坏人可以在此基础上进行构建。另一方面,坏人不会开源他们的模型,这意味着信息只是在一个方向上共享。这给坏人带来了显著的优势。

总结

正如所见,这不是一个容易回答的问题。利弊权衡很多,无论哪种情况,潜在的危害都很大。在写这篇文章的过程中,我进行了许多对话。让我感到惊讶的一点是开源绝对主义者的数量。我经常听到的一个论点是,深伪技术已经存在,不可能把魔鬼收回瓶子。许多人,包括我自己在内,都认为,我们将迎来一个深伪技术与现实难以区分的时代。如果到那时我们都没有足够的意识去质疑这些技术,我们可能会面临很大麻烦,因为坏分子可能在未被察觉的情况下行动。这是一个我认为最近对 AI 发展暂停的呼吁忽视的点。然而,虽然开源可能会减少长期的危害,但它也为那些可能想要滥用深伪技术但没有技术能力的人打开了短期危害的大门。

尽管我对开源问题仍然未作决定,但我比以往任何时候都更自信地认为,这个讨论是必须进行的,至少深伪技术的研究需要在公开的环境下进行,并且要向公众传达。我强烈鼓励每个人发表意见。如果你有我没有涉及到的看法或任何问题,请留下评论或直接联系我,我真的希望听到尽可能多的人反馈。

我真的应该吃这个蘑菇吗?

原文:towardsdatascience.com/should-i-really-eat-that-mushroom-9edeaa69d934

使用 CatBoost 梯度提升决策树对可食用和有毒蘑菇进行分类

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Caroline Arnold

·发表于Towards Data Science ·阅读时间 6 分钟·2023 年 8 月 17 日

大多数教育和现实世界的数据集包含分类特征。今天我们将讨论来自CatBoost库的梯度提升决策树,该库原生支持分类数据。我们将使用一个蘑菇数据集,这些蘑菇要么是可食用的,要么是有毒的。蘑菇通过分类特征如颜色、气味和形状进行描述,我们想要回答的问题是:

基于其分类特征,这种蘑菇是否安全食用?

如你所见,风险很高。我们希望确保机器学习模型的准确性,以免我们的蘑菇煎蛋卷以灾难收场。作为额外奖励,最后我们将提供一个特征重要性排名,告诉你哪个分类特征是蘑菇安全性的最强预测因素。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Andrew Ridley提供,来自Unsplash

介绍蘑菇数据集

蘑菇数据集可以在这里找到:archive.ics.uci.edu/dataset/73/mushroom [1]。为了清晰展示,我们将从原始的简短变量中创建一个pandas DataFrame,并用适当的列名和长格式变量进行注释。我们使用 pandas 的replace函数,长格式变量来自数据集描述。目标变量只能取TrueFalse值——数据集创建者采取了保守的方式,将可疑的蘑菇归类为不可食用。

在检查数据集缺失值后,我们发现只有一列——stalk_root——受到影响。我们删除了这一列。

数据集的探索揭示数据相当平衡:在 8124 个蘑菇中,4208 个是可食用的,3916 个是有毒的。我们将数据框分为目标变量is_edible和其余的蘑菇特征。然后,我们通过对目标变量进行分层,将数据集分为训练数据和测试数据。这确保了两个拆分中的类别分布是可比较的。

CatBoost 库

CatBoost 是一个开源的机器学习包,用于梯度提升决策树。可以通过按照安装说明来获取 CatBoost Python 包。对我们来说最重要的组件是catboost.Pool,它组织数据集并指定分类特征和数值特征,以及我们的模型catboost.CatBoostClassifier。分类特征在机器学习算法中可能难以处理,它们必须被编码成数值才能用于训练。每个分类值都与一个数字相关联,例如蘑菇颜色的brown->0, black->1, yellow->2, ...CatBoost 可以自动处理分类输入变量,这样我们就不用再添加 独热编码 到流程中。这不仅方便,而且 CatBoost 算法也经过优化,以便更快地训练分类变量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Haithem FerdiUnsplash 拍摄

梯度提升决策树

决策树是成熟的机器学习算法,根据特征值将样本分类为不同的类别。单棵决策树容易过拟合。 因此,通常使用决策树的集合来实现更好的性能。在梯度提升决策树中,树的集合通过迭代更新树来构建。每次迭代的树通过在应用前一个树后留下的残差上进行训练,提供了比前一次迭代更小的改进。该过程在损失收敛时停止,即当添加更多树没有增值时,或达到固定的总树数时。有关梯度提升决策树的更详细介绍,请参见页面底部推荐的博客文章。

蘑菇分类

在蘑菇数据集中,所有特征都是分类的,并在Pool中相应指定。我们为训练和测试分别构建一个Pool。目标变量被转换为数值,因为这与CatBoostClassifier的损失函数更好地集成。分类器本身的格式类似于 scikit-learn。可以调整许多属性,包括学习率、树的总数和树的正则化。损失函数是log-loss,因为我们处理的是二分类问题。

对于二分类目标类的预测,使用的是对数损失或交叉熵函数。实际值 y 与模型提供的概率 p 进行比较。

我们在下面的代码框中定义数据集和模型。为了比较,我们训练了一个单一的决策树和一个完整的梯度提升决策树。

评估

现在我们准备评估分类器在测试数据上的表现。食用毒蘑菇可能导致严重的健康问题,因此我们关注减少假阳性。我们计算精确度指标,即实际可食用的蘑菇数量与预测为可食用的蘑菇数量的比例。

单一决策树的精确度为 97%,对于一个分类算法来说相当不错。但是通过梯度提升树,我们可以将精确度提高到 100%,测试数据集中没有毒蘑菇被误标为可食用。混淆矩阵显示,梯度提升决策树在测试集上提供了最佳表现。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

单一决策树(左)和梯度提升决策树(右)的混淆矩阵。

特征重要性

这很好,但我们可能没有整天的时间来确定每种我们想吃的蘑菇的 22 个特征。那么,确定蘑菇是否可食用的最重要特征是什么呢?

为了回答这个问题,我们使用内置的模型属性feature_importances_来推导梯度提升树分类器的特征重要性排名。结果显示,气味在特征重要性排名中占据主导地位,其次是孢子 印刷颜色数量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从训练好的 CatBoostClassifier 中获得的蘑菇数据集的特征重要性排名。

进一步观察可能的气味值可以发现,这个特征本身已经是一个很好的预测因素,能够判断一只蘑菇是会成为你餐点的美味补充,还是会让你一天结束在医院。数据集中所有散发茴香杏仁气味的蘑菇都是可食用的。没有气味的蘑菇大多也是可食用的。你应该远离腥臭、辛辣、刺鼻、腐臭、木焦油霉味的蘑菇——老实说,这些蘑菇听起来一开始就不怎么美味。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对蘑菇数据集中的气味特征进行详细分析。

摘要

我们介绍了蘑菇数据集,其中包含仅由分类变量描述的可食用和有毒蘑菇样本。我们引入了 catboost 包,它对分类数据效果很好,并提供了梯度提升决策树。训练了一个模型来对蘑菇进行分类,并取得了令人满意的表现。气味是预测蘑菇安全性的最强指标。希望你喜欢这篇博客文章,并对模型在真实蘑菇上的应用不负责任 😃。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Zhen H提供,来自Unsplash

深入阅读

数据集参考

[1] 蘑菇。UCI 机器学习库(1987 年)。doi.org/10.24432/C5959T. 本数据集采用创意共享署名 4.0 国际(CC BY 4.0)许可。

我们是否应该更依赖数据?有时候。

原文:towardsdatascience.com/should-we-be-more-data-driven-sometimes-3dcf5e2753ae?source=collection_archive---------3-----------------------#2023-08-17

何时应该依赖数据,何时数据依赖反而会成为障碍。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Robert Yi

·

关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 8 月 17 日

我在 Airbnb 担任数据科学家时,Covid-19 爆发了。正如你所料,Covid-19 对于一个依赖良好人际互动的业务来说特别残酷。当世界正在形成孤立的社交圈时,想要让人们住在陌生人的家中是非常困难的。因此,正如你所预期,我们的指标急剧下滑——我们的核心指标下降到了个位数的同比值。没有人再预订 Airbnb 了,可以肯定的是,也没有人愿意开设新的 Airbnb。

当我们面临那突如其来的指标悬崖时,我们的首席执行官布赖恩迅速作出了令人钦佩的回应。尽管我们都在设置家庭办公室,并从好市多囤积卫生纸和罐头食品,布赖恩却召开了紧急全员大会。他明确告诉我们:“我们所知的旅行已经结束。”他没有明确的下一步计划,但在风暴中却有一个灯塔般的指示:停止一切非关键工作,弄清楚如何在疫情中生存下来。

随后发生的事情令人印象深刻。公司有效地转变了方向,在如此大规模的公司中参与其中是非常激动人心的。我们在创纪录的时间内推出了 Airbnb 在线体验。我们以“近在咫尺即为远方”为新的口号,策划并推动人们前往那些在疫情期间适合作为避难所的地点。明显不符合未来方向的新举措被关闭(我曾参与一个名为“社交住宿”的团队,尽管投入巨大,我们还是迅速终结了这一项目)。我们进行了新的融资,重组了公司。公司每天做出数百甚至上千个决定,因此,成功地在疫情最严重时期游刃有余,表现出尽可能好的灵活性。

话虽如此,尽管业务变动颇具趣味,我实际上更想在这篇文章中讨论这一时期数据的作用以及我们可以从中获得的经验教训。我最令人震惊的认识是:数据,曾经在几乎所有战略对话中扮演关键角色,却在一夜之间变成了次要因素。那时,为了争取“数据驱动决策”而奋斗将会是可笑的——不是因为数据在这一过渡期没有用,而是因为在危机中数据不应成为主导。接下来,我将讨论这种思维方式转变的根本原因:紧迫性。让我们考虑不同的决策情况,然后讨论我们应该如何利用数据。是时候真正谈谈“数据驱动”应该意味着什么了。

决策的划分

你可以通过两个轴来清晰地划分决策过程:决策的紧迫性决策的重要性。根据你的决策在帕内特方格中的位置,分析的参与程度可以并且应该有所不同。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

低紧迫性,高重要性

一方面,当一个决定极其重要但并不特别紧急时,我们可以按照理想的方式进行分析——与利益相关者紧密迭代,以更好地导航可能的行动空间。例如,假设你公司的高管想要彻底修改你的登陆页面,但他们希望你支持决定应该放置什么内容。你团队中的机器学习软件工程师跳转到卡片分类解决方案,但你和你的利益相关者知道,更关键的决定是是否首先要应用这种分类解决方案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片。

当前的主页运行良好,因此所需的更改并不紧急,但决定的影响很大——你的更改将影响每一个访客的体验。因此,应该利用分析来更好地导航决策空间:你可以筛选过去的实验,汇总可能有助于当前决策的学习;你可以进行小规模的机会大小检查,以查看任何更改的范围;你可以提供人口统计/渠道/其他分布数据,以更好地了解你可能需要重点关注的内容。

利益相关者必须处理大量的选项,而你可以帮助他们以一种有度量、以假设驱动的方式进行。这就像你在买车一样。花时间进行市场调研是一个好的投资。

高紧迫性,高重要性

另一方面,让我们重新考虑上面的 Covid-19 Airbnb 情况。公司正处于危机状态,领导层已经确定了前进的最佳行动方案:我们需要确定一些市场,推向那些对 Covid 隔离所具有吸引力的市场。你可以像之前的例子那样采取相同的方法——仔细分析细分市场,筛选过去的实验结果等。但每推迟一天做出选择,你将失去两样东西:

  1. 有机会利用新市场。

  2. 有机会进行测试并学到东西。

因此,你提出了一个简单的假设:如果你选择一些与主要城市相对接近的地点,那么你将最大化预订量,因为客人将(a)感到足够隔离于 Covid,但也(b)足够接近以便在紧急情况下能回到家中与朋友和家人团聚。你在几小时内回到高管那里,他们发起了一个推进这些地点的倡议,你发现一些地点效果更佳,从而为你的第二批选择提供了参考。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片。

在这里分析的最佳参与程度与低紧急情况有所不同——你仍在帮助你的利益相关者在思想迷宫中导航,但所做的决策大多是由直觉驱动的,因此你的参与自然较为浅薄。这并不是说你应该盲目顺从,强化反应性先例——仍然要理解原因,但接受你的参与将会是较少结构化和严谨的。尽管你可以在足够的时间里帮助利益相关者做出更好的决策,但你没有足够的时间,现在做出 80%正确的决策远比明天做出 90%正确的决策要有价值得多。

你遇到车祸了。获取一些数据来评估你和对方司机的健康状况,以及到最近医院的最佳路线是有用的,但你可能不应该花几个小时阅读医院评价。

低重要性

最后,有时决策实际上并没有那么重要。你在用户支持页面上移动一个按钮,实验没有收敛,但你的利益相关者想知道结果的真相。这时候你应该反驳——分析确实可以提供答案,但结果会改变什么行动?你会学到什么?利益相关者已经知道这是一个更好的体验,他们询问是为了确认,但你知道在这种实验曝光水平下,确定性是不可能的。

如果我们的决策没有因为我们的数据工作而发生变化,或至少我们没有从探索数据中学到些什么,我们可能根本不应该做这项工作。学会预测你的工作的影响——如果你帮助做出这个决策,潜在的提升是什么?——然后据此行动。

最终评论

为了明确,我并不主张在这里做一个严格的截断,但在选择适合任务的分析时,速度重要性应当被考虑。当决策非常紧急时,数据几乎总是应该退居二线,依赖直觉。当决策极为重要时,数据应当被更仔细地使用来验证假设,并对直觉进行监督。当决策不重要时,你不应该花很多时间担心这个决策,因此任何分析工作都应该在完成前重新考虑。

👋 你好!我是罗伯特, Hyperquery 的首席产品官,曾是数据科学家和分析师。此帖最初发布在 Win With Data,我们每周讨论如何最大化数据的影响。如果你想了解更多关于 Hyperquery 如何帮助你最大化影响的信息,请随时联系我。你可以在 LinkedIn Twitter上找到我。

我们是否应该虚拟化我们的数据科学系统——还是不虚拟化?

原文:towardsdatascience.com/should-we-be-virtualizing-our-data-science-systems-and-or-not-6cb69b4850f3

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者当前的家庭实验室设置

导航虚拟化数据科学过程的优缺点可能很困难,但有些性能趋势无法忽视

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Will Keefe

·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 9 月 12 日

随着“巨量数据”的利用在各个行业解决问题变得越来越相关,家庭实验室和数据湖规模的数据存储库需要比以往更多的并行计算能力来提取、转换、加载和分析数据。在创建自己的家庭实验室时,是否在虚拟机上还是在硬件上原生创建并行化设置让我感到困惑,我难以找到性能比较。在本文中,我们将探讨每种设置的优缺点,并对每种方法的虚拟和原生性能及基准测试进行逐一对比。

介绍

许多并行计算集群包括多个节点,即指定处理集群中分配任务的计算机。管理这些节点可能是一个大麻烦,这也是为什么数据工程如此有利可图相比于它们的分析对手。通常,公司会管理整个集群的队列,这使得几乎不可能对单独的节点给予个别关注,因此,“高可用性”设置,如 Proxmox、Kubernetes 和 Docker Swarm,是现代企业的必备工具。你可能已经与这些集群互动过,甚至没有意识到这一点——我今天午餐吃的 Chick-fil-A 鸡肉三明治就是通过一个边缘计算 Kubernetes 集群与他们的销售点系统完成的。

在虚拟化机器上计算有许多好处,包括:

  • 整个操作系统可以从企业服务器快速部署到现场,几乎是瞬时的

  • 图像可以实时备份

  • 部署可以容器化以限制范围并增加安全性

  • 在硬件故障的情况下,系统可以在最小的停机时间内迁移

这些并不是新概念,但随着每个组织对数据分析需求的增加,访问并行化部署的方式可以并且应该有所不同,因为虚拟化的缺点通常是你离裸金属越远,你的系统性能受到的影响就越大。虽然一个开发者在处理一个 Excel 文件时可能不会受到影响,但在处理几 GB 甚至 TB 的数据时,需要仔细考虑如何以及何时使用虚拟工具,并建立考虑处理能力的设置。

设置我们的比较

为了验证这一点,我们可以比较使用 readily available 企业硬件的小型到中型组织的设置(我负担不起那些高级设备)。在我的家庭实验室中,我有一个由多个翻新的企业单元构建的计算集群。我在下面的一些文章中链接了如何构建此设置以及我的用途,但现在让我们比较虚拟系统和裸金属系统之间的性能,并特别测量虚拟化的影响。

关于启动自己的数据分析家庭实验室的完整指南

现在是启动你的数据科学家庭实验室以分析对你有用的数据的最佳时机,存储……

towardsdatascience.com [## 在家庭实验室集群上使用 Python 构建分布式机器学习模型

使用我们自己设置的经济实惠的家庭实验室设备,设置并行和分布式机器非常简单……

betterprogramming.pub

自从写了上述文章后,我稍微升级了我的设置,增加了六台配备 Intel Core i7–7700 处理器、32 GB DDR4–2400 RAM 和 256–512 GB SATA III SSD 的 HP EliteDesk 800 G3 Mini。我在一个拍卖网站上以约 80 美元一台的便宜价格购买了这些设备,并额外支付了大约 30–40 美元以将它们升级为新 RAM 和硬盘。处理器都是 65W 型号,配有 90W 电源。即使按照今天的标准,处理器也不容小觑,超频达到 4Ghz 和 4 核心,8 线程。

今天的比较中,我有两个节点并排放置。一个节点运行 Proxmox,这是一个优化用于虚拟化和部署的 Linux 操作系统,运行一个 Windows 10 Pro 虚拟机,另一个节点在裸机上运行 Windows 10 Pro。没有“正确”的操作系统可用,因为这严重依赖于个人喜欢使用的工具,但每个操作系统都有其优缺点。

Proxmox

Proxmox 的一个优点是,它声称对基线处理器的影响很小。静止时,我们可以看到我在节点上部署的虚拟机的资源使用非常低。下面的截图捕捉了仅一个节点的性能摘要。我们可以看到在空闲时,CPU 的使用率仅为极小的百分比,这也与非常有限的功耗(和电费)相关。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

针对特定节点的仪表板 — 作者截图

一旦部署了来宾操作系统,情况就完全不同了。此时的资源利用几乎完全由虚拟机配置决定。

我从玩弄 Proxmox 中学到的一些笔记包括:

  • 学习曲线相当陡峭。这是一个企业工具,虽然使用 Proxmox 的文档非常丰富,但你需要花费大量时间阅读文档、论坛,甚至 Reddit 线程。

  • 从另一个积极的角度来看,解决问题时有大量的文档可供参考,而且围绕该平台建立的社区非常强大。

  • 尽管 Proxmox 具有非常直观的 GUI,但解决问题仍需要一些“跳出框框”的思维。例如,一旦我创建了一个 Windows 虚拟机,并将其调整到我想要的标准,我不能像第一次启动镜像时那样轻松地将其“拖放”到另一个节点。我不得不通过将外部硬盘添加到我运行 Windows 的破旧笔记本电脑中来创建网络附加存储(NAS)(有些人可能还记得我在第一篇文章中提到的那台发光的笔记本电脑)。这个存储充当了中介和备份库,用于克隆和迁移我的虚拟机。

  • 我不是一个很精通 Linux 的人。我知道,我知道,我确实应该深入研究一下,但多年来 Mac 和 Windows PC 的便利性使得我在尝试用 CLI 完成操作时会感到挣扎,而这些操作我通常可以通过点击来完成。

  • Proxmox 非常容易扩展。将第一个节点添加到集群或所谓的“数据中心”花了一些时间才弄明白,但添加其他节点则没有花费任何时间。一开始我能够通过完成分配静态 IP 地址等管理员任务,按照我的要求自定义每个节点。一旦掌握了技巧,部署虚拟机也只需几分钟。

  • 这非常酷,我非常喜欢那个展示所有操作统计数据的仪表板,我在操作过程中会密切监控。下图中,我们可以看到仪表板不仅监控了一个节点的使用情况,还记录了其他节点的情况。能够在数据中心的节点之间灵活切换是巨大的,当在裸机上监控可能需要在实例之间远程操作时。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Proxmox 图形界面在家庭实验室 — 作者截屏

最终,回顾起来我想到的一个问题是,不管我花了多少时间来配置 Windows 虚拟机以使其完全按照我希望的方式运行(本周早些时候,我花了一整晚来配置嵌套虚拟化以使 Docker 能够运行),总是会有一个额外的障碍或瓶颈。

容器化

我甚至不会对 Docker 进行比较,因为我在虚拟机中尝试启动的容器(一个为每半年一次的大学朋友 Minecraft 夜晚准备的 Minecraft 服务器)甚至无法达到令人满意的性能水平(服务器无法跟上,且无法进行游戏)。虽然如果我打算使用嵌套虚拟化,我的周末计划略有受阻,但也有实际应用可能会受到影响。

我经常用于工作和娱乐的一个工具是 PyCaret,一个专注于集成机器学习模型的 Python 库。机器学习模型常常有处理器或架构特定的注意事项,比如 PyCaret 不适用于 M1 Macs,因为 ARM 架构的原因,Tensorflow 没有针对我使用的 Radeon 显卡进行优化,而 Autogluon 在我的 i7 Mac 上无法构建(我甚至不知道为什么)。因此,这些包通常被容器化成 Docker 应用程序以实现便携性。我还在研究本地化到 Docker 的 DynamoDB,以利用强大的 AWS NoSQL 架构,而无需支付云端相关的高额费用。时间和速度是这些工具的卖点,而嵌套虚拟化对 Docker 的影响是巨大的(至少在这些 PC 上)。实际上,性能下降在每一级虚拟化中都是递增的,每一级的性能下降超过 10%

对于那些可能会指出这一点的人来说,一个额外的反思是,能够运行 Docker 的 LXCs(Linux 容器)是在主机操作系统上运行的,因此像 ML 模型这样的大型程序可能会导致内核崩溃,如内存交换失败,不仅会杀死容器,还会影响操作系统(而不仅仅是客操作系统)。因此,我甚至没有考虑在这里使用它们,尽管它们无疑是轻量级应用程序的有用工具。

即使没有测量,我们也可以看到尽可能接近裸金属的方式能提升工具的性能。然而,一些人能够在虚拟化环境中仍然实现出色的性能。例如,AWS Nitro就是该领域中的一个真正的差异化因素,它为亚马逊的大规模计算和数据仓储成功做出了贡献,尽管这需要巨大的成本,使得一些数据科学工具如 Sagemaker 的费用相当于我为每台桌面电脑在一个月内支付的费用来租用一台笔记本电脑。我们可以看到下面一个标准的 Sagemaker 工作室笔记本实例,每天使用八小时,一周五天,规格与我们的机器相似(甚至时钟速度有限),大约每月需 $75。总体来看,每个单元的成本可能在 $100–120 之间,升级后,功耗在待机状态下约为 10–15W,峰值为 65W。这大致相当于每月 $2–3 的电费。这与整年相比节省了近一个数量级。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Sagemaker 计算成本通过 AWS 的公共成本估算器 — 作者截图

话虽如此,我相信随着时间的推移以及更好、更快的硬件在消费者和二手市场上的出现,虚拟化性能与实际性能之间的差距会缩小。如果英特尔想送我一台 i9–13900K 或NVIDIA AI送我一台 RTX 4090,我会很乐意进行测试并向大家汇报。与此同时,我将满足于我的 HP mini 电脑和 AWS 免费套餐来满足我的数据分析和虚拟化需求。

比较

为了实际比较虚拟性能与物理性能,我们将对 Windows 10 虚拟机和物理系统分别进行一般化的基准测试,然后进行 Python 性能测试。在这里我要说明的是,我为虚拟机和 PC 分配了相同数量的核心和 RAM(虽然这让我在虚拟机方面有点冒险,因为我曾经在分配“所有”核心时遇到过问题,因为这影响了宿主虚拟化程序的性能,导致系统故障)。

毫不拖延,以下是基于 userbenchmark.com 为虚拟机构建的 基准测试 快照,运行于本地。以下我们可以看到 VCPU 的性能远低于基线和平均水平,我们的 RAM 也仅稍微低于基线。这表明要么我们的 CPU 没有得到充分利用,要么在测试期间这些数学密集型操作中,托管虚拟机的 CPU 存在大量开销。具体的整数计算性能等指标见截图。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者截图

整体表现不算特别好,尽管单位本身目前输出了一些 BTUs。

让我们通过运行一个简单的 Python 脚本作为基准来评估性能(请注意,由于 GIL,该脚本是单线程运行的)。下面的非科学 Python 脚本是我编写的,用于创建一个粗略的“速度计算”以比较相对性能。

在两个单元格中,我们首先进行:

  1. 计算一个任意大的数字并将每个数字添加到列表中

  2. 在循环中乘以越来越大的数字

每个测试都有时间限制,并且会重复执行,以建立相对性能的基线,用于比较处理速度和内存 IO。以下是代码片段,供你感兴趣时自行运行以进一步比较。

def test1(n):
    l = []
    for i in range(n):
        l.append(i)
n = 100000
%timeit -r 5 -n 1000 test1(n)

def test2(n):
    for i in range(n):
        i * (i-1)
n = 100000
%timeit -r 5 -n 1000 test2(n)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在虚拟机的 Jupyter Notebook 实例中运行 Python — 作者截图

第一个脚本平均完成时间为 8.84ms,第二个脚本平均完成时间为 11.5ms。我们将很快在虚拟机和裸金属之间比较这些数据。

在运行我们的脚本后,我们可以看到相当一部分 RAM 在使用中,然而,CPU 的利用率几乎没有增加,不过如果尝试在多个线程中分配此任务,我会担心弹性带来的问题。12.5GB 的空闲内存是一个显著的开销,应通过进一步的研究进行优化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者的任务管理器截图

现在谈谈裸金属……

在实际硬件上原生运行 Windows 10 Pro,我们可以看到使用相同测试套件的性能基准显著提高,这与虚拟机上使用的测试套件相同。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者截图

我们的处理器在 Windows 原生模式下的整数比较性能几乎是虚拟化模式的两倍,这使得虚拟机的表现大打折扣。我们现在更接近基准线,一般而言,处理过程的平均水平也有所提高。至于我们的 RAM,读写速度显著提升,当在单核上运行时,吞吐量几乎提高了 3 倍。直接在裸金属上运行对 IO 和处理速度确实产生了巨大的影响。

在运行我们的 Python 测试时,我们注意到性能有类似的跳跃。我们的测试脚本运行速度是虚拟化模式的两倍,测试 1 为 4.06 毫秒,测试 2 为 6.04 毫秒。这是虚拟机上原始测试速度的一半。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者的裸金属 Jupyter Notebook 截图

在未虚拟化的情况下,我们还可以看到使用的 RAM 是虚拟机空闲时的一半。总的来说,这表明与运行虚拟机的相同硬件相比,裸金属运行可以显著改善处理和内存性能。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者的任务管理器截图

每个团队独特的数据科学需求没有一刀切的解决方案。对于企业而言,花更多的钱用于虚拟分析工具可能更为合理,因为这些工具的安全性和性能可以得到严格监控。较小的公司也可以利用云工具,具体取决于它们的预算。然而,对于个人和小型研究团队来说,基于裸金属的构建可能是实现最佳性能的必要条件。

我用于构建项目和管道的策略不是专注于管理特定的主机和节点,而是保持一个新安装的 Windows(去除多余软件)的备份,其中预安装了我所需的所有内容——某些 Python 包、代码重新分发包、服务器连接等。项目的其余部分集中在一个代码库中,我可以在运行时将其复制并部署到节点进行处理。从 2010 年代中期开始,大多数计算机都具备的千兆连接速度足够快,可以传输数据科学库和包。因此,对高可用计算和操作系统正常运行时间的需求减少了,因为我大规模管理硬盘,不会在本地进行大规模更改。一些服务仍然需要主动管理,例如在 Docker 中运行的容器,但这些服务无论如何都需要相当活跃的管理,将其放在我的本地 Windows 10 专业版安装中,更符合我如何支配时间的方式,因为无论如何都会发生故障。

你怎么看?你在哪里运行你的代码?你倾向于使用哪些工具和平台来托管你的数据科学工作流?请在下面告诉我,或者随时在 LinkedIn 上与我建立联系!

对我使用的硬件感兴趣?查看我在 www.willkeefe.com 上的评论。

你应该使用 slots 吗?Slots 如何影响你的类,何时以及如何使用它们

原文:towardsdatascience.com/should-you-use-slots-how-slots-affect-your-class-when-and-how-to-use-ab3f118abc71

一行代码能带来 20%的性能提升?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Mike Huls

·发表于 Towards Data Science ·6 分钟阅读·2023 年 8 月 12 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(图片来源:Sébastien GoldbergUnsplash)

Slots 是一种机制,它允许你声明类属性并限制其他属性的创建。你可以确定你的类有哪些属性,从而防止开发者动态添加新属性。这通常会导致20%的速度提升

Slots 在有大量类实例且属性集已知的程序中特别有用。比如视频游戏或物理模拟;在这些情况下,你跟踪大量的实体。

你可以通过添加一行代码将 slots 添加到你的类中,但这总是一个好主意吗?在本文中,我们将探讨为什么如何使用slots使你的类更快以及何时使用它们。总体目标是更好地理解 Python 的类内部工作原理。开始编码吧!

Slots 让 Python 类更快

通过使用slots,你可以提高类的内存使用效率和性能。一个具有 slots 的类占用更少的内存,执行速度更快。

如何让我的类使用 slots?

告诉 Python 让一个类使用 slots 非常简单。你只需添加一个特殊的属性__slots__,它指定所有其他属性的名称:

class Person:
  first_name:str
  last_name:str
  age:int

  __slots__ = ['first_name', 'last_name', 'age']    # <-- this adds slots

  def __init__(self, first_name:str, last_name:str, age:int):
    self.first_name = first_name
    self.last_name = last_name
    self.age = age

在上述类中,我们看到Person有三个属性:first_namelast_nameage。我们可以告诉 Python 我们希望Person类使用 slots,通过添加__slots__属性来实现。这个属性必须指定所有其他属性的名称。

## 参数与关键字参数:哪种方式在 Python 中调用函数最快?

timeit模块的清晰演示

towardsdatascience.com

slotted 类的速度提升了多少?

我们上面使用的 Person 类使用 slots 后几乎 小了 60%(从 488 字节减少到 206 字节)。

关于速度,我已经对实例化、访问和赋值进行了基准测试。我发现 速度提高了多达 20%!你需要对这些结果持保留态度;虽然这些百分比看起来相当令人印象深刻,但这 20% 仅代表 10 万次实例化类的 0.44 秒。这相当于每个实例 可忽略的 44 纳秒(大约比一秒小 3030 万倍)。

查看用于基准测试的 内存速度代码;

## 为什么 Python 很慢以及如何加速

看看底层,了解 Python 的瓶颈所在

towardsdatascience.com

为什么 slotted 类更小且更快?

这与 Python 类的 动态字典 有关。这个字典让你可以为 Python 类分配属性:

class Person:
  pass

mike = Person()

mike.age = 33  # <-- create a new attribute

在上面的例子中,我们定义了一个没有任何属性的类,创建了一个实例,然后动态地创建 age 属性并赋值。

在底层,Python 将所有属性信息存储在一个字典中。通过调用类上的 __dict__ 魔法方法可以访问这个字典:

# 1\. Define class
class Person:
  name:str

  def __init__(self, name:str):
      self.name = name

# 2\. Create instance
mike = Person(name='mike')
# 3\. Create a new variable
mike.age = 33
# 4\. Create new attribute throught the __dict__
mike.__dict__['website'] = 'mikehuls.com'
# 5\. Print out the dynamic dictionary
print(mike.__dict__)  
# -> {'name': 'mike', 'age': 33, 'website': 'mikehuls.com'}

动态字典使得 Python 类非常灵活,但它有一个缺点:使用属性时 Python 会在这个字典中进行搜索,这相对较慢。

## 用两行代码线程化你的 Python 程序

通过同时做多件事来加速你的程序

towardsdatascience.com

slots如何影响动态字典?

当你告诉 Python 为你的类使用 slots 时,不会创建动态字典。相反,Python 创建了一个 固定大小的数组,其中包含对变量的引用。这就是你必须将属性名称传递给 __slots__ 属性的原因。

访问这个数组不仅速度更快,而且占用的内存空间也更少。较小的内存占用对内存分配和垃圾回收也有积极的影响。

插槽有什么副作用?

插槽改变了你的类;它变得有点不灵活,因为你的类变得更静态。这意味着你不能在运行时添加属性;你必须事先指定你的属性:

# 1\. Define class
class Person:
  name:str

  def __init__(self, name:str):
      self.name = name

# 2\. Create instance
mike = Person(name='mike')

# 3\. Add a new attribute?
mike.website = 'mikehuls.com'     # this will not work!
# ERROR: AttributeError: 'Person' object has no attribute 'website'

# 4\. Print out dynamic dict
print(mike.__dict__)              # this will not work
# ERROR: AttributeError: 'Person' object has no attribute '__dict__'

有一种(虽然有点乱的)变通方法:通过将 "__dict__" 的值添加到你的 __slots__ 数组中:

# 1\. Define class
class Person:
  name: str

  __slots__ = ["name", "__dict__"] # <- We've added __dict__

  def __init__(self, name: str):
    self.name = name

# 2\. Create instance
mike = Person(name='mike')

# 3\. Add a new attribute
mike.website = 'mikehuls.com'     # no error this time!

最后一个需要注意的事项是,有些包可能期望使用“普通”的 Python 类,而不是使用插槽类。

towardsdatascience.com ## 6 步骤让 Pandas DataFrame 操作快 100 倍

Cython 用于数据科学:将 Pandas 与 Cython 结合,以实现令人难以置信的速度提升

[towardsdatascience.com

这在数据类中也适用吗?

是的!从 Python 3.10 开始,你还可以添加插槽数据类。使用数据类更简单,只需向 @dataclass 装饰器添加一个参数即可。只需像下面这样定义你的数据类:

@dataclasses.dataclass(slots=True)
class Person:
    name: str

使用插槽有什么好处?

显然,速度内存效率,但也许还有安全性:如果我想覆盖类中的 age 属性但打错字,例如输入 mike.aage = 34,那么未使用插槽的类将创建一个新属性,而保持 age 属性不变。当你使用插槽时,Python 会抛出一个错误,因为它不知道类中有 aage 属性。

何时使用插槽?

速度:尽管插槽从百分比上加速了你的类,但每次操作的绝对时间增加是相当微不足道的。因此,如果你需要创建大量实例,或者需要多次重写或访问属性,插槽的使用会更具吸引力。

内存:如果你内存不足且希望节省每一个字节,使用插槽可能会有好处,因为它们显著减少了内存使用量。我们的简单类减少了 60% 的内存占用。

安全性:插槽防止你使用错误的属性和动态创建新属性。如果你尝试修改一个未知的属性,插槽类会抛出错误。

towardsdatascience.com ## 绝对初学者的 Cython:两步实现 30 倍更快的代码

轻松编译 Python 代码,实现极快的应用程序

[towardsdatascience.com

结论

正如我们在这篇文章中看到的,slots 以三种方式影响你的类:

  • 大小:slots 消除了 Python 创建动态字典的需要,而是依赖于 更小 的固定大小数组,这间接通过减少对垃圾回收的需求来加速你的应用。

  • 速度:slots 允许直接访问内存,绕过搜索字典的需要,这样会更快。速度提升在绝对意义上是相当微小的;节省了几纳秒。

  • 灵活性:slots 防止在运行时添加属性,因此你的类变得有点不那么灵活。这也可能是件好事,因为当你使用动态属性创建时,你的代码可能会变得杂乱无章。

在我看来,减少的灵活性是我不常遇到的缺点:我从不动态创建属性,我喜欢 slots 保持属性静态。因此,我会尽可能使用 slots。在最坏的情况下,依赖关系可能会出问题,但在这种情况下,很容易再次移除 slots。

我希望这篇文章能像我希望的那样清晰,但如果不是这种情况,请告诉我我可以做些什么来进一步澄清。与此同时,查看我的其他文章,涵盖了各种编程相关的话题,例如:

编程愉快!

— Mike

附言:喜欢我做的事吗? 关注我!

## 使用我的推荐链接加入 Medium — Mike Huls

阅读 Mike Huls 的每一篇故事(以及 Medium 上其他数千名作家的故事)。你的会员费用直接支持 Mike…

mikehuls.medium.com

在你的 Medium 博客中展示 Streamlit 应用

原文:towardsdatascience.com/show-streamlit-apps-in-your-medium-blog-520e98c7d51d

跟随这个教程将你的应用在线发布到 Medium 帖子中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Gustavo Santos

·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 5 月 30 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Lloyd Dirks 提供,来源于 Unsplash

介绍

我喜欢写关于Streamlit的文章。我会说它是我最喜欢的 Python 包之一。

Streamlit 是一个非常容易学习的包,它使人们能够快速创建仪表板、Web 应用,甚至将模型部署到生产环境中。我已经写过几次关于它的文章,甚至教你如何用这个优秀的包来构建和部署你的第一个应用(见下方链接)。

2022 年,Snowflake 收购了 Streamlit,从那时起,他们不断为这个包添加更多功能,使其变得更好。最近,他们推出了一个非常酷的功能,我想在这个简短的教程中展示:如何将你的应用嵌入到像这样的 Medium 博客中

[## Streamlit 基础:构建你的第一个 Web 应用

学习创建每个基本功能的代码,并在几分钟内部署你的 Web 应用。

medium.com [## Streamlit 基础:部署你的第一个 Web 应用

学习如何使用 Streamlit 的分享功能部署一个 Web 应用。

medium.com

使用案例

在我看来,将应用程序嵌入 Medium 博客中的好处有很多。

首先,我认为这是任何数据科学家建立优秀作品集的好方法。如果你在科技行业工作,你可能知道拥有一个有趣的作品集来展示你的技能,是吸引注意力的好方法,同时也能保持你的专业形象光鲜亮丽。

另一个好处是将研究成果展示给世界或甚至是客户。你可以写一篇博客文章,并在良好的市场分析或书面报告之后,将你的应用程序添加进去,作为你工作的一个很好的视觉补充。

我甚至看到有人将餐厅菜单做成 Streamlit 应用程序,连接客户和厨房,这真的很棒。所以,看看我们可以想到多少种选项。

创建一个简单的应用程序

好的,为了嵌入一个应用程序,我们需要先创建一个。所以,让我们创建一个简单的应用程序,只需几行代码。

我决定使用meteostat包创建一个应用程序,该包允许你根据经纬度坐标从几乎任何位置检索温度信息。

下面是我将在这个快速项目中使用的包:

# Imports
import pandas as pd
import streamlit as st
from datetime import datetime
from meteostat import Point, Monthly

接下来,让我们将 streamlit 页面设置为宽布局,这样我们的应用程序可以占据页面的更多部分,而不是将所有功能居中显示。作为输入数据,我使用了这个包含许多世界城市及其经纬度的 csv 文件。

# Set Page Layout
st.set_page_config(layout='wide')

# Load the Dataset
cities = pd.read_csv('world_cities.csv')

很好。下一步是放置两个下拉框,让用户选择他们想要检索信息的城市和年份。

# Select box
st.subheader('Temperature History App')

col1, col2 = st.columns(2, gap='medium')
# column 1 - Table weather history
with col1:
    # Title of the select box
    selected_city = st.selectbox(label='Select a city for weather information',
                                 options=cities['city'].sort_values().unique())

with col2:
    selected_year = st.selectbox(label= 'Select an year',
                                 options= range(2022,1999,-1) )

下一个代码片段展示了从meteostat包中实际获取的信息。我们首先设置一个时间框架,使用用户选择的年份。然后我们使用选择的城市从 CSV 中获取经纬度信息,使用简单的 Pandas 查询,将结果数据转换为浮点数。接下来,我们使用meteostat中的Point()函数创建一个对象,该对象将放入Monthly()函数中,后者接收时间框架和位置,创建一个最终对象data,用于提取我们应用程序所需的数据。

# Collect the Weather Information
# Set time period
start = datetime(selected_year, 1, 1)
end = datetime(selected_year, 12, 31)

# Create Point
lat = cities.query('city == @selected_city')['latitude'].astype('float').tolist()[0]
long = cities.query('city == @selected_city')['longitude'].astype('float').tolist()[0]
city_loc = Point(lat, long)

# Get daily data for 2018
data = Monthly(city_loc, start, end)
data = data.fetch()
#data['mth'] = data.index.month
data = data[ ['tavg', 'tmax', 'tmin'] ]

收集了所需的数据框后,我们将创建两列来绘制选定位置的月度温度线图,并显示包含检索信息的表格。

col1, col2 = st.columns(2, gap='large')
# column 1 - Table weather history
with col1:
    st.text('| TEMPERATURES IN °C')
    st.line_chart(data=data)

# column 2 - Graphics
with col2:
    # WEATHER INFORMATION TABLE
    st.text('| WEATHER HISTORY')
    st.write(data)

最后,我们将添加另一个部分,展示一个地图,显示所选位置。

# Division
st.markdown('---')
# Map
st.subheader('| WHERE THIS CITY IS')
df_map = cities.query('city == @selected_city')
st.map(df_map, zoom=5)

生成的应用程序被保存为.py文件,与 requirement.txt 和城市 csv 一起放在一个GitHub 存储库中,链接在此,这是部署前所需的步骤。

然后我们可以直接去share.streamlit.io/将应用程序部署到网络上。一旦完成,你将获得应用程序的链接,并将其粘贴到你的 Medium 博客文章中。

在下面的序列中,你可以看到我们在最后几段中刚刚构建的嵌入应用程序。它在你的 Medium 帖子中直接功能齐全*(请耐心等待,可能需要几秒钟才能完全加载)*。

在本教程中创建的天气应用。图片由作者提供。

在你离开之前

哇!这真是太棒了。当我看到这个新功能时,我迫不及待地想要测试它并与你分享。希望你也喜欢,并找到很好的方法与大家分享和展示你的工作。

Streamlit 使用起来真的很简单,你可以通过他们的文档或互联网上的许多教程了解更多。我相信你会喜欢用它编写应用程序的简单性。

如果你喜欢这些内容,记得关注我获取更多信息。

[## Gustavo Santos - Medium

阅读 Gustavo Santos 在 Medium 上的文章。他是一名数据科学家,从数据中提取洞察,帮助个人和公司……

gustavorsantos.medium.com](https://gustavorsantos.medium.com/?source=post_page-----520e98c7d51d--------------------------------)

LinkedIn 上找到我,或者通过 TopMate.io 预约时间与我讨论数据科学

参考文献

[## 嵌入你的应用 - Streamlit 文档

嵌入 Streamlit Community Cloud 应用可以通过集成交互式、数据驱动的应用程序来丰富你的内容……

docs.streamlit.io](https://docs.streamlit.io/streamlit-community-cloud/get-started/embed-your-app?utm_medium=email&_hsmi=259535966&_hsenc=p2ANqtz-9deEEF-Z6E5LeUsWM_TiXef4GoXNX6wpR27Fz5CYkwa9nRbwFaYVnGkLwIy9hmvE_gN6GwZsaFmkGDq8iQFCS3wfRp3g&utm_content=259535966&utm_source=hs_email&source=post_page-----520e98c7d51d--------------------------------#embedding-with-oembed) [## Streamlit 文档

Streamlit 不仅仅是创建数据应用的一种方式,它还是一个创作者社区,分享他们的应用和想法……

docs.streamlit.io](https://docs.streamlit.io/?source=post_page-----520e98c7d51d--------------------------------) [## Snowflake 以 8 亿美元收购 Streamlit,帮助客户构建基于数据的应用

Snowflake 帮助客户在云中存储和管理大量数据,而不受云供应商的锁定。Streamlit 是一个……

techcrunch.com [## Streamlit 基础知识:构建你的第一个 Web 应用

学习如何编写代码以创建每一个基本功能,并在几分钟内部署你的网页应用。

medium.com [## Streamlit 基础知识:部署你的第一个 Web 应用

学习如何使用 Streamlit 的分享功能部署 Web 应用。

medium.com

Siamese 神经网络与三重损失和余弦距离

原文:towardsdatascience.com/siamese-neural-networks-with-tensorflow-functional-api-6aef1002c4e

理论与代码实践:使用三重损失和余弦距离进行 Siamese 网络在 CIFAR-10 数据集上的训练

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Tan Pengshi Alvin

·发表于 Towards Data Science ·阅读时长 11 分钟·2023 年 5 月 12 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Alex Meier 提供,来源于 Unsplash

如果我们可以将每个对象图像(如人脸等)编码成一个模板——一个数字向量呢?之后,我们可以通过对比它们的模板——计算距离——来客观地确定对象之间的相似性。在深度学习中,这正是 Siamese 神经网络希望实现的目标。

Siamese 神经网络基本上是经过训练后为每个输入对象生成独特特征向量(模板)的模型。尽管这些模型通常用于对象图像的模板(计算机视觉),但它们也可以用于文本和声音数据。

除了安全认证,如人脸识别和签名比对,Siamese 神经网络还常用于电子商务平台中测量产品相似性。例如,一些电子商务平台允许你通过上传你想寻找的对象的图像来搜索类似产品。在 Kaggle 上,甚至有一个由东南亚领先电子商务公司 Shopee 举办的 产品匹配竞赛

在这篇文章中,我们将探索一个在 Tensorflow 中常见的数据集——CIFAR-10——该数据集与产品相似性搜索问题有些相似,只不过兴趣对象是汽车——如汽车、飞机、卡车、船只等——以及动物(或者说宠物也行!)——如猫、狗、马、鸟、鹿等。

在开始之前,我们首先需要理解 Siamese 神经网络背后的理论。之后,我们将探索在 CIFAR-10 数据集上训练和评估简单 Siamese 神经网络的代码。

准备好了吗?开始吧!

1. 孪生网络理论

我不得不承认本文中的封面图片有点误导——‘Siamese’一词实际上并不是源于‘暹罗猫’。而是来源于‘暹罗双胞胎’,即身体某部分连在一起的双胞胎。

因此,孪生神经网络基本上指的是双胞神经网络,这些网络通常在最后——Lambda 层,如我们将看到的——连接在一起,然后将模型输出输入损失函数。在训练这些双胞神经网络的过程中,它们的权重在初始化、前向传播和反向传播过程中完全相同。

由于我们通常处理的是图像,每对孪生神经网络通常是卷积神经网络(CNN)。如果你对 CNN 不熟悉或需要刷新记忆,我这里有一篇关于 CNN 的优秀文章:

[## 迁移学习与卷积神经网络(CNN)

从 CNN 到迁移学习的完整指南,适用于 Kaggle 的猫狗数据集

medium.com](https://medium.com/mlearning-ai/transfer-learning-and-convolutional-neural-networks-cnn-e68db4c48cca?source=post_page-----6aef1002c4e--------------------------------)

牢记这一点,我们将介绍两种常见的孪生神经网络:

1.1 对比损失孪生网络

第一种类型是基于计算双胞 CNN 的嵌入层(特征向量)之间的欧几里得/余弦距离,然后与真实值(1:匹配,0:不匹配)比较来确定对比损失的孪生神经网络。

以下是这种模型的示意图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对比损失的孪生神经网络示例。图片改编自SigNet 论文

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对比损失公式与欧几里得距离,其中 Y 为真实值。图片作者提供。

1.2 三重损失孪生网络

第二种类型的孪生神经网络基于计算三重 CNN 的嵌入层(特征向量)之间的两个欧几里得/余弦距离——即锚点和正样本之间,锚点和负样本之间——然后在 Lambda 层中完全计算三重损失,而不与任何真实值进行比较。

因为研究表明这种三重损失模型通常比对比损失模型更鲁棒,所以我们将在本文中重点讨论这种类型的孪生网络。

以下是这种模型的示意图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

三重损失孪生神经网络的示例。图片改编自SigNet 论文

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用欧几里得距离的三元组损失公式,其中 A 是锚点图像输入,P 是正样本图像输入,N 是负样本图像输入。图片由作者提供。

1.3 孪生网络的目标

现在,我们已经看到孪生神经网络的大致架构。但是在训练网络后,我们打算达到什么目标?让我们看看下面的插图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

孪生网络的训练减少了相似图像之间的距离,同时增加了不相似图像之间的距离。图片来源于FaceNet 论文

我们看到孪生网络正在学习在同一类别图像之间重建相似的特征向量。因此,训练后,相似图像模板之间的距离将减少,而不相似图像模板之间的距离将增加。

话虽如此,在训练过程中覆盖尽可能多的图像类别是很重要的,以便模型也能推广到未见过的类别(签名、面孔等)。

最后,在模型评估期间,我们主要关注生成输入图像数据的模板。因此,在进行模板推理时,仅提取单个 CNN 网络或双胞胎/三胞胎网络的主体,而不包括 Lambda 层。

1.4 欧几里得距离与余弦距离

在我们开始编码之前,让我们首先区分两个常见的向量距离度量——欧几里得距离和余弦距离。到目前为止,在上述插图中,我们展示了欧几里得距离,因为它更直观易懂,但在构建更好的模型时并不一定优于余弦距离。下面我们来说明一下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2D 空间中两个向量的欧几里得距离和余弦距离的插图。图片由作者提供。

从上述内容来看,欧几里得距离只是两个特征向量之间的“坐标距离”,而余弦距离是它们之间“角度距离”的一种度量。因此,当两个特征向量远离时,我们可以看到欧几里得距离和余弦距离都很大。但它们之间存在微妙的差别,下面我们来看一下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

欧几里得距离和余弦距离在小角度下但向量长度不同的比较。图片由作者提供。

虽然余弦距离仅测量特征向量之间的角度差异,但欧几里得距离测量第二维度——长度差异。因此,虽然更直观,但欧几里得距离本质上比余弦距离更复杂。

一般来说,欧几里得距离和余弦距离都被广泛使用,选择取决于经验探索。然而,对于较小的数据集和特定数量的类别,采用余弦距离作为损失函数可能是一个更好的选择,这也是我们为 CIFAR-10 数据集所做的。

2. 孪生网络代码练习

接下来,让我们开始编码吧。我们将基于 TensorFlow CIFAR-10 数据集构建三元组损失孪生网络。我们将基于余弦距离来构建三元组损失,然后在测试集评估时,通过角度相似度来比较测试图像。

注意*:使用的是角度相似度,因为它基于余弦距离,但值范围缩放在 0%到 100%之间。*

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

角度相似度的公式。图片由作者提供。

还需要注意的是,在模型初始化过程中,我们将采用 TensorFlow 的功能性 API(对比之前在迁移学习和 CNN 文章中使用的顺序 API),以及自定义 Lambda 层和自定义损失函数。

毫不犹豫地,让我们开始编码吧!

2.1 探索 CIFAR-10 数据集

# import necessary libraries
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# set random seed
np.random.seed(42)

# load CIFAR-10 data
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()

# check data size
assert X_train.shape == (50000, 32, 32, 3)
assert X_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

# combine data first - we will generate test set later.
X = np.concatenate([X_train,X_test],axis=0)
y = np.concatenate([y_train,y_test],axis=0)
y = np.squeeze(y)

assert X.shape == (60000, 32, 32, 3)
assert y.shape == (60000,)

# check number of data in each class
unique, counts = np.unique(y,return_counts=True)
np.asarray([unique,counts]).T

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

# Plot Class N (0-9)

TARGET = # Class index here
NUM_ARRAYS = 10

arrays = X[np.where(y==TARGET)]
random_arrays_indices = np.random.choice(len(arrays),NUM_ARRAYS)
random_arrays = arrays[random_arrays_indices]

fig = plt.figure(figsize=[NUM_ARRAYS,4])
plt.title('Class 0: Plane',fontsize = 15)
plt.axis('off')

for index in range(NUM_ARRAYS):
     fig.add_subplot(2, int(NUM_ARRAYS/2), index+1)
     plt.imshow(random_arrays[index])

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2.2 生成三元组

# initialize triplets array
triplets = np.empty((0,3,32,32,3),dtype=np.uint8)

# get triplets for each class
for target in range(10):

    locals()['arrays_'+str(target)] = X[np.where(y==target)].reshape(3000,2,32,32,3)
    locals()['arrays_not_'+str(target)] = X[np.where(y!=target)]

    random_indices = np.random.choice(len(locals()['arrays_not_'+str(target)]),3000)
    locals()['arrays_not_'+str(target)] = locals()['arrays_not_'+str(target)][random_indices]

    locals()['arrays_'+str(target)] = np.concatenate(
        [
            locals()['arrays_'+str(target)],
            locals()['arrays_not_'+str(target)].reshape(3000,1,32,32,3)
        ],
        axis = 1
    )

    triplets = np.concatenate([triplets,locals()['arrays_'+str(target)]],axis=0)

# check triplets size
assert triplets.shape == (30000,3,32,32,3)

# plot triplets array to visualize
TEST_SIZE = 5
random_indices = np.random.choice(len(triplets),TEST_SIZE)

fig = plt.figure(figsize=[5,2*TEST_SIZE])
plt.title('ANCHOR | POSITIVE | NEGATIVE',fontsize = 15)
plt.axis('off')

for row,i in enumerate(range(0,TEST_SIZE*3,3)):
    for j in range(1,4):
        fig.add_subplot(TEST_SIZE, 3, i+j)
        random_index = random_indices[row]
        plt.imshow(triplets[random_index,j-1])

# save triplet array
np.save('triplets_array.npy',triplets)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2.3 准备模型训练/评估

# Import all libraries

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras import Input, optimizers, Model
from tensorflow.keras.layers import Layer, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import plot_model

from sklearn.metrics import precision_recall_curve, roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split

from scipy import spatial
triplets = np.load('triplets_array.npy')

triplets = triplets/255 #normalize by 255
labels = np.ones(len(triplets)) #create a fixed label

assert triplets.shape == (30000,3,32,32,3)
# Split data into our train and test set

X_train, X_test, y_train, y_test = train_test_split(
    triplets,
    labels,
    test_size=0.05,
    random_state=42
)
# Load pretrained model for transfer learning

pretrained_model = MobileNetV2(
    weights='imagenet', 
    include_top=False, 
    input_shape=(32,32,3)
)

for layer in pretrained_model.layers:
    layer.trainable = True

2.4 模型训练

# Initialize functions for Lambda Layer

def cosine_distance(x,y):
    x = K.l2_normalize(x, axis=-1)
    y = K.l2_normalize(y, axis=-1)
    distance = 1 - K.batch_dot(x, y, axes=-1)
    return distance

def triplet_loss(templates, margin=0.4):

    anchor,positive,negative = templates

    positive_distance = cosine_distance(anchor,positive)
    negative_distance = cosine_distance(anchor,negative)

    basic_loss = positive_distance-negative_distance+margin
    loss = K.maximum(basic_loss,0.0)

    return loss
# Adopting the TensorFlow Functional API

anchor = Input(shape=(32, 32,3), name='anchor_input')
A = pretrained_model(anchor)

positive = Input(shape=(32, 32,3), name='positive_input')
P = pretrained_model(positive)

negative = Input(shape=(32, 32,3), name='negative_input')
N = pretrained_model(negative)

loss = Lambda(triplet_loss)([A, P, N])

model = Model(inputs=[anchor,positive,negative],outputs=loss)
# Create a custom loss function since there are no ground truths label

def identity_loss(y_true, y_pred):
    return K.mean(y_pred)

model.compile(loss=identity_loss, optimizer=Adam(learning_rate=1e-4))

callbacks=[EarlyStopping(
    patience=2, 
    verbose=1, 
    restore_best_weights=True,
    monitor='val_loss'
    )]

# view model
plot_model(model, show_shapes=True, show_layer_names=True, to_file='siamese_triplet_loss_model.png')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

# Start training - y_train and y_test are dummy

model.fit(
    [X_train[:,0],X_train[:,1],X_train[:,2]],
    y_train,
    epochs=50, 
    batch_size=64,
    validation_data=([X_test[:,0],X_test[:,1],X_test[:,2]],y_test),
    callbacks=callbacks
)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2.5 模型评估

X_test_anchor = X_test[:,0]
X_test_positive = X_test[:,1]
X_test_negative = X_test[:,2]

# extract the CNN model for inference
siamese_model = model.layers[3]

X_test_anchor_template = np.squeeze(siamese_model.predict(X_test_anchor))
X_test_positive_template = np.squeeze(siamese_model.predict(X_test_positive))
X_test_negative_template = np.squeeze(siamese_model.predict(X_test_negative))

y_test_targets = np.concatenate([np.ones((len(X_test),)),np.zeros((len(X_test),))])
# Get predictions in angular similarity scores

def angular_similarity(template1,template2):

    score = np.float32(1-np.arccos(1-spatial.distance.cosine(template1,template2))/np.pi)

    return score

y_predict_targets = []

for index in range(len(X_test)):
    similarity = angular_similarity(X_test_anchor_template[index],X_test_positive_template[index])
    y_predict_targets.append(similarity)

for index in range(len(X_test)):
    similarity = angular_similarity(X_test_anchor_template[index],X_test_negative_template[index])
    y_predict_targets.append(similarity)
# Get prediction results with ROC Curve and AUC scores

fpr, tpr, thresholds = roc_curve(y_test_targets, y_predict_targets)

fig = plt.figure(figsize=[10,7])
plt.plot(fpr, tpr,lw=2,label='UnoFace_v2 (AUC={:.3f})'.format(roc_auc_score(y_test_targets, y_predict_targets)))
plt.plot([0,1],[0,1],c='violet',ls='--')
plt.xlim([-0.05,1.05])
plt.ylim([-0.05,1.05])
plt.legend(loc="lower right",fontsize=15)

plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title('Receiver Operating Characteristic (ROC) Curve',weight='bold',fontsize=15)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

# Getting Test Pairs and their Corresponding Predictions

positive_comparisons = X_test[:,[0,1]]
negative_comparisons = X_test[:,[0,2]]

positive_predict_targets = np.array(y_predict_targets)[:1500]
negative_predict_targets = np.array(y_predict_targets)[1500:]

assert positive_comparisons.shape == (1500,2,32,32,3)
assert negative_comparisons.shape == (1500,2,32,32,3)

assert positive_predict_targets.shape == (1500,)
assert negative_predict_targets.shape == (1500,)

np.random.seed(21)
NUM_EXAMPLES = 5
random_index = np.random.choice(range(len(positive_comparisons)),NUM_EXAMPLES)
# Plotting Similarity Scores for Positive Comparisons 
# (Switch values and input to plot for Negative Comparisons)

plt.figure(figsize=(10,4))
plt.title('Positive Comparisons and Their Similarity Scores')
plt.ylabel('Anchors')
plt.yticks([])
plt.xticks([32*x+16 for x in range(NUM_EXAMPLES)], ['.' for x in range(NUM_EXAMPLES)])
for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
    t.set_color('green') 
plt.grid(None)
anchor = np.swapaxes(positive_comparisons[:,0][random_index],0,1)
anchor = np.reshape(anchor,[32,NUM_EXAMPLES*32,3])
plt.imshow(anchor)

plt.figure(figsize=(10,4))
plt.ylabel('Positives')
plt.yticks([])
plt.xticks([32*x+16 for x in range(NUM_EXAMPLES)], positive_predict_targets[random_index])
for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
    t.set_color('green') 
plt.grid(None)
positive = np.swapaxes(positive_comparisons[:,1][random_index],0,1)
positive = np.reshape(positive,[32,NUM_EXAMPLES*32,3])
plt.imshow(positive)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3. 结论

恭喜你完成理论和代码练习!希望这个教程为你提供了关于孪生网络及其在对象相似度应用方面的全面介绍。

在结束之前,我还要补充的是,如何处理对象相似度分数取决于问题陈述。

如果我们在生产中进行 1:1 对象比较(即两个对象是否相似或不同),通常需要基于测试时的假匹配率(FMR)设置一个相似度阈值。另一方面,如果进行 1:N 对象匹配,通常会返回相似度得分最高的对象,并进行排序。

注:有关完整的代码,请查看我的 GitHub

感谢您的时间,希望您喜欢本教程。我还想介绍一个在这篇文章中详细阐述的极其重要的话题——以数据为中心的机器学习

[## 以数据为中心的 AI — 数据收集和增强策略]

关于以数据为中心的机器学习项目的数据生成策略的综合指南

pub.towardsai.net

感谢阅读!如果您喜欢我的内容,可以浏览我在 Medium 上的其他文章,并在 LinkedIn 上关注我。

支持我! — 如果您没有订阅 Medium,并且喜欢我的内容,请考虑通过我的推荐链接来支持我。

[## 通过我的推荐链接加入 Medium - Tan Pengshi Alvin]

阅读 Tan Pengshi Alvin 的每一个故事(以及 Medium 上成千上万的其他作家)。您的会员费用直接…

tanpengshi.medium.com

相似性搜索,第三部分:结合倒排文件索引和产品量化

原文:towardsdatascience.com/similarity-search-blending-inverted-file-index-and-product-quantization-a8e508c765fa?source=collection_archive---------1-----------------------#2023-05-19

了解如何结合两种基本的相似性搜索索引,以获得两者的优势

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Vyacheslav Efimov

·

关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 5 月 19 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

相似性搜索 是一个问题,其中给定一个查询的目标是找到数据库中与之最相似的文档。

引言

在数据科学中,相似性搜索常见于自然语言处理(NLP)领域、搜索引擎或推荐系统中,这些系统需要为查询检索出最相关的文档或项目。在海量数据中提升搜索性能的方法有很多种。

在本系列的前两部分中,我们讨论了信息检索中的两种基本算法:倒排文件索引产品量化。这两者都优化了搜索性能,但关注的方面不同:前者加快了搜索速度,而后者则将向量压缩为更小、更节省内存的表示。

## 相似性搜索,第一部分:kNN 与倒排文件索引

相似性搜索是一个热门问题,其中给定一个查询 Q,我们需要在所有文档中找到最相似的文档。

## 相似性搜索,第二部分:产品量化

在本系列文章的第一部分,我们查看了用于执行相似性搜索的 kNN 和倒排文件索引结构。

medium.com

由于这两种算法侧重于不同方面,自然会产生一个问题,即是否可以将这两种算法合并成一种新算法。

在本文中,我们将结合这两种方法的优点,以产生一种快速且节省内存的算法。供参考,大多数讨论的想法将基于这篇论文

在深入细节之前,有必要了解残差向量是什么,并对其有用的属性有一个简单的直观认识。我们将在设计算法时使用它们。

残差向量

想象一下执行了一个聚类算法,并产生了几个簇。每个簇都有一个质心和与之相关的点。残差是一个点(向量)与其质心之间的偏移。基本上,要找出特定向量的残差,需要从其质心中减去该向量。

如果簇是由 k-means 算法构建的,那么簇的质心是所有属于该簇的点的均值。因此,从任何点找出残差将等同于从中减去簇的均值。通过从属于特定簇的所有点中减去均值,这些点将围绕 0 中心对齐。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

原始的点簇显示在左侧。然后从所有簇点中减去簇质心。结果的残差向量显示在右侧。

我们可以观察到一个有用的事实:

用残差替换原始向量不会改变它们之间的相对位置。

也就是说,向量之间的距离始终保持不变。我们可以简单地查看下面的两个方程。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

减去均值不会改变相对距离

第一个方程是两个向量之间的欧几里得距离公式。在第二个方程中,从两个向量中减去簇的均值。我们可以看到,均值项会被消去——整个表达式变得与第一个方程中的欧几里得距离完全相同!

我们通过使用 L2 度量(欧几里得距离)的公式证明了这一声明。重要的是要记住,这个规则可能不适用于其他度量。

因此,如果对于给定的查询,目标是找到最近的邻居,可以仅从查询中减去簇均值,然后在残差中进行正常的搜索过程。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从查询中减去均值不会改变其相对位置。

现在让我们看看下图中的另一个例子,其中两个簇的向量残差分别计算。

从每个簇的对应质心中减去均值将使所有数据集向量围绕 0 中心

这是一个有用的观察,将在未来使用。此外,对于给定的查询,我们可以计算到所有簇的查询残差。查询残差使我们能够计算到簇的原始残差的距离。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从每个簇中减去均值后,所有点都围绕 0 中心。查询和查询残差与相应簇中其他点的相对位置保持不变。

训练

考虑到上一节中的有用观察,我们可以开始设计算法。

给定一个向量数据库,构建一个倒排文件索引,将向量集划分为n个 Voronoi 分区,从而减少推理过程中的搜索范围。

在每个 Voronoi 分区内,从每个向量中减去质心的坐标。结果是,所有分区中的向量变得彼此更接近,并且围绕 0 中心。此时,无需原始向量,因为我们存储它们的残差。

之后,对所有分区中的向量运行产品量化算法。

重要方面:产品量化不会对每个分区单独执行——那样会很低效,因为分区的数量通常很高,这将需要大量的内存来存储所有的码本。相反,算法会对所有分区的残差同时执行

实际上,现在每个子空间包含来自不同 Voronoi 分区的子向量。然后,对于每个子空间,执行一个聚类算法,构建出如常规的 k 个簇及其中心点。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练过程

替换向量为其残差的重要性是什么? 如果向量没有被其残差替换,那么每个子空间将包含更多的各种子向量(因为子空间将存储来自不同不相交的 Voronoi 分区的子向量,而这些子向量可能在空间中相距很远)。现在来自不同分区的向量(残差)彼此重叠。由于现在每个子空间的变化更小,因此表示向量所需的重现值也更少。换句话说:

使用之前相同长度的 PQ 代码,向量可以更准确地表示,因为它们的方差更小。

推断

对于给定的查询,找到 Voronoi 分区的 k 个最近中心点。所有这些区域内的点都被视为候选点。由于原始向量在每个 Voronoi 区域中最初被其残差所替代,查询向量的残差也需要被计算。在这种情况下,查询残差需要为每个 Voronoi 分区单独计算(因为每个区域有不同的中心点)。只有来自所选 Voronoi 分区的残差才会成为候选点。

查询残差随后被拆分为子向量。与原始的产品量化算法相同,对于每个子空间,计算包含从子空间中心点到查询子向量的距离的距离矩阵 d。必须记住,查询残差在每个 Voronoi 分区中都是不同的。这基本上意味着距离矩阵 d 需要为每个查询残差单独计算。这是所需优化的计算代价。

最后,部分距离被汇总,就像在产品量化算法中之前所做的那样。

排序结果

在计算所有距离后,需要选择 k 个最近邻点。为了高效完成这一过程,作者建议使用一个 MaxHeap 数据结构。它的容量有限为 k,并在每一步中存储 k 个当前最小的距离。每当计算出一个新距离时,只有当计算出的值小于 MaxHeap 中的最大值时,该值才会被添加到 MaxHeap 中。计算完所有距离后,查询的答案已经存储在 MaxHeap 中。使用 MaxHeap 的优点是其构建时间很快,为 O(n)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

推断过程

性能

该算法利用了倒排文件索引和产品量化。根据推理过程中 Voronoi 分区的数量,相同数量的子向量到质心矩阵 d 需要计算并存储在内存中。这可能看起来像是一个缺点,但与整体优势相比,这是一个相当好的折衷。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

该算法从倒排文件索引继承了良好的搜索速度,从产品量化继承了压缩效率。

Faiss 实现

Faiss(Facebook AI 搜索相似性)是一个用 C++ 编写的 Python 库,用于优化相似性搜索。该库提供了不同类型的索引,这些数据结构用于高效地存储数据和执行查询。

根据 Faiss 文档 的信息,我们将了解如何将倒排文件和产品量化索引组合在一起形成新的索引。

Faiss 在 IndexIVFPQ 类中实现了上述算法,该类接受以下参数:

  • quantizer:指定计算向量之间距离的方式。

  • d:数据维度。

  • nlist:Voronoi 分区的数量。

  • M:子空间的数量。

  • nbits:编码单个簇 ID 所需的位数。这意味着单个子空间中的簇总数将等于 k = 2^nbits

此外,可以调整 nprobe 属性,该属性指定在推理过程中用于搜索候选项的 Voronoi 分区数量。更改此参数无需重新训练。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Faiss 的 IndexIVFPQ 实现

存储单个向量所需的内存与原始产品量化方法相同,只是现在我们增加了 8 个字节,用于在倒排文件索引中存储关于向量的信息。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

结论

利用之前文章部分的知识,我们探讨了一个先进算法的实现,该算法实现了高效的内存压缩和加速的搜索速度。该算法在处理大量数据时广泛用于信息检索系统。

资源

除非另有说明,所有图片均由作者提供。

相似性搜索,第一部分:kNN 与倒排文件索引

原文:towardsdatascience.com/similarity-search-knn-inverted-file-index-7cab80cc0e79?source=collection_archive---------1-----------------------#2023-04-28

介绍 kNN 的相似性搜索及其通过倒排文件的加速。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Vyacheslav Efimov

·

关注 发布于 Towards Data Science ·9 分钟阅读·2023 年 4 月 28 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

相似性搜索是一个问题,给定一个查询,目标是找到数据库中与之最相似的文档。

介绍

在数据科学中,相似性搜索常出现在自然语言处理领域、搜索引擎或推荐系统中,需要为查询检索最相关的文档或项目。通常,文档或项目以文本或图像的形式表示。然而,机器学习算法不能直接处理原始文本或图像,这就是为什么文档和项目通常被预处理并存储为向量的原因。

有时,向量的每个组件可以存储语义信息。在这种情况下,这些表示也称为嵌入。这些嵌入可以有数百维,并且其数量可以达到数百万!由于这些巨大的数字,任何信息检索系统必须能够迅速检测到相关文档。

在机器学习中,向量也被称为对象

索引

为了加速搜索性能,数据集嵌入之上建立了一个特殊的数据结构。这种数据结构称为索引。在这一领域已有大量研究,并且发展出了许多类型的索引。在选择适用于特定任务的索引之前,有必要了解其内部操作原理,因为每种索引有不同的用途,并且各自有优缺点。

在本文中,我们将看看最简单的方法——kNN。基于 kNN,我们将转到倒排文件——一种用于更可扩展搜索的索引,可以加速搜索过程数倍。

kNN

kNN是相似性搜索中最简单和最原始的算法。考虑一个向量数据集和一个新的查询向量 Q。我们希望找到与 Q 最相似的前 k 个数据集向量。首先要考虑的是如何测量两个向量之间的相似性(距离)。实际上,有几种相似性度量可以实现这一点。下面的图中展示了其中一些。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

相似性度量

训练

kNN 是机器学习中为数不多的无需训练阶段的算法之一。选择合适的度量后,我们可以直接进行预测。

推理

对于一个新的对象,该算法会穷尽地计算与所有其他对象的距离。之后,它会找到距离最小的 k 个对象,并将其作为响应返回。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

kNN 工作流程

显然,通过检查与所有数据集向量的距离,kNN 可以保证 100%的准确结果。然而,这种蛮力方法在时间性能上非常低效。如果一个数据集由 n 个具有 m 维度的向量组成,那么对于每个 n 向量,需要 O(m) 时间来计算与查询 Q 的距离,总时间复杂度为 O(mn)。正如我们稍后将看到的,存在更高效的方法。

此外,原始向量没有压缩机制。想象一下一个包含数十亿个对象的数据集。将所有这些对象存储在 RAM 中几乎是不可能的!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

kNN 性能。具有 100%的准确率和没有训练阶段会导致在推理过程中进行穷举搜索以及向量的无内存压缩。注意:这种图示显示了不同算法的相对比较。根据情况和选择的超参数,性能可能会有所不同。

应用

kNN 的应用范围有限,应该仅在以下情况之一中使用:

  • 数据集的大小或嵌入维度相对较小。这一方面确保了算法仍然能够快速执行。

  • 算法的要求准确度必须达到 100%。在准确度方面,没有其他最近邻算法能够超越 kNN。

基于指纹检测一个人的例子是需要 100%准确度的问题。如果一个人犯了罪并留下了指纹,检索到的结果必须完全正确。否则,如果系统不是 100%可靠,则可能会错误地将另一人定罪,这是一种非常严重的错误。

基本上,改进 kNN 有两种主要方法(稍后我们将讨论):

  • 缩小搜索范围。

  • 降低向量的维度。

使用这两种方法之一时,我们将不会再次进行穷举搜索。这些算法被称为近似最近邻(ANN),因为它们不保证 100%的准确结果。

倒排文件索引

“倒排索引(也称为文档列表文档文件倒排文件)是一个数据库索引,存储内容的映射,例如单词或数字,及其在表格、文档或文档集合中的位置” — 维基百科

在执行查询时,计算查询的哈希函数,并从哈希表中获取映射值。这些映射值中的每一个都包含自己的一组潜在候选者,然后在条件下完全检查是否为查询的最近邻。通过这样做,所有数据库向量的搜索范围被缩小。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

倒排文件索引工作流程

这种索引有不同的实现方式,具体取决于哈希函数的计算方式。我们将要研究的实现是使用Voronoi 图(或Dirichlet 镶嵌)的方法。

训练

该算法的思想是创建几个不相交的区域,每个数据集点将属于其中一个区域。每个区域都有自己的质心,指向该区域的中心。

有时Voronoi 区域也被称为单元分区

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Voronoi 图示例。白点是各自分区的中心,分区内包含一组候选者。

Voronoi 图的主要特性是,质心到其区域内任意点的距离小于该点到另一质心的距离。

推理

当给定一个新对象时,计算所有 Voronoi 分区质心的距离。然后选择距离最小的质心,并将该分区中的向量作为候选项。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过给定查询,我们搜索最近的质心(位于绿色区域)

最终,通过计算与候选项的距离并选择前k个最近的候选项,返回最终答案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在选定区域内查找最近邻居

如你所见,这种方法比之前的要快得多,因为我们不需要遍历所有数据集向量。

边缘问题

随着搜索速度的提高,倒排文件也有一个缺点:它不能保证找到的对象始终是最近的。

在下图中,我们可以看到这样的场景:实际的最近邻居位于红色区域,但我们仅从绿色区域选择候选项。这种情况被称为边缘问题

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

边缘问题

这种情况通常发生在查询对象靠近另一区域边界时。为了减少这种情况下的错误数量,我们可以扩大搜索范围,并基于与对象最接近的前m个质心选择多个区域来搜索候选项。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在几个区域内查找最近邻居(m = 3)

探索的区域越多,结果越准确,但计算所需的时间也越长。

应用

尽管存在边缘问题,倒排文件在实践中表现出色。它在需要在准确度和速度提升之间进行权衡时非常适合使用。

一个使用案例示例是基于内容的推荐系统。假设它根据用户过去观看的其他电影向用户推荐一部电影。数据库包含一百万部电影供选择。

  • 使用 kNN 时,系统确实会选择对用户最相关的电影并进行推荐。然而,执行查询所需的时间非常长。

  • 假设使用倒排文件索引,系统推荐第 5 个最相关的电影,这在现实生活中可能是这种情况。搜索时间比 kNN 快 20 倍。

从用户体验的角度来看,很难区分这两种推荐结果的质量:第 1 个和第 5 个最相关的结果都是来自百万个可能电影中的良好推荐。用户可能对这些推荐中的任何一个都感到满意。从时间的角度来看,倒排文件显然是赢家。这就是为什么在这种情况下,最好使用后者的方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

反向文件索引性能。在这里,我们稍微降低了准确度以在推理过程中获得更高的速度。

Faiss 实现

Faiss(Facebook AI 搜索相似性)是一个用 C++编写的 Python 库,用于优化的相似性搜索。该库展示了不同类型的索引,这些索引是用来高效存储数据和执行查询的数据结构。

根据Faiss 文档中的信息,我们将了解索引的创建和参数化过程。

kNN

实现 kNN 方法的索引在 Faiss 中被称为flat,因为它们不压缩任何信息。它们是唯一保证正确搜索结果的索引。实际上,Faiss 中存在两种类型的 flat 索引:

  • IndexFlatL2。相似度计算为欧几里得距离。

  • IndexFlatIP。相似度计算为内积。

这两种索引在构造函数中都需要一个单一的参数d:数据维度。这些索引没有任何可调参数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Faiss 对 IndexFlatL2 和 IndexFlatIP 的实现

存储一个向量的单一分量需要 4 字节。因此,存储一个维度为d的向量需要4 * d字节。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

反向文件索引

对于描述的反向文件,Faiss 实现了IndexIVFFlat类。与 kNN 的情况一样," Flat "一词表示没有对原始向量进行解压,它们被完全存储。

为了创建这个索引,我们首先需要传递一个量化器——一个决定数据库向量如何存储和比较的对象。

IndexIVFFlat有两个重要参数:

  • nlist:定义在训练过程中创建的区域(Voronoi 单元)的数量。

  • nprobe:决定搜索候选区域的数量。更改 nprobe 参数不需要重新训练。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Faiss 对 IndexIVFFlat 的实现

与之前的情况一样,我们需要4 * d字节来存储一个向量。但现在我们还需要存储有关 Voronoi 区域的信息,这些区域是数据集向量所属的。在 Faiss 实现中,这些信息每个向量占用 8 字节。因此,存储单个向量所需的内存为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

结论

我们已经探讨了相似性搜索中的两种基础算法。实际上,朴素的 kNN 几乎不应该用于机器学习应用,因为它在扩展性方面表现不佳,除非在特定情况下。另一方面,倒排文件提供了加速搜索的良好启发式方法,其质量可以通过调整超参数来提高。从不同的角度仍然可以提升搜索性能。在本系列文章的下一部分,我们将深入探讨一种旨在压缩数据集向量的方法。

## 相似性搜索,第二部分:产品量化

学习一种有效压缩大数据的强大技术

towardsdatascience.com

资源

除非另有说明,否则所有图片均由作者提供。

相似性搜索,第四部分:分层可导航的小世界(HNSW)

原文:towardsdatascience.com/similarity-search-part-4-hierarchical-navigable-small-world-hnsw-2aad4fe87d37?source=collection_archive---------0-----------------------#2023-06-16

发现如何构建高效的多层图以提升在海量数据中的搜索速度

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Vyacheslav Efimov

·

关注 发表在 数据科学进展 · 13 分钟阅读 · 2023 年 6 月 16 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

相似性搜索 是一个问题,其中给定一个查询,目标是找到与之最相似的文档,这些文档位于所有数据库文档中。

介绍

在数据科学中,相似性搜索通常出现在自然语言处理领域、搜索引擎或推荐系统中,这些系统需要为查询检索最相关的文档或项目。在海量数据中,有各种不同的方法可以提高搜索性能。

分层可导航小世界(HNSW)是一种用于近似邻居搜索的最先进算法。在背后,HNSW 构建了优化的图结构,使其与本系列文章前面讨论的其他方法大相径庭。

HNSW 的主要思想是构建一个图,使得任意一对顶点之间的路径可以在少量步骤内遍历。

一个著名的类比是著名的 六度分隔理论,与这种方法相关:

所有人彼此之间的社交联系最多为六层。

在深入探讨 HNSW 的内部工作之前,我们先讨论跳表和可导航小世界——HNSW 实现中使用的关键数据结构。

跳表

跳表 是一种概率数据结构,允许在排序列表中以 O(logn) 的平均时间复杂度插入和搜索元素。跳表由多个层次的链表构成。最低层包含所有元素的原始链表。当移动到更高的层级时,被跳过的元素数量增加,从而减少了连接数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在跳表中找到元素 20

对于某个值的搜索程序从最高层开始,将其下一个元素与该值进行比较。如果值小于或等于该元素,则算法继续到下一个元素。否则,搜索程序降到连接更多的较低层,并重复相同的过程。最后,算法降到最低层并找到所需的节点。

根据 维基百科 的信息,跳表有一个主要参数 p,它定义了一个元素出现在多个列表中的概率。如果一个元素出现在层 i 中,则它出现在层 i + 1 的概率等于 pp 通常设置为 0.5 或 0.25)。平均而言,每个元素会出现在 1 / (1 — p) 个列表中。

正如我们所见,这个过程比普通的链表线性搜索要快得多。实际上,HNSW 继承了相同的思想,但它使用的是图而不是链表。

可导航的小世界

可导航的小世界 是一个具有多对数 T = O(logᵏn) 搜索复杂度的图,它使用贪心路由。路由 指的是从低度顶点开始搜索过程,并以高维度顶点结束。由于低度顶点的连接非常少,算法可以在它们之间迅速移动,从而高效地导航到可能存在最近邻的区域。然后,算法逐渐放大并切换到高维度顶点,以在该区域的顶点中找到最近邻。

顶点有时也被称为节点

搜索

首先,通过选择一个入口点进行搜索。为了确定算法下一步移动的顶点(或顶点),它计算查询向量到当前顶点邻居的距离,并移动到最近的一个。某些时候,当算法找不到比当前节点更靠近查询的邻居节点时,它会终止搜索过程。这个节点被返回作为查询的响应。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可导航的小世界中的贪心搜索过程。节点 A 被用作入口点。它有两个邻居 B 和 D。节点 D 比 B 更接近查询。因此,我们移动到 D。节点 D 有三个邻居 C、E 和 F。E 是距离查询最近的邻居,所以我们移动到 E。最终,搜索过程将导致节点 L。由于 L 的所有邻居都比 L 本身离查询更远,我们停止算法,并将 L 作为查询的答案返回。

这种贪心策略不能保证找到确切的最近邻,因为该方法仅使用当前步骤的局部信息来做出决策。早期停止 是该算法的问题之一。特别是在搜索过程的开始阶段,当没有比当前节点更好的邻居节点时,早期停止现象尤为明显。在大多数情况下,这可能发生在起始区域有太多低度顶点时。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

早期停止。当前节点的两个邻居都比查询更远。因此,算法返回当前节点作为响应,尽管存在距离查询更近的节点。

可以通过使用多个入口点来提高搜索精度。

构建

NSW 图是通过打乱数据集点并逐个将它们插入当前图中来构建的。当插入一个新节点时,它会通过边连接到其M个最近的顶点。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

节点的顺序插入(从左到右),M = 2。在每次迭代中,向图中添加一个新顶点,并将其链接到其 M = 2 个最近邻居。蓝线表示连接到新插入节点的边。

在大多数情况下,长距离边缘可能会在图构建的初始阶段创建。它们在图导航中扮演着重要角色。

在构建开始时插入的元素的最近邻链接随后变成了网络中心之间的桥梁,这些桥梁保持了整个图的连通性,并允许在贪婪路由过程中对跳数进行对数缩放。 — Yu. A. Malkov, D. A. Yashunin

从上图中的示例可以看出,在开始时添加的长距离边缘AB的重要性。设想一个查询需要遍历从相对远离的节点AI的路径。拥有边缘AB允许通过直接从图的一侧导航到另一侧来快速完成这个过程。

随着图中顶点数量的增加,新连接到新节点的边的长度变短的概率也增加。

HNSW

HNSW 基于与跳表和可导航小世界相同的原理。它的结构表现为一个多层次的图,其中顶部层次的连接较少,而底部层次的区域则更为密集。

搜索

搜索从最高层开始,每次在层节点中贪婪地找到局部最近邻,然后逐层向下。最终,找到的最低层上的最近邻即为查询的答案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

HNSW 中的搜索

类似于 NSW,通过使用多个入口点可以提高 HNSW 的搜索质量。与其在每层上仅找到一个最近邻,不如使用efSearch(一个超参数)找到与查询向量最接近的最近邻,并将每个邻居作为下一层的入口点。

复杂度

原始论文的作者声称,在任何层上查找最近邻所需的操作数都由一个常数限制。考虑到图中的所有层数是对数级的,我们得到了总的搜索复杂度,即O(logn)

构建

选择最大层

节点在 HNSW 中是一个接一个地顺序插入的。每个节点会随机分配一个整数l,表示该节点可以出现在图中的最大层。例如,如果l = 1,则该节点只能在第 0 层和第 1 层找到。作者为每个节点随机选择l,其指数衰减概率分布由非零乘数mL(mL = 0 结果是 HNSW 中的单层和非优化的搜索复杂度)进行归一化。通常,大多数l值应该等于 0,因此大多数节点仅存在于最低层。较大的mL值增加了节点出现在更高层的概率。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个节点的层数 l 是根据指数衰减概率分布随机选择的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于标准化因子mL的层数分布。横轴表示均匀分布(0, 1)的值。

为了实现可控层次结构的最佳性能优势,不同层之间的邻居重叠(即也属于其他层的元素邻居的百分比)必须很小。 — Yu. A. Malkov, D. A. Yashunin。

减少重叠的一个方法是减小mL。但重要的是要记住,减少mL通常会导致在每层贪婪搜索过程中需要更多的遍历。因此,选择一个能够平衡重叠和遍历次数的mL值至关重要。

论文的作者建议选择mL的最佳值,即1 / ln(M)。该值对应于跳表的参数p = 1 / M,它是层间的平均单元素重叠。

插入

节点被分配l值后,有两个插入阶段:

  1. 算法从上层开始,贪婪地找到最近的节点。找到的节点随后被用作下一层的入口点,搜索过程继续。一旦达到层l,插入过程就进入第二步。

  2. 从层l开始,算法在当前层插入新节点。然后,它像之前一样执行第 1 步,但不是仅找到一个最近邻,而是贪婪地搜索efConstruction(超参数)个最近邻。然后从efConstruction个邻居中选择M个,并建立从插入节点到它们的边。之后,算法下降到下一层,每个找到的efConstruction节点作为入口点。算法在新节点及其边被插入到最低层 0 后终止。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 HNSW 中插入一个节点(蓝色)。新节点的最大层随机选择为 l = 2。因此,节点将被插入到层 2、1 和 0。在每一层,节点将连接到其 M = 2 个最近邻。

选择构造参数的值

原始论文提供了如何选择超参数的几个有用见解:

  • 根据模拟,M的良好值在 5 到 48 之间。较小的M值适合较低的召回率或低维数据,而较大的 M 值则更适合较高的召回率或高维数据。

  • 更高的efConstruction值意味着更深层次的搜索,因为会探索更多的候选项。然而,这也需要更多的计算。作者建议选择一个efConstruction值,以便在训练过程中回忆接近0.95–1

  • 另外,还有一个重要的参数 Mₘₐₓ — 一个顶点可以拥有的最大边数。除此之外,还存在一个相同的参数 Mₘₐₓ₀,但仅针对最低层。建议选择一个接近 2 * MMₘₐₓ 值。大于 2 * M 的值可能会导致性能下降和过度的内存使用。同时,Mₘₐₓ = M 会导致高召回率下性能差。

候选选择启发式

上面提到,在节点插入过程中,从 efConstruction 候选节点中选择 M 个来建立边。让我们讨论选择这些 M 个节点的可能方法。

天真的方法选取 M 个最近的候选节点。然而,这并不总是最优选择。下面是一个演示这个问题的例子。

想象一个如下面图所示的图结构。正如你所见,图中有三个区域,其中两个区域彼此没有连接(在左侧和顶部)。因此,例如,从点 AB 需要通过另一个区域经过很长的路径。为了更好的导航,将这两个区域以某种方式连接起来是合乎逻辑的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

节点 X 被插入到图中。目标是将其最优地连接到其他 M = 2 个点。

然后一个节点 X 被插入到图中,并且需要连接到 M = 2 个其他顶点。

在这种情况下,天真的方法直接选择 M = 2 个最近的邻居(BC),并将 X 连接到它们。尽管 X 已经连接到其真实的最近邻居,但这并没有解决问题。让我们来看一下作者们发明的启发式方法。

启发式算法不仅考虑节点之间的最近距离,还考虑图中不同区域的连通性。

启发式算法选择第一个最近的邻居(在我们的例子中是 B)并将插入的节点 (X) 连接到它。然后算法按照排序的顺序逐个选择下一个最接近的邻居 (C),并仅当该邻居到新节点 (X) 的距离小于该邻居到所有已经连接的顶点 (B) 到新节点 (X) 的距离时,才建立一条边。之后,算法继续处理下一个最近的邻居,直到建立 M 条边。

回到例子,启发式过程如下面的图所示。启发式算法选择 B 作为 X 的最近邻居,并建立了边 BX。然后算法选择 C 作为下一个最近邻居。然而,这次 BC < CX。这表明将边 CX 添加到图中并不是最优的,因为已经存在边 BX,且节点 BC 非常接近。相同的类比适用于节点 DE。之后,算法检查节点 A。这一次,它满足条件,因为 BA > AX。因此,新边 AX 和两个初始区域变得互相连接。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

左侧的示例使用了简单的方法。右侧的示例使用了选择启发式,使两个初始不相交的区域相互连接。

复杂度

插入过程与搜索过程非常相似,没有显著的差异需要非恒定数量的操作。因此,单个顶点的插入需要 O(logn) 的时间。要估计总复杂度,应该考虑给定数据集中的所有插入节点 n。最终,HNSW 构建需要 O(n * logn) 时间。

将 HNSW 与其他方法结合使用

HNSW 可以与其他相似性搜索方法结合使用,以提供更好的性能。最常见的方法之一是将其与倒排文件索引和产品量化(IndexIVFPQ)结合使用,这在本系列文章的其他部分中已有描述。

[## 相似性搜索,第三部分:融合倒排文件索引和产品量化

在本系列的前两部分中,我们讨论了信息检索中的两个基本算法:倒排……

medium.com](https://medium.com/@slavahead/similarity-search-blending-inverted-file-index-and-product-quantization-a8e508c765fa?source=post_page-----2aad4fe87d37--------------------------------)

在这个范式中,HNSW 充当粗量化器的角色,负责找到最近的 Voronoi 划分,从而可以缩小搜索范围。为此,必须在所有 Voronoi 质心上构建 HNSW 索引。给定查询时,使用 HNSW 找到最近的 Voronoi 质心(而不是之前通过比较每个质心的距离进行的暴力搜索)。之后,查询向量在相应的 Voronoi 划分中被量化,并通过 PQ 代码计算距离。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过在 Voronoi 质心上建立的 HNSW 中找到最近邻,选择最接近的 Voronoi 质心。

当仅使用倒排文件索引时,最好将 Voronoi 划分的数量设置得不太大(例如 256 或 1024),因为会执行暴力搜索以找到最近的质心。通过选择较少的 Voronoi 划分,划分内的候选项数量变得相对较大。因此,算法迅速识别查询的最近质心,并且大部分运行时间集中在 Voronoi 划分内找到最近邻上。

然而,将 HNSW 引入工作流需要调整。考虑仅在少量质心(256 或 1024)上运行 HNSW:由于质心数量较少,HNSW 在执行时间上与简单的暴力搜索相对相同,因此不会带来显著的好处。此外,HNSW 需要更多的内存来存储图结构。

这就是为什么在合并 HNSW 和倒排文件索引时,建议将 Voronoi 质心的数量设置得比平时大得多。这样,每个 Voronoi 分区内的候选者数量会大大减少。

这种范式的转变导致了以下设置:

  • HNSW 以对数时间快速识别最近的 Voronoi 质心。

  • 之后,执行各自 Voronoi 分区内的穷举搜索。因为潜在候选者的数量较少,所以不应成为问题。

Faiss 实现

Faiss(Facebook AI 搜索相似性)是一个用 C++编写的 Python 库,用于优化相似性搜索。该库提供了不同类型的索引,这些索引是用于高效存储数据和执行查询的数据结构。

根据Faiss 文档的信息,我们将探讨如何将 HNSW 与倒排文件索引和乘积量化结合使用。

IndexHNSWFlat

FAISS 有一个类IndexHNSWFlat实现了 HNSW 结构。通常,“Flat”后缀表示数据集向量完全存储在索引中。构造函数接受 2 个参数:

  • d:数据维度。

  • M:在插入过程中需要添加到每个新节点的边的数量。

此外,通过hnsw字段,IndexHNSWFlat 提供了几个有用的属性(可以修改)和方法:

  • hnsw.efConstruction:构造时要探索的最近邻数量。

  • hnsw.efSearch:搜索时要探索的最近邻数量。

  • hnsw.max_level:返回最大层级。

  • hnsw.entry_point:返回入口点。

  • faiss.vector_to_array(index.hnsw.levels):返回每个向量的最大层级列表。

  • hnsw.set_default_probas(M: int, level_mult: float):允许分别设置MmL值。默认情况下,level_mult 设置为 1 / ln(M)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Faiss 实现的 IndexHNSWFlat

IndexHNSWFlatMₘₐₓ = MMₘₐₓ₀ = 2 * M 设置值。

IndexHNSWFlat + IndexIVFPQ

IndexHNSWFlat 也可以与其他索引结合使用。一个例子是前面部分描述的IndexIVFPQ。创建这个复合索引分两个步骤进行:

  1. IndexHNSWFlat 被初始化为粗量化器。

  2. 量化器作为参数传递给IndexIVFPQ的构造函数。

训练和添加可以使用不同或相同的数据完成。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

FAISS 实现的 IndexHNSWFlat + IndexIVFPQ

结论

在这篇文章中,我们研究了一种强大的算法,该算法在处理大型数据集向量时表现尤为出色。通过使用多层图表示和候选选择启发式方法,其搜索速度在保持合理的预测准确性的同时得以高效扩展。值得注意的是,HNSW 还可以与其他相似性搜索算法结合使用,使其非常灵活。

资源

除非另有说明,否则所有图像均由作者提供。

相似性搜索,第五部分:局部敏感哈希(LSH)

原文:towardsdatascience.com/similarity-search-part-5-locality-sensitive-hashing-lsh-76ae4b388203?source=collection_archive---------0-----------------------#2023-06-24

探索如何将相似性信息融入哈希函数

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Vyacheslav Efimov

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 6 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

相似性搜索 是一个问题,给定一个查询的目标是在所有数据库文档中找到与其最相似的文档。

介绍

在数据科学中,相似性搜索通常出现在自然语言处理(NLP)领域、搜索引擎或推荐系统中,其中需要为一个查询检索到最相关的文档或项目。存在多种不同的方法来提升在海量数据中的搜索性能。

在本系列文章的前几部分中,我们讨论了倒排文件索引、产品量化和 HNSW 以及它们如何结合使用以提高搜索质量。在本章中,我们将探讨一种主要不同的方法,这种方法既能保持高搜索速度,又能保证高质量。

## 相似性搜索,第三部分:融合倒排文件索引和产品量化

在本系列的前两部分中,我们讨论了信息检索中的两个基本算法:倒排…

towardsdatascience.com ## 相似性搜索,第四部分:分层可导航小世界(HNSW)

分层可导航小世界(HNSW)是一种最先进的算法,用于近似搜索最近的…

towardsdatascience.com

局部敏感哈希(LSH)是一组方法,用于通过将数据向量转换为哈希值来缩小搜索范围,同时保留有关其相似性的信息。

我们将讨论传统方法,该方法包括三个步骤:

  1. 切片:将原始文本编码成向量。

  2. MinHashing:将向量转换为一种称为 签名 的特殊表示形式,这种表示形式可以用于比较它们之间的相似性。

  3. LSH 函数:将签名块哈希到不同的桶中。如果一对向量的签名至少有一次落入同一个桶中,则它们被视为候选。

我们将逐步深入探讨这些步骤的细节。

切片

切片是从给定文本中收集 k-grams 的过程。k-gram 是一组 k 个顺序排列的标记。根据上下文,标记可以是单词或符号。切片的最终目标是通过使用收集到的 k-grams 来编码每个文档。我们将使用独热编码来完成这一点。然而,也可以应用其他编码方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

收集句子“学习数据科学很有趣”的长度为 k = 3 的唯一切片

首先,收集每个文档的独特k-gram。其次,为了对每个文档进行编码,需要一个词汇表,它代表了所有文档中独特k-gram 的集合。然后,为每个文档创建一个长度等于词汇表大小的零向量。对于文档中出现的每个k-gram,确定其在词汇表中的位置,并在文档向量的相应位置放置一个*“1”。即使相同的k*-gram 在文档中出现多次也没关系:向量中的值始终为 1。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一热编码

MinHashing

在这个阶段,初始文本已经被向量化。可以通过Jaccard 指数比较向量的相似性。记住,Jaccard 指数定义为两个集合中共同元素的数量除以所有元素的总长度。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Jaccard 指数定义为两个集合的交集与并集之比

如果取一对编码向量,则 Jaccard 指数公式中的交集是两个都包含 1 的行数(即k-gram 在两个向量中都出现),并且并集是至少包含一个 1 的行数(k-gram 至少在一个向量中出现)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Jaccard 指数公式

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用上述公式计算两个向量的 Jaccard 指数的示例

当前的问题是编码向量的稀疏性。计算两个一热编码向量之间的相似性得分将耗费大量时间。将它们转换为稠密格式可以使后续操作更高效。最终目标是设计一个将这些向量转换为较小维度的函数,同时保留它们的相似性信息。构建这样的函数的方法称为 MinHashing。

MinHashing 是一种哈希函数,它对输入向量的组件进行排列,然后返回排列向量组件等于 1 的第一个索引。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算给定向量和排列的 minhash 值的示例

为了获得由n个数字组成的向量的稠密表示,可以使用n个 minhash 函数来获得n个 minhash 值,这些值构成一个签名

一开始可能不太明显,但可以使用多个 minhash 值来近似向量之间的 Jaccard 相似性。实际上,使用的 minhash 值越多,近似值就越准确。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算签名矩阵及其如何用于计算向量之间的相似性。使用 Jaccard 相似性和签名计算的相似性应该通常大致相等。

这只是一个有用的观察。事实证明,背后有一个完整的定理。让我们来了解为什么 Jaccard 指数可以通过使用签名来计算。

陈述证明

假设给定的一对向量仅包含011011类型的行。然后对这些向量进行随机排列。由于所有行中至少存在一个 1,因此在计算两个哈希值时,这两个哈希值计算过程中的至少一个会在具有对应哈希值为 1 的向量的第一行停止。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第二个哈希值等于第一个的概率是多少?显然,只有当第二个哈希值也等于 1 时才会发生。这意味着第一行必须是11类型。由于排列是随机的,这种事件的概率等于P = count(11) / (count(01) + count(10) + count(11))。这个表达式与 Jaccard 指数公式完全相同。因此:

基于随机行排列,两个二进制向量获得相同哈希值的概率等于 Jaccard 指数

然而,通过证明上述陈述,我们假设初始向量不包含00类型的行。显然,00类型的行不会改变 Jaccard 指数的值。同样,包含00类型行时获得相同哈希值的概率不会影响它。例如,如果第一个排列行是 00,则 minhash 算法只是忽略它,转到下一行,直到找到至少一个 1。当然,00类型的行可能导致不同的哈希值,但获得相同哈希值的概率保持不变

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们已经证明了一个重要的陈述。但是,如何估计获得相同的 minhash 值的概率呢?当然,可以生成所有可能的向量排列,然后计算所有的 minhash 值以找到所需的概率。出于显而易见的原因,这种方法效率不高,因为一个大小为n的向量的可能排列数等于n!。不过,概率可以大致评估:我们可以使用很多哈希函数来生成大量的哈希值。

两个二进制向量的 Jaccard 指数大致等于它们签名中对应值的数量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数学符号

很容易注意到,采用更长的签名会导致更准确的计算。

LSH 函数

目前,我们可以将原始文本转换为长度相等的密集签名,从而保留关于相似性的的信息。然而,在实践中,这些密集签名通常仍具有高维度,直接比较它们效率不高。

考虑到n = 10⁶ 个文档,每个文档的签名长度为 100. 假设一个签名的单个数字需要 4 字节来存储,那么整个签名将需要 400 字节。存储n = 10⁶ 个文档需要 400 MB 的空间,这在现实中是可行的。但以蛮力方式比较每个文档需要大约 5 * 10¹¹次比较,这太多了,尤其是当n 更大的时候。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

为了避免这个问题,可以建立一个哈希表来加速搜索性能,但即使两个签名非常相似,仅在 1 个位置上有所不同,它们仍可能具有不同的哈希值(因为向量的余数可能不同)。然而,我们通常希望它们落入同一个桶中。这就是 LSH 派上用场的地方。

LSH机制构建一个哈希表,该表由几个部分组成,如果一对签名有至少一个对应的部分,它们就会被放入同一个桶中。

LSH 将签名矩阵水平分成相等的b部分,称为,每部分包含r 。而不是将整个签名插入到一个哈希函数中,签名被分成b部分,每个子签名由一个哈希函数独立处理。因此,每个子签名落入不同的桶中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

LSH 的示例。两个长度为 9 的签名被分成 b = 3 个带,每个带包含 r = 3 行。每个子向量被哈希到 k 个可能的桶之一。由于第二个带中存在匹配(两个子向量具有相同的哈希值),我们将这两个签名对视为最近邻候选。

如果两个不同签名的对应子向量之间至少有一个碰撞,那么这些签名被视为候选。如我们所见,这个条件更灵活,因为考虑向量作为候选者时,它们不需要完全相等。然而,这增加了假阳性的数量:一对不同的签名可能只有一个对应的部分,但总体上完全不同。根据问题的不同,优化参数brk 总是更好的。

错误率

使用 LSH,可以估计两个具有相似度s的签名被视为候选的概率,给定带数b和每个带中的行数r。让我们分几个步骤找到它的公式。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

两个签名的任意一行相等的概率

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一随机带有 r 行的概率相等

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一随机带有 r 行的概率不同

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表中所有 b 个带不同的概率

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

至少有一个 b 带相等的概率,即两个签名是候选的

注意,公式没有考虑当不同的子向量意外地哈希到同一个桶中时的碰撞。因此,签名成为候选的真实概率可能会略有不同。

示例

为了更好地理解我们刚刚得到的公式,我们考虑一个简单的例子。考虑两个长度为 35 符号的签名,它们被平均分成 5 个带,每个带有 7 行。以下表格表示了基于 Jaccard 相似度至少有一个相等带的概率:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于相似度 s,至少获得一对签名具有对应带的概率 P

我们注意到,如果两个相似的签名具有 80% 的 Jaccard 相似度,那么在 93.8% 的情况下它们有一个对应带(true positives)。在剩余的 6.2% 情况下,这样的一对签名是 false negative

现在让我们考虑两个不同的签名。例如,它们的相似度只有 20%。因此,在 0.224% 的情况下,它们是 false positive 候选。在其他 99.776% 的情况下,它们没有相似的带,所以它们是 true negatives

可视化

现在让我们可视化相似度 s 和两个签名成为候选的概率 P 之间的关系。通常,随着签名相似度 s 的提高,签名成为候选的概率应当更高。理想情况下,情况如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

理想的场景。只有当签名的相似度大于某个阈值 t 时,才认为一对签名是候选的

基于上述获得的概率公式,典型的线如下图所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一条典型的线在开始和结束时缓慢上升,并在图中所示的近似概率公式的阈值 t 处有一个陡峭的斜率

可以通过改变带的数量b,将图中的线向左或向右移动。增加 b 将线向左移动,并导致更多的 FP,减少则将其向右移动,导致更多的 FN。根据问题找到一个好的平衡点是很重要的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

带的数量增加,线会向左移动;减少则向右移动

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

将阈值向左移动会增加 FP,而向右移动则增加 FN

采用不同数量的带和行进行实验

以下为不同值的br构建的几条线图。根据具体任务调整这些参数通常更为有效,以成功检索所有相似文档对,并忽略那些具有不同签名的文档。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

调整带的数量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

调整行数

结论

我们已经讲解了 LSH 方法的经典实现。LSH 通过使用低维签名表示和快速哈希机制来优化搜索速度,从而减少候选项的搜索范围。同时,这也会影响搜索的准确性,但在实践中,差异通常微不足道。

然而,LSH 对高维数据比较敏感:更多维度需要更长的签名长度和更多计算来保持良好的搜索质量。在这种情况下,建议使用其他索引。

实际上,存在不同的 LSH 实现,但所有这些实现都基于将输入向量转换为哈希值的相同范式,同时保留关于它们相似性的信息。基本上,其他算法只是定义了获得这些哈希值的不同方式。

随机投影是另一种 LSH 方法,将在下一章中介绍,并且在Faiss库中实现为 LSH 索引,用于相似性搜索。

资源

所有图像除非另有说明,均由作者提供。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值