TowardsDataScience 2023 博客中文翻译(一百一十)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

Dijkstra 算法在 OSM 网络中按旅行时间加权

原文:towardsdatascience.com/dijkstras-algorithm-weighted-by-travel-time-in-osm-networks-792aa92e03af

使用 OSMNX 1.6 寻找最快和最短路径

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Bryan R. Vallejo

·发表于 Towards Data Science ·7 min 阅读·2023 年 10 月 10 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。摩洛哥示例中的最快路线(红色)和最短路线(橙色)

最短路径(Dijkstra)算法可以应用于 OSM 网络中,如驾驶、骑行和步行,以找到起点和终点之间的最优路线。但该算法在网络中计算基于距离的最短路线,这并不意味着最优路线。在道路网络中,距离可以是相对的,当我们考虑道路的速度时。显然,在所有道路速度相等的情况下,两点之间的最优路线可能是最短的。如果我们比较高速公路与城市街道的速度,我们会重新调整这个想法,理解最优路线是最快的。

“在道路网络中,距离可能是相对的,当我们考虑道路的速度时”

借助 Python 库 OSMNX,可以在全球范围内为不同类型的道路添加速度,并计算 OSM 网络中节点之间的旅行时间。这使得 Python 库可以处理以旅行时间加权的最短路径算法。

这一实践是之前一个教程的延续,那个教程使用了最短路径算法来计算摩洛哥两个位置之间的最短路线。

摩洛哥 OSM 网络中的最短路线

## 最短路径(Dijkstra)算法逐步 Python 指南

使用 OSMNX 1.6 的更新及长距离路径

towardsdatascience.com

访问编码教程

如果你还不是Medium的会员,你需要订阅才能访问这些故事。你可以通过使用我的个人链接来跟随更多编码教程并支持我的工作。成为这段编码之旅的一部分。

这里加入 👉 bit.ly/3yjLsSL

OSM 数据许可证

介绍

接下来的步骤将指导如何使用旅行时间应用最短路径。我们将比较最快路线和最短路线,以了解旅行时间和长度的变化。

此外,我们还将使用 OSMNX 中的更多函数来改进结果,如utils_graph.route_to_gdf(),以及计算旅行时间add_edge_speeds()add_edge_travel_times()

这是由Hanae建议的起点和终点的快速视图。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。起点和终点位置。

编码实践

开始获取我们需要的库。

import osmnx as ox
import geopandas as gpd
from shapely.geometry import Point
import pandas as pd
import matplotlib.pyplot as plt

1. 定义起点和终点为 GeoDataFrames

开始添加坐标并使用 Geopandas 创建新的 GDF。

# --- origin and destination geom

origin_geom = Point(-5.6613932957355715, 32.93210288339607)
destination_geom = Point(-3.3500597061072726, 34.23038027794419)

# --- create origin dataframe

origin =  gpd.GeoDataFrame(columns = ['name', 'geometry'], crs = 4326, geometry = 'geometry')
origin.at[0, 'name'] = 'origin'
origin.at[0, 'geometry'] =origin_geom

# --- create destination dataframe

destination =  gpd.GeoDataFrame(columns = ['name', 'geometry'], crs = 4326, geometry = 'geometry')
destination.at[0, 'name'] = 'destination'
destination.at[0, 'geometry'] = destination_geom

2. 获取图网络

当我们有长途路线时,建议使用envelope函数来获取图。

使用之前定义的函数

def get_graph_from_locations(origin, destination, network='drive'):
    '''
    network_type as drive, walk, bike
    origin gdf 4326
    destination gdf 4326
    '''
    # combine and area buffer
    combined = pd.concat([origin, destination])

    convex = combined.unary_union.envelope # using envelope instead of convex, otherwise it breaks the unary_union

    graph_extent = convex.buffer(0.02)

    graph = ox.graph_from_polygon(graph_extent, network_type= network)

    return graph

应用并可视化

# --- Get Graph
graph = get_graph_from_locations(origin, destination)
fig, ax = ox.plot_graph(graph, node_size=0, edge_linewidth=0.4, bgcolor='black', edge_alpha=0.2,  edge_color='yellow')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。包含起点和终点的道路网络图

3. 向道路网络添加旅行时间

我们将使用函数add_edge_speeds()添加旅行时间,该函数输入速度(以 km/h 为单位)。插补将按高速公路类型添加道路的平均最大速度值。然后,我们使用函数add_edge_travel_time()计算旅行时间

我们将保留相同的变量graph

# --- add edge speed
graph = ox.add_edge_speeds(graph)

# --- add travel time
graph = ox.add_edge_travel_times(graph)

如果你想修改高速公路类别,可以传递一个基于本地速度值的字典

# --- add speeds define by local authorities (example)
hwy_speeds = {"residential": 35, "secondary": 60, "tertiary": 75}

graph = ox.add_edge_speeds(graph, hwy_speeds)
graph = ox.add_edge_travel_times(graph)

现在,通过获取边缘类别来快速查看按高速公路类型的旅行时间。

# --- get the edges as GDF
edges = ox.graph_to_gdfs(graph, nodes=False)[['highway', 'speed_kph', 'length', 'travel_time', 'geometry']].reset_index(drop=True)

# --- see mean speed/time values by road type
edges["highway"] = edges["highway"].astype(str)
edges.groupby("highway")[["speed_kph", "travel_time"]].mean().round(0)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。道路网络中的速度和旅行时间

4. 查找起点和终点的最近节点

使用起点和终点坐标获取网络中最接近的节点。该函数在新版本 1.6 中为nearest_nodes()

# ------------- get closest nodes

# origin
closest_origin_node = ox.nearest_nodes(G=graph, 
                                       X=origin_geom.x, 
                                       Y=origin_geom.y)

# destination
closest_destination_node = ox.nearest_nodes(G=graph, 
                                           X=destination_geom.x, 
                                           Y=destination_geom.y)

然后,我们在节点之间应用最短路径算法。

5. 计算使用旅行时间的最短路径

我们将使用 shortest_path() 函数来计算我们的路线。我们将同时使用距离和时间来比较路线在 weight 参数下的不同之处。

# --- calculate shortest path with length and travel time

# time
fastest_route = ox.shortest_path(graph, 
                                orig = closest_origin_node, 
                                dest = closest_destination_node, 
                                weight="travel_time")

# distance
shortest_route = ox.shortest_path(graph, 
                                orig = closest_origin_node, 
                                dest = closest_destination_node,  
                                weight="length")

这将返回一组属于该路线的节点代码。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。路线的节点

6. 从节点创建路线(单行代码)

osmnx 实现了一个函数 utils_graph.route_to_gdf(),可以在一行中将节点转换为 GeoDataFrame。很方便,我们可以动态获取感兴趣的列。

# --- get gdf of routes

# fastest
fastest_route_gdf = ox.utils_graph.route_to_gdf(graph, fastest_route, weight='travel_time')[['highway', 'speed_kph', 'travel_time', 'geometry']]

# shortest
shortest_route_gdf = ox.utils_graph.route_to_gdf(graph, shortest_route, weight='length')[['highway', 'speed_kph', 'travel_time', 'geometry']]

7. 快速比较(时间和距离)

我们将比较两条路线的旅行时间和长度。

# --- comparison

d1 = fastest_route_gdf['length'].sum()
d2 = shortest_route_gdf['length'].sum()

t1 = fastest_route_gdf['travel_time'].sum()
t2 = shortest_route_gdf['travel_time'].sum()

打印值

print(f'Fastest Route: Time {round(t1/3600, 2)} hours, Distance {round(d1/1000, 2)} km')
print(f'Shortest Route: Time {round(t2/3600, 2)} hours, Distance {round(d2/1000, 2)} km')

*最快路线:时间 5.42 小时,距离 378.93 公里

最短路线:时间 5.6 小时,距离 362.84 公里*

结果显示最快的路线更长。如预期的那样,涉及速度时,距离变得相对。

8. 保存文件并可视化

保存网络和路线

# --- save

edges.to_file('osm_drive_network.gpkg')

fastest_route_gdf.to_file('fastest_route.gpkg')

shortest_route_gdf.to_file('shortest_route.gpkg')

在 QGIS 中

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。最快的路线(橙色)和最短的路线(黄色)

在 Matplotlib 中

# --- plot network
ax = edges.plot(figsize=(12, 10), linewidth = 0.1, color='grey', zorder=0);

# --- origin and destination
origin.plot(ax=ax, markersize=100, alpha=0.8, color='blue', zorder=1)
destination.plot(ax=ax, markersize=100, alpha=0.8, color='green', zorder=2)

# --- route
fastest_route_gdf.plot(ax=ax, linewidth = 4, color='red', alpha=0.4, zorder=3)
shortest_route_gdf.plot(ax=ax, linewidth = 4, color='yellow', alpha=0.4, zorder=4)

plt.axis(False);

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。较快的路线(红色)和最短的路线(黄色)

已知改进

从 osmnx 使用的新功能 utils_graph.route_gdf() 创建了一个干净的路线。它使用了道路段,而不仅仅是节点之间的联接,新路径覆盖了 OSM 道路网络。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像。路线覆盖了道路网络。

结论

最短路径算法可以使用 OSMNX 函数 add_edge_speeds()add_edge_travel_times() 计算道路速度。这种不同的方法表明,如果在路线计算中实施速度(旅行时间),最短路径是相对的。正如预期的那样,最快的路线最终比最短路线更长,但它以最短的旅行时间到达了目的地。

生成覆盖道路网络的路线的改进使得在城市区域级别的可达性和邻近性研究中,距离和旅行时间的计算更加准确。

致谢

多亏了Geoff Boeing提供的资料,我得以探索这些功能并理解 OSMNX 的功能。

如果你想就问题或定制分析联系我:

Bryan R. LinkedIn

维度缩减:面对维度诅咒

原文:towardsdatascience.com/dimension-reduction-facing-the-curse-of-dimensionality-63a743e4b199?source=collection_archive---------4-----------------------#2023-04-13

PCA 与动态因子模型的比较

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Victor Graff

·

跟进 发表在 Towards Data Science ·10 分钟阅读·2023 年 4 月 13 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Kolleen Gladden 的照片,来自 Unsplash

许多数据科学家不得不面对维度的挑战。数据集可能包含大量变量,使得理解和计算变得复杂。例如,资产管理者可能会被与其投资组合相关的许多动态变量所困扰,处理大量数据可能导致计算问题。降维是一种将大量变量的信息提取到较小的降维变量集合中的方法,而不会丧失过多的解释性。换句话说,降维方法可以被认为是寻找一个最小化重构误差的子空间。

存在几种方法来进行信息提取,每种方法都适用于不同的用例。本文旨在提供这两种方法的详细比较:主成分分析(PCA)和动态因子模型(DFM)。PCA 可以用于任何类型的结构化数据集,而动态因子模型则用于时间序列应用,因为它嵌入了时间序列的演变。

分析基于经济和金融数据。用于本研究的数据是克拉克、托德;卡里耶罗、安德烈亚;马切利诺、马西米利亚诺的文章测量不确定性及其对经济的影响中使用数据的复制版,数据可在Harvard dataverse上获取。数据包括 18 个宏观经济变量和 12 个金融变量,涵盖了这些变量从 1960 年到 2014 年的演变。在通过降维算法处理之前,数据被转换以确保平稳性。

整个代码可在Github上获取。

主成分分析(PCA)

理论

PCA 可以看作是一种无监督的降维方法。假设我们有大量的变量。所有这些变量似乎都对分析有用,但没有明显的方法将这些变量汇总成类别。在这种情况下,算法将负责在没有模型师特定输入的情况下进行降维。换句话说,算法将创建更少的变量,称为降维成分,这些成分能够接近地重现初始变量。

PCA 的方法基于变量的协方差。如果两个变量高度协方差,这意味着它们遵循相同的趋势。第一个变量在重现第二个变量方面非常高效,使得只保留第一个变量而不丧失在需要时重建第二个变量的能力成为可能。PCA 创建一个变量子集,最大化与初始变量集的协方差,以便在较低维度中存储尽可能多的信息。

该方法的思路是计算由原始变量集创建的空间的正交基。创建这个基的向量是方差-协方差矩阵的特征向量。通过选择最能代表初始数据的特征向量,即包含最多协方差的特征向量,可以轻松地减少维度。特征值量化了向量存储的协方差量:特征值越大,其相关的向量就越有趣。

PCA 算法的过程如下:

1. 计算协方差矩阵

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2. 计算其特征向量和特征值

3. 对特征值进行排序,以保留包含最多信息的向量

每个特征值与所有特征值之和的比率表示其相关特征向量中包含的协方差量。剩下的任务是确定保留的特征向量数量。我们将在下一节中看到为此选择的不同参数。

应用到数据

Python 使得定义 PCA 模型变得简单,因为它包含在库 sklearn 中。属性 n_components 可以初步设置为一个较大的值,以便比较特征值,然后选择保留的组件数量。一旦拟合,特征值将按降序显示,以帮助我们做出决策。下面的图显示了每个特征值所包含的协方差比率。

from sklearn.decomposition import PCA
pca = PCA(n_components=5).fit(u_data)
plt.plot(pca.explained_variance_ratio_)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

选择合适特征向量数量的通常规则是查看图中表示的“肘部”。取到肘部的向量数量提供了信息保留和结果维度之间的有趣折衷。在这种情况下,我们保留前两个组件。

前两个组件的协方差比率为 69.6% 和 9.7%。因此,通过仅保留两个组件,我们保留了初始数据中几乎 80% 的信息,同时将维度从 30 减少到 2!

总结来说,PCA 是一种很好的降维工具。它易于部署,并且在保留信息方面产生了良好的结果,同时显著减少了维度。然而,PCA 像一个黑箱,阻止了对结果组件的有意义理解。此外,PCA 适用于任何类型的结构化数据,但如果数据以时间序列的形式存在,则不包含数据的动态性。

下一节将讨论动态因子模型,它们可能是应对这些局限性的潜在解决方案。

动态因子模型(DFM)

理论

动态因子模型用于观察 N 个变量随时间的演变(这些变量组合成一个向量 Xt),并使用较少数量的动态公共因子。这种方法的优势在于它将大量变量的共同运动嵌入到较少的成分中。

这种方法适用于时间序列应用。因此,它们在金融和经济学中被广泛使用,因为许多关键变量随时间共同演变。

DFM 将向量 Xt 定义为减少因子(ft)过去和当前值的线性组合。这些因子本身是动态的,即以自回归方式定义。减少的成分数量为 q,自回归的滞后为 p。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个 λ 是一个 (N x q) 矩阵,其中 q 是减少的成分数量,每个 ft 是一个 (q x 1) 向量,每个 ψ 是一个 (q x q) 矩阵。动态性体现在每个减少的向量 ft 遵循向量自回归过程,因此它本身是基于 f 的过去值计算的。此外,向量 X 受当前和过去数据的影响。

DFM 的一个非常重要的方面是组件的数量是在计算前基于对数据的定性知识定义的。如果变量容易分类,这可以是一个有趣的特征,但如果没有出现有意义的类别,这也可能是一个挑战。

一旦因子的数量被定义,估计成分的主要方法是使用高斯最大似然估计器(MLE)。ε 和 η 被假设为遵循高斯分布,MLE 的目标是通过调整高斯参数(均值和标准差)来最大化获得样本数据(Xt, ft)的概率。幸运的是,这一步骤在 Python 库中直接实现,使得计算变得容易。

一旦估计完成,这些计算出的成分将代表它们被分配的类别。这就意味着我们得到的成分数量与我们定义的类别数量相同。这使我们能够以高效且有意义的方式减少维度。

数据应用

DFM 将应用于与之前展示的相同数据集。这里有个好消息:我们直接有两个明显的类别:宏观经济学和金融。

Python 的 statsmodels 库中包含了一个 DFM 模型:DynamicFactorMQ。为了计算模型,需要几个参数。首先,显然是我们旨在减少的初始数据。其次,一个将每个变量与其类别关联起来的字典(从技术上讲,每个变量可以属于多个类别,但我们在这里不讨论这种情况)。

factors = dict()
for macro_variable in list(macro_variables.values()):
    factors[macro_variable] = ["Macro"]
for finance_variable in list(finance_variables.values()):
    factors[finance_variable] = ["Finance"]

然后,我们定义与每个因子 ft 相关的 VAR 模型的滞后阶数,即多少个时间步骤向后影响因子的当前状态。在我们的案例中,一个滞后似乎足够。增加滞后显然会增加计算约束,但通过在每一步提供更长时间的信息,可以显著影响模型的效率。

最后,需要定义特定成分。这个成分表示向量 Xt 中不能通过 ft 的当前值和过去值解释的部分。这个成分可以看作是线性回归中的残差。它可以拟合为 AR(1)模型或白噪声。从经济学角度来看,这一选择是重要的:我们估计模型的残差是自回归的(即现值和过去值相关)还是独立同分布的?对于经济学研究来说,一个不相关的特定成分通常是不现实的,因为测量方法通常会引入相关误差。

factor_model = DynamicFactorMQ(u_data,
                    factors=factors,
                    factor_orders = {'Macro': 1, "Finance": 1},
                    idiosyncratic_ar1=True,
                    standardize=False)
model_results = factor_model.fit(disp=30)

方法比较

接下来的问题显然是:应该使用哪种方法?如预期的那样,这取决于我们要寻找的内容。

让我们总结一下每个模型的优缺点。

主成分分析(PCA)

  • 可以应用于任何类型的结构化数据

  • 计算时无需对数据有先验知识

  • 选择降维成分的经验法则

  • 无监督过程

动态因子模型(DFM)

  • 应用于时间序列数据

  • 对数据的定性知识,以确定嵌入在降维因子中的类别

  • 预先确定的降维成分数量

乍一看,PCA 似乎比 DFM 更受关注,但要做出决定还需要进一步观察。这两者之间的主要区别在于 DFM 能够提供其结果的有意义解释。

可读性

首先,我们来看一下创建的组件。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这两个图显示了每个模型中两个选择因子的演变。有趣的是,这两个模型似乎都将一个变量分离为趋势(蓝色)和另一个变量分离为波动性(橙色)。DFM 给我们提供了这一观察结果背后的含义:看到宏观变量(例如 GDP、房价等)随着时间的推移而增加并不奇怪。此外,金融变量被认为波动性更大。PCA 似乎捕捉到了相同类型的信息,但我们仍然只能对这种现象做出假设。DFM 在这一点上有优势。

准确性

让我们回到降维方法的目的:作为较低维度下原始数据的良好替代。因此,我们需要确保模型能够准确地重现原始数据。

Python 为这两种算法提供了一种便捷的方法来重现初始变量。对于 PCA,将数据转换为其降维空间后,inverse_transform 方法提供了由模型处理的每个初始变量的表示。DFM 模型将所有表示包含在其 fittedvalues 属性中。

#PCA
scores = pca.transform(u_data)
reconstruct = u_data + pca.inverse_transform(scores) - u_data

#DFM
model_results.fittedvalues

我们可以轻松绘制每个模型的数据表示。下面的图中,我们展示了失业率变量的一个例子。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这个例子中,DFM 显然更适合,因为它始终更接近原始数据,变化更少。为了进行更全面和定量的评估,让我们计算两个模型在整个数据集上的残差。

print(f"Residuals of DFM on global dataset: {np.round(np.abs(model_results.resid).sum().sum(), 2)}")
print(f"Residuals of PCA on global dataset: {np.round(np.abs(resid_pca).sum().sum(), 2)}")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

残差和

在重现初始数据方面,DFM 明显比 PCA 更具性能。模型中的分类和动态似乎准确捕捉了初始变量集的信息。

结论

我们比较了两种降维方法,各有优缺点。我们看到,在所呈现的情况下,DFM 模型更适合,但 PCA 也非常有价值。让我们总结一下:

何时偏好 PCA?

  • 数据中没有时间动态。

  • 初始数据没有明显的分类。

  • 对初始数据的定性知识很少。

何时偏好 DFM?

  • 时间动态是数据的一个重要特征。

  • 分析需要对降维组件的理解。

  • 数据的分类很容易找到。

总结来说,没有哪种算法在所有情况下都优于另一种。建模者的角色是评估每种情况中什么是最好的。此外,正如我们所见,两种模型都易于在 Python 上实现。实现这两者有助于增加对数据的理解,并带来更好的解决方案。

我希望这篇文章对你有所帮助,并能帮助你理解这两种模型之间的差异。请随时给我任何反馈或想法!

参考文献

Clark, Todd; Carriero, Andrea; Marcellino, Massimiliano, 2017, “Replication Data for: “Measuring Uncertainty and Its Impact on the Economy””, doi.org/10.7910/DVN/ENTXDD, Harvard Dataverse, V3

DINO — 计算机视觉的基础模型

原文:towardsdatascience.com/dino-a-foundation-model-for-computer-vision-4cb08e821b18

🚀Sascha 的论文俱乐部

自监督视觉变换器中的新兴特性,作者 M. Caron 等。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Sascha Kirch

·发表于 Towards Data Science ·13 分钟阅读·2023 年 9 月 27 日

计算机视觉正迎来一个令人兴奋的十年。来自自然语言领域的巨大成功被转移到视觉领域,包括引入 ViT(视觉变换器),最近大规模的自监督预训练技术在基础模型的名义下成为头条新闻。

今天我们将探讨一个名为 DINO(自DI蒸馏,NO 标签)的框架,它是建立在 ViTs 有趣特性基础上的视觉基础模型。它也是今天表现最佳的基础模型之一的前身:DINOv2

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源于 出版物,作者 Sascha Kirch

论文: 自监督视觉变换器中的新兴特性,作者 Mathilde Caron 等,2021 年 4 月 29 日

资源: GitHub博客文章

类别: 基础模型,计算机视觉,视觉变换器,知识蒸馏,相似性学习,自监督学习

其他详细讲解:

[BYOL] — [CLIP] — [GLIP] — [Segment Anything] — [DINO] — [Depth Anything] — [DDPM]

大纲

  1. 背景与背景

  2. 方法

  3. 实验

  4. 消融测试

  5. 结论

  6. 进一步阅读与资源

背景与背景

时间是 2021 年,准确地说是 4 月。自从发布了带有 Attention is All You Need 的 Transformer 模型已经过去了四年。自监督预训练在 NLP 中已经由 BERT 等模型长期实践,而“基础模型”这一术语在接下来的几个月中尚未被知晓,直到 关于基础模型的机遇与风险 的发布。六个月前,Vision Transformer (ViT) 首次发布在 arxiv 上,距离 ICLR 2021 还有一个月,它将在那里进行展示。

让我们稍微消化一下这个信息:ViT 于 2020 年 10 月在 arxiv.org 上首次发布,并在 2021 年 5 月的 ICLR2021 上进行了展示。DINO 于 2021 年 4 月在 arxiv 上发布。所以,实际在会议上展示前的一个月。这意味着他们只有 5 个月的时间,如果他们立即开始的话,来构思项目的想法、组建团队、奠定理论基础、训练模型、进行实验和消融测试,并撰写论文。难怪现在的博士生感到不断的焦虑。至少这就是我有时的感受 😅

尽管 ViT 与卷积网络非常具有竞争力,但它们在计算资源和训练数据量方面的要求很高。

DINO 的作者做出了一个简单的观察:变换器在 NLP 中的成功与自监督预训练相关,而目前视觉领域的自监督方法是基于卷积网络的,比如 BYOL。

## BYOL -对比自监督学习的替代方案

论文分析——《Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning》

[towardsdatascience.com

受到 BYOL 和 mean teacher 的启发,作者提出了一个框架来以自监督的方式训练 ViT,并发现:

  1. 自监督 ViT 特征明确包含场景布局,特别是对象边界。

  2. 自监督 ViT 特征在没有任何微调、线性分类器或数据增强的情况下,与基础的最近邻分类器 (k-NN) 一起表现尤为出色。

与 BYOL 和 mean teacher 相比,DINO 实现了一个知识蒸馏框架,包括一个学生模型和一个教师模型,作用于同一图像的不同视角,并采取额外措施应对相似性学习方法的固有不稳定性,其中解决方案通常会崩溃。

底层视觉变换器架构 (ViT) 的一个有趣发现是,当使用无监督学习技术进行训练时,其特征包含有关图像语义分割的显著信息。可以简单地可视化多头注意力层中选择的头部的自注意力图,如下方视频所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1:选择的头部的自注意力图。来源

让我们深入探讨一下 DINO 实现其框架的方式,如何应对不稳定性,以及与以前的方法相比它的表现如何!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Sascha Kirch

由 Sascha Kirch 进行的论文解读

查看列表7 个故事!“DDPM — 去噪扩散概率模型”论文插图,作者:Sascha Kirch外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

方法

DINO 框架与其他相似性学习框架(如 BYOL 或 mean teacher)以及知识蒸馏具有相同的整体结构。让我们首先看看 DINO 是如何做到这一点的,并与其他框架进行区分。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2:DINO 架构。 来源 + Sascha Kirch 的注释

网络和更新规则

我们从中间开始。DINO 实现了两个具有完全相同架构但权重不同的网络。这些网络分别是学生网络和教师网络。学生网络通过反向传播进行训练,而教师网络则通过其自身权重和学生网络权重的指数移动平均来更新其权重。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

方程 1:教师权重的更新规则。 来源 + Sascha Kirch 的注释

骨干网络可以是 ResNet50 或 DeiT(这是为知识蒸馏而调整的 ViT)。一个基于 MLP 的投影头连接到骨干网络,以减少特征的维度,但在推理时会被移除。

很好,但用于推理的是哪个模型:学生还是教师? — 好问题,实际上论文中并没有提到这个问题的任何信息。直观上你可能会认为是学生,至少我最开始也是这样想的。但正如我们后续将看到的,教师在整个训练过程中表现优于学生。除了更好的性能之外,唯一的线索是,在代码实现中,教师检查点是用于例如 视频分割线性探测k-NN 的默认评估点。由于此参数可以更改,因此我不能给出确切的答案。

输入和输出

从输入图像 x 创建不同的视图 x1x2,方法是通过裁剪和应用图像增强,如 BYOL(例如色彩抖动、高斯模糊和太阳化)。用于裁剪的技术称为 multi-crop,通过生成不同大小的多个裁剪来节省内存,同时提供更多数据。小裁剪被称为局部视图,由 96x96 像素组成,这些视图专门输入到学生网络中。较大的裁剪被称为全局视图,由 224x224 像素组成,这些视图专门输入到教师网络中。正如我们在消融部分将看到的,训练过程中使用了 2 个全局视图和 10 个局部视图。

注意:论文对于多裁剪技术有点混乱,因为提供的伪代码和上面的图 3 所示的架构都没有反映出来。伪代码甚至建议 x1 和 x2 像在 BYOL 中一样输入到学生和教师中,这在使用多裁剪时并非如此。

与相似性学习的目标是最大化嵌入的相似性不同,DINO 最小化教师和学生输出分布之间的交叉熵。如下面的方程所示,交叉熵是对每对全局和局部视图计算的,然后汇总。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

方程 2:优化目标。 来源 + Sascha Kirch 的注解

模型的输出是什么? — 就像相似性学习中,学生和教师为给定的图像输出一个嵌入,而不是预测分数。就像在知识蒸馏中,输出通过 SoftMax 转换为概率分布。SoftMax 有一个温度参数,它控制结果分布的平滑或锐化。这个温度在知识蒸馏中起着关键作用,因为它可以控制从教师网络到学生网络转移一般知识和细粒度细节之间的平衡,使蒸馏过程对不同任务更有效。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3:温度值对 SoftMax 输出的影响。 Sascha Kirch 的插图,使用 这个 Python 笔记本 创建

我为你创建了一个笔记本,以便你可以调查温度对结果分布的影响:

[## ML_Notebooks/Softmax_Temperature.ipynb 在 main 分支 · sascha-kirch/ML_Notebooks

机器学习相关笔记的集合用于共享。— ML_Notebooks/Softmax_Temperature.ipynb 在 main 分支 ·…

github.com](https://github.com/sascha-kirch/ML_Notebooks/blob/main/Softmax_Temperature.ipynb?source=post_page-----4cb08e821b18--------------------------------)

避免崩溃

如前所述,学生和教师具有完全相同的架构。这种设置是不稳定的(如果没有采取对策),可能会导致崩溃解决方案,即所有特征都映射到潜在空间中的某个区域,例如最坏情况下的一个点。BYOL 通过为其中一个模型引入额外的预测头来解决这个问题,从而引入了不对称性。由于 DINO 具有对称模型,因此需要另一种技巧:中心化和锐化。两者仅应用于教师网络。中心化是一种技术,通过向教师输出添加偏置项c来防止潜在空间中的单一维度主导,即g(x) = g(x)+c

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

方程 3:中心化项的更新规则。来源 + Sascha Kirch 的注释

虽然中心化具有积极效果,但它也鼓励输出崩溃为均匀分布。锐化具有相反的效果,因此应用两者平衡它们的效果并稳定训练。通过使用较小的温度来实现锐化(见图 3),教师的 SoftMax 温度比学生的低。

为了避免方程 3 中的超参数m和教师的温度崩溃是至关重要的。在附录部分的消融研究中,作者展示了m=0.9…0.999的效果最佳,并且温度值在预热期间从0.04线性增加到0.07

DINO 是做什么的?知识蒸馏还是相似性学习?

答案是两者兼有!

虽然知识蒸馏通常是将知识从已经训练好的、更大且更准确的教师模型蒸馏到较小的学生模型中,但它也可以看作是一种相似性学习,因为它鼓励学生网络生成与教师相似的预测。在相似性学习中,两个模型通常是联合训练的,并且通常对齐它们的潜在空间预测,而不是概率分布。

由于 DINO 的作者将他们的目标表述为知识蒸馏,让我们看看与“标准”知识蒸馏相比的一些差异:

  1. DINO 的教师不是事先可用的,而是与学生一起“训练”的。它甚至可以被认为是一种共同蒸馏,因为知识也从学生蒸馏到教师。

  2. DINO 的教师和学生不是对相同的输入进行操作,而是对裁剪到不同尺寸的图像的不同视图进行操作。

  3. DINO 在两个模型的 SoftMax 中使用不同的温度来进行锐化。

  4. DINO 计算的是嵌入的温度缩放 SoftMax 上的交叉熵,而不是预测分数。

它与知识蒸馏的相似之处在哪里?:

  1. DINO 由一个学生网络和一个教师网络组成,其中教师的表现优于学生,正如我们在实验中将看到的那样。

  2. DINO 不是最大化相似性度量,而是最小化温度缩放的 SoftMax 输出的交叉熵损失。

[## 每当 Sascha Kirch 发布新内容时都会收到邮件 🚀

每当 Sascha Kirch 发布新内容时都会收到邮件 🚀 想要了解更多深度学习相关的内容或只是保持最新动态……

medium.com](https://medium.com/@SaschaKirch/subscribe?source=post_page-----4cb08e821b18--------------------------------)

实验

论文展示了大量的实验。他们在 ImageNet 上预训练模型,ImageNet 是一个在表征学习中常用的数据集。

对于评估,常见的技术通常要么在冻结特征上训练线性分类器,要么对模型进行微调以适应新的下游任务,在这种情况下,模型的参数会被调整。

DINO 的作者声称这些技术对超参数非常敏感,这使得比较不公平且难以重现。因此,他们建议对预训练模型的特征使用简单的最近邻聚类算法。

ImageNet 上的线性和 k-NN 分类

在这个实验中,模型在 ImageNet 上的图像分类准确性上进行了测试。测试了多种自监督预训练模型,骨干网包括 ResNet 或 ViT。分类是通过线性探测或 k-NN 聚类完成的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表 1:在 ImageNet 上的线性和 k-NN 分类。 来源 + Sascha Kirch 的注释

我认为主要的收获是:

  1. K-NN 在 ViT 特征上的表现优于 ResNet 特征。

  2. 在 ViT 中减少补丁大小比增加骨干网带来的改进更大,但代价是推理速度变慢。

视频实例分割

一个重要的实验是视频分割任务,因为论文讨论了 ViT 在用自监督方法训练时捕捉语义分割能力的特性。或者说这是论文所声称的 😁

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表 2:视频实例分割。 来源 + Sascha Kirch 的注释

观察这些结果后,我觉得还缺少两个进一步的实验:

  1. 如果能看到在 DINO 框架下监督的 ResNet50 和自监督的 ResNet50 之间的对比,将会很好,这可以支持他们关于 ViT 优于 ResNet 架构的主张。

  2. 如果能看到相同的 ViT 骨干网在监督学习和自监督学习下的效果对比,将会非常棒,这样可以观察到对补丁大小和模型大小的影响。

不过正如我总是说的:提出问题很容易 😁 在实际项目中,作者们常常面临资源限制和项目截止日期,所以不可能涵盖每一个细节!

探索自注意力图

在这个实验中,作者调查了 ViT 的多头自注意力层中不同头部的自注意力图。他们可视化了 ViT-S/8 最后一层中选定头部的注意力图,精确来说是学习到的[CLS]令牌。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4:来自选定头部的注意力图。 来源 + Sascha Kirch的注释

其他实验

在其他实验中,DINO 在与监督基线的比较中有所改进。这些任务包括图像检索和复制检测。

消融实验

在他们的消融研究中,作者对 ViT-S 模型进行了实验。

补丁大小的重要性

记住,视觉变换器输入的是一个补丁化的输入图像,将每个补丁转化为令牌,然后应用具有自注意力机制的变换器。这是 ViT 作者的一项技巧,用于减少性能权衡的计算需求,使变换器适用于图像数据。

DINO 声称,较小的补丁大小提高了性能,同时降低了吞吐量(每秒可以处理的图像数量),这正是 ViT 所声称的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:补丁大小对准确性和吞吐量的影响。 来源 + Sascha Kirch的注释

直观地说,这并不令人惊讶,因为你增加了输入分辨率,结果是需要处理更多的令牌,因此你得到一个更细粒度的注意力图。

不同的教师更新规则

DINO 中的教师通过计算从更新后的学生和当前教师的指数移动平均来更新。这就是他们所称的“动量编码器”方法。

使用动量编码器并绘制教师和学生在训练过程中的准确性,教师在整个过程中表现更好。由此我们可以假设:

  1. 教师可以为学生提供强有力的学习信号。

  2. 改进的学生由于 EMA 更新规则(共同蒸馏)使教师得到提升。

  3. 可以使用教师作为最终模型,该模型具有更好的性能,但与学生具有相同的架构,因此计算需求没有变化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6:教师性能。 来源 + Sascha Kirch的注释

他们还实验了另外 3 种更新规则:将权重从学生复制到教师,使用优化器前一个迭代的学生权重,和使用前一个时代的学生权重。

多裁剪与时间和 GPU 内存

如前所述,DINO 输入相同图像的多个裁剪视图,并将全局视图输入到教师模型中,将局部视图输入到学生模型中。在这项消融实验中,作者试验了不同数量的局部视图,并报告了对性能、训练时间和每 GPU 峰值内存的影响。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表 3:多裁剪与时间和 GPU 内存。来源 + Sascha Kirch的注释

避免崩溃

在这项消融实验中,作者评估了其稳定措施在避免崩溃解决方案中的作用:中心化和锐化。

为此,他们将交叉熵分解为熵项和 Kullback-Leibler(KL)散度项。KL 散度是两个概率分布差异的度量。如果 KL 为 0,则认为两个分布相等。

其直观的理解是:如果教师和学生的输出分布的 KL 散度在整个训练过程中保持不变,那么学生的权重更新就没有学习信号。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7:崩溃解决方案分析。来源 + Sascha Kirch的注释

批量大小的影响

一个有趣的特性是,DINO 可以用较小的批量大小进行训练,而不会大幅下降性能。这实际上是 BYOL 的一个动机,DINO 基于此论文,减少了对批量大小的依赖,相比对比自监督学习方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表 4:批量大小与准确率。来源 + Sascha Kirch的注释

类似 CLIP 和 GLIP 的对比方法提供了大量的负样本以避免崩溃解决方案。每次优化器更新步骤(因此每批次)的负样本越多,效果越好。

结论

总结来说,DINO 是一个知识蒸馏框架。它是一个视觉基础模型,利用了 ViTs 的有趣特性,并且是今天表现最好的基础模型之一 DINOv2 的前身。DINO 的框架由学生模型和教师模型组成,作用于相同图像的不同视图,并采取额外措施来处理相似性学习方法的固有不稳定性。实验表明,DINO 在各种任务上优于其他自监督预训练模型。

进一步阅读与资源

论文

与此同时,DINO 的改进版本已经发布:

  1. DINOv2: 在没有监督的情况下学习鲁棒的视觉特征

  2. Meta 的 DINOv2 博客文章

论文解读

你可能还会喜欢我其他的论文解读,涵盖了我们在本文中讨论的概念:

## CLIP 基础模型

论文总结—从自然语言监督中学习可转移的视觉模型

towardsdatascience.com ## GLIP:引入语言-图像预训练到目标检测

论文总结:基于语境的语言-图像预训练

towardsdatascience.com ## BYOL -对比自我监督学习的替代方案

论文分析—自我监督学习的新方法:Bootstrap Your Own Latent

towardsdatascience.com ## Segment Anything — 可提示的任意对象分割

论文讲解—Segment Anything

towardsdatascience.com

方向改善图学习

原文:towardsdatascience.com/direction-improves-graph-learning-170e797e94fe

有向图上的图神经网络

研究在异质图上进行消息传递时合理使用方向可以带来非常显著的提升。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 迈克尔·布朗斯坦

·发布于 Towards Data Science ·10 分钟阅读·2023 年 6 月 8 日

图神经网络(GNNs)在建模关系数据方面非常有效。然而,当前的 GNN 模型通常假设输入图是无向的,忽略了许多实际图(如社交网络、交通网络、交易网络和引用网络)固有的方向性。在这篇博文中,我们探讨了在异质图的背景下边的方向性影响,并概述了 Dir-GNN,一种针对有向图的全新消息传递方案,允许单独聚合进入和离开边。尽管其简单性,该方案在多个实际异质有向图上显著提高了性能。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于 Shutterstock。

本文由 埃曼纽尔·罗西 共同撰写,基于论文 E. Rossi et al., “边的方向性改善异质图上的学习” (2023) arXiv:2305.10498,与 贝特朗·夏尔潘捷 弗朗西斯科·迪·乔瓦尼 法布里齐奥·弗拉斯卡 斯特凡·古恩曼 [1] 合作完成。论文的代码可以在 这里找到。

许多有趣的实际图,例如在建模社交、交通、金融交易或学术引用网络时遇到的图,都是有向的。边的方向通常传达了关键的见解,否则如果仅考虑图的连接模式,这些见解将会丧失。

相反,大多数在各种图机器学习应用中取得显著进展的图神经网络(GNNs)假设输入图是无向的。多年来,使输入图成为无向图已变得非常普遍,以至于流行的 GNN 库之一 PyTorch-Geometric 在加载数据集时包含了一个通用工具函数,该函数会自动将图转换为无向图[2]。

对无向图的这种倾向源于 GNNs 的两个“原罪”。首先,无向图具有对称的拉普拉斯算子和正交特征向量,提供了傅里叶变换的自然推广,而早期的谱 GNN 依赖于此以正常运作。其次,早期用于基准测试 GNNs 的数据集主要是同质性图[3],如 Cora 和 Pubmed[4]。在这些数据集中,通过将定向图转换为无向图来忽略方向似乎是有利的,早期证据有助于巩固“无向”范式。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在同质性图(左)中,方向大多无用,这一观察导致了大多数当前的 GNNs 忽视了这一信息。相反,在异质性设置中(右),如果使用得当,方向性可以带来大幅收益(10%到 15%),正如我们在 Dir-GNN 框架中提出的那样。

我们在最近的论文[5]中挑战了这一现状,表明方向性可以在异质性设置中带来广泛的收益。

在定向图中测量同质性

图的同质性通常被测量为与节点本身具有相同标签的邻居的比例,平均遍及所有节点(节点同质性)。对于定向图,我们提出了加权节点同质性

h(S) = 1/n Σ ( Σ sᵤᵥ * I[yᵤ = yᵥ] ) / Σ sᵤᵥ

其中I表示指示函数,n是节点数量,S是一般邻接矩阵,可以选择𝐀或𝐀ᵀ,或者更高阶矩阵,例如𝐀𝐀ᵀ或𝐀²(对于定向图),或对称矩阵𝐀ᵤ= (𝐀+ 𝐀ᵀ) / 2 及其高阶对应矩阵𝐀ᵤ²(如果图被视为无向图)。

即使当 1-hop 邻居存在异质性[6]时,情况也可能在转到更远的节点时发生变化。与无向图相比,定向图中有四种不同的 2-hops,分别由矩阵𝐀²、(𝐀ᵀ)²、𝐀𝐀ᵀ和𝐀ᵀ𝐀表示,这些矩阵可以体现出不同程度的(加权)同质性。

由于 GNNs 通过多跳聚合进行操作,它们可以利用图中任何 2-hop(甚至更远的跳数)的同质性。为了获得一个全面的度量来捕捉 GNN 原则上可以利用的最大同质性,我们引入了有效同质性的概念,定义为图中任何跳数的最大加权节点同质性。

从经验上看,当将图转为无向图时,有向同质数据集的有效同质性保持不变。相反,在异质图中,这种转换平均减少了约 30%的有效同质性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们比较了多种同质和异质数据集的有向和无向扩散矩阵的加权同质性。对于异质数据集,有向图的有效同质性*(h*⁽ᵉᶠᶠ⁾)比无向图的*(h*⁽ᵉᶠᶠ⁾*)*大得多,表明有效利用方向性可能带来潜在的收益。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在合成实验中,我们再次观察到,有向的随机块模型图的有效同质性始终高于其无向对应物。有趣的是,对于较少同质的图,这一差距会扩大。

一个玩具示例

特别地,我们观察到𝐀𝐀ᵀ和𝐀ᵀ𝐀在异质图中始终出现为“最同质的矩阵”。

为了提供一个直观的理解,想象我们正在尝试预测一篇特定学术论文的出版年份,例如 Kipf & Welling 2016 年 GCN 论文,给定有向引用网络和其他论文的出版年份。考虑两种不同的 2 跳关系:一种是查看我们关注的论文 v 引用的论文的引用(由矩阵 𝐀² 的 v 行表示),另一种是查看引用与我们论文相同来源的论文(由(𝐀𝐀ᵀ)表示)。

在第一种情况下(𝐀²),我们从 GCN 论文开始,并跟随其引用两次。我们最终找到了一篇 1998 年由 Frasconi et al. 发表的论文。这篇较旧的论文并没有提供很多关于我们的 GCN 论文发布时间的有用信息,因为它时间跨度过长。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

有向引用网络的玩具示例。

在第二种情况下(𝐀𝐀ᵀ),我们从 GCN 论文开始,跟随一个引用,然后返回到引用相同来源的论文,例如 2017 年的 GAT 论文。这篇论文与我们的 GCN 论文出版年份更接近,因此提供了更好的线索。更一般地,分享更多引用的节点,如我们第二个例子中的节点,在𝐀𝐀ᵀ中的分数更高,因此对我们的最终预测贡献更大。

现在,考虑一个无向 2 跳关系(𝐀ᵤ²),这只是四种可能的 2 跳矩阵的平均值。这包括我们的第一种类型(如 Frasconi et al.),这并不是非常有用。因此,高度有用的𝐀𝐀ᵀ矩阵被较少信息的矩阵(如𝐀²)稀释,导致一个较少同质的运算符,从而总体上导致一个较不可靠的预测器。

虽然我们在示例中使用了引用网络,但这种直觉具有更广泛的适用性。在社交网络中,例如,影响者的特征更可能与那些有很多共同关注者的用户类似,由 𝐀ᵀ𝐀 表示。类似地,在交易网络中,两个账户向同一组账户汇款(由 𝐀𝐀ᵀ 捕获),很可能表现出类似的行为。

Dir-GNN:有向图神经网络

为了有效利用方向性,我们提出了 有向图神经网络(Dir-GNN)框架,它通过对节点的入邻居和出邻居进行独立聚合,将 MPNNs 扩展到有向图:

m⁽ᵏ⁾ᵢₙ = AGGᵢₙ({{x⁽ᵏ⁻¹⁾, x⁽ᵏ⁻¹⁾) : (u,v) ∈ E }})

m⁽ᵏ⁾ₒᵤₜ = AGGₒᵤₜ({{x⁽ᵏ⁻¹⁾, x⁽ᵏ⁻¹⁾) : (v,u) ∈ E }})

x⁽ᵏ⁾ = COM(x⁽ᵏ⁻¹⁾, m⁽ᵏ⁾ᵢₙ, m⁽ᵏ⁾ₒᵤₜ)

其中,聚合映射 AGGᵢₙ 和 AGGₒᵤₜ,以及组合映射 COM 是可学习的(通常是一个小的神经网络)。重要的是,AGGᵢₙ 和 AGGₒᵤₜ 可以拥有独立的参数集,以允许对入边和出边进行不同的聚合 [7]。

有趣的是,这种程序模式类似于经典 Weisefiler-Lehman 图同构测试(1-WL)对有向图的自然扩展 [8]。这一联系非常重要:在区分能力方面,我们证明了 Dir-GNN 严格比标准 MPNNs 更强大,后者要么将图转换为无向图,要么仅沿边的方向传播消息。

我们的框架也很灵活:定义特定架构(如 GCN、GraphSAGE 或 GAT)的有向对应物很容易。例如,我们可以定义 Dir-GCN 为:

𝐗⁽ᵏ⁾ = σ(𝐒ₒᵤₜ𝐗⁽ᵏ⁻¹⁾𝐖ₒᵤₜ⁽ᵏ⁾ + (𝐒ₒᵤₜ)ᵀ𝐗⁽ᵏ⁻¹⁾𝐖ᵢₙ⁽ᵏ⁾)

其中 𝐒ₒᵤₜ= Dₒᵤₜ⁻¹ᐟ² 𝐀 Dᵢₙ⁻¹ᐟ²,Dᵢₙ 和 Dₒᵤₜ 分别表示对角线的入度和出度矩阵。

我们还展示了 Dir-GNN 在多层迭代应用时,能导致更具同质性的聚合。与其他模型不同,Dir-GNN 可以访问四个 2-hop 矩阵 𝐀²、(𝐀ᵀ)²、𝐀𝐀ᵀ 和 𝐀ᵀ𝐀,并学会对它们进行不同的加权。相比之下,操作在无向图上的模型仅能访问 𝐀ᵤ²,而仅沿入边或出边传播信息的模型分别限于 (𝐀ᵀ)² 和 𝐀²。

由于 Dir-GNN 对两个方向的独立聚合,因此它是唯一一个在 𝐀𝐀ᵀ 和 𝐀ᵀ𝐀 上操作的模型,我们已经证明这两个矩阵是最具同质性的,因此最可靠的预测器。

实验结果

我们首先在一个需要方向信息的合成任务上比较了 GraphSAGE 及其有向扩展(Dir-SAGE)。结果确认,只有 Dir-SAGE(in+out)在访问到入边和出边的情况下,能够几乎完美地解决该任务。作用于无向图版本的模型表现得与随机情况相当,而仅对入边或出边的模型表现相似,准确率约为 75%。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在检查 GraphSAGE 及其 Dir-扩展在一个需要方向性信息的合成任务上的表现时,只有利用双向信息的 Dir-SAGE (in+out)才能解决该任务。

我们进一步通过消融研究验证了我们的方法,将 GCN、GraphSAGE 和 GAT 基础模型与它们的 Dir-扩展进行比较。在异质数据集上,使用方向性在所有三个基础 GNN 模型中带来了异常大的准确率提升(10%到 20%绝对提升)。此外,Dir-GNN 击败了专门为异质图设计的最先进模型。

这些结果表明,当存在时,使用边的方向可以显著提高异质图上的学习效果。相比之下,忽略方向性是非常有害的,即使是复杂的架构也无法弥补信息的丧失。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在异质图上,通过明智地使用方向性取得了新的最先进结果。

另一方面,在同质数据集上使用方向性则表现不变(甚至稍有负面影响)。这与我们的发现一致,即在我们的框架中使用方向性通常会增加异质数据集的有效同质性,而对同质数据集几乎没有影响。

总之,我们的论文展示了在 GNN 中利用方向性的好处,特别是在异质图的情况下。我们希望这些发现能引发范式转变,将方向性提升为 GNN 中的一等公民。简而言之,在使图无向之前三思而后行!

[1] 本帖标题“方向提升图学习”是对 J. Gasteiger、S. Weissenberger 和 S. Günnemann 的先前工作Diffusion improves graph learning(2019 年)的故意戏仿,该工作展示了基于扩散的图重连方案(DIGL)在同质环境中提高了 GNN 的性能。在这里,我们关注的是异质情况。

[2] 这个 Pytorch-Geometric 例程 用于加载存储在 npz 格式中的数据集。它将一些定向数据集,如 Cora-MLCiteseer-Full,自动转换为非定向版本,并且没有选项获取定向版本。

[3] 同质性 指的是节点具有类似属性(通常是标签,有时是特征)趋向于连接在一起的假设。在同质图中,一个节点的邻域看起来就像是一个节点本身,通常允许通过对邻居的简单聚合(例如,平均)来预测节点的属性。违反这一假设的图称为异质性图

[4] Cora 数据集由 Andrew McCallum 于 1990 年代末期引入,它对 GNNs 的意义相当于 MNIST Digits 数据集对 CNNs 的意义。

[5] E. Rossi 等人,“边的方向性改善了在异质图上的学习”(2023)arXiv:2305.10498。

[6] 异质性并不一定是本质上不好的。例如,考虑以下具有三类(蓝色、橙色、绿色)的玩具定向图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

查看不同邻接矩阵的兼容性矩阵(兼容性矩阵中的位置ij 捕获了从标签为 i 的节点到标签为 j 的节点的边的比例,按给定邻接矩阵加权)。同质邻接矩阵在其兼容性矩阵的对角线上的质量更高,因为它包含了标签相同的节点之间的边。而在我们的示例中,定向(𝐀)和非定向(𝐀ᵤ)的一跳都是极度异质的,而定向的两跳(𝐀𝐀ᵀ 和 𝐀ᵀ𝐀)比非定向的两跳(𝐀ᵤ²)更具同质性。

[7] 重要的是要注意,我们不是第一个处理定向图并提出单独聚合入邻居和出邻居的人。然而,我们的贡献在于提供了对定向图的更全面的处理,包括一个通用框架(Dir-GNN)、关于方向性好处的全面实证证据,特别是在异质性背景下,以及分析定向图模型表达能力的起点。有关相关先前工作的更详细概述,请参阅我们论文中的“相关工作”部分 [5]。

[8] 尽管已经提出了多个关于WL 测试在有向图上的扩展,但 M. Grohe、K. Kersting、M. Mladenov 和 P. Schweitzer 讨论的变体,色彩细化及其应用中,“An Introduction to Lifted Probabilistic Inference”(2021),MIT Press,将入邻居和出邻居分开处理。

我们感谢 Christopher Morris 和 Chaitanya K. Joshi 的深刻讨论,并指出了相关工作。有关图上深度学习的更多文章,请参见 Michael 的 其他文章 在 Towards Data Science 中, 订阅 他的文章和 YouTube 频道,获取 Medium 会员资格,或在 Twitter上关注他

Dirichlet 分布:基础直观理解及 Python 实现

原文:towardsdatascience.com/dirichlet-distribution-the-underlying-intuition-and-python-implementation-59af3c5d3ca2

关于 Dirichlet 分布你需要知道的一切

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Reza Bagheri

·发布于 Towards Data Science ·27 分钟阅读·2023 年 8 月 1 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源: pixabay.com/vectors/cubes-dice-platonic-solids-numbers-160400/

Dirichlet 分布是贝塔分布的一种推广。在贝叶斯统计中,它通常用作多项式分布的共轭先验,因此可以用来建模概率随机向量的不确定性。它具有广泛的应用,包括贝叶斯分析、文本挖掘、统计遗传学和非参数推断。本文对 Dirichlet 分布进行了直观介绍,并展示了它与多项式分布的联系。此外,还展示了如何在 Python 中建模和可视化 Dirichlet 分布。

定义

假设连续随机变量 X₁, X₂, …Xₖ (k≥2) 形成随机向量 X,定义为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们还定义了向量 α 如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

现在,如果随机向量 X 具有参数 αDirichlet 分布,则它具有以下联合 PDF:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

函数 B(α) 称为 多变量 贝塔函数,其定义为

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中 Г(x) 是伽马函数。如果随机向量 X 具有参数 α 的 Dirichlet 分布,则记作 X ~ Dir(α)。多变量贝塔函数包含在联合概率密度函数(PDF)中用于归一化。联合 PDF 应该在其定义域上积分为 1:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,我们有:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于方程 1,随机变量X₁、X₂、…Xₖ的取值应满足以下条件,以使f(x)>0:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这些条件定义了 Dirichlet 分布的支持X的支持及其分布的支持是所有x的集合(X可以取的值),其中f(x)>0。如果Xk个元素,具有 Dirichlet 分布的X的支持是一个k-1 维的单纯形。单纯形是由于方程 3 的约束而形成的有界线性流形。单纯形是三角形概念的高维推广。因此,k-1 维单纯形可以被视为一个位于k维空间中的k-1 维三角形。

例如,如果k=2,那么X的支持是图 1(左)中显示的 1 维单纯形。它是一条直线,触及每个坐标轴,距离原点 1 个单位。对于这条线上的每一点,我们有:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对于k=3,那么X的支持是图 1(右)中显示的 2 维单纯形。现在它是一个触及每个坐标轴的三角形,距离原点 1 个单位。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1(作者提供)

对于这个三角形表面上的每一个点,我们有:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

让随机向量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

具有参数的 Dirichlet 分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

然后,可以证明X的均值如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

也可以证明:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

直觉

如前所述,Dirichlet 分布通常作为多项分布的共轭先验。因此,为了理解其背后的直觉,我们首先需要回顾多项分布。假设离散随机向量X定义为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

让向量p为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

那么X被称为具有参数np的多项分布,如果它具有以下联合 PMF:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

多项分布可以用于建模一个k面骰子。假设我们有一个k面骰子,并将其掷n次。让pᵢ表示获得第i面的概率,并让随机变量Xᵢ表示第i面观察到的总次数(i=1…k)。那么随机向量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

具有参数n的多项分布

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这一点在图 2 中展示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2(作者提供的图片)

现在假设我们不知道向量 p 中的 pᵢ 值。因此,我们不知道每一面 k-面骰子的概率,我们想通过观察 n 次掷骰子的结果来推断它。p 的元素表示一些互斥事件的概率,因此我们应该有:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

p 的值可以使用 贝叶斯方法 推断。在这里,我们假设未知的概率向量 p 由连续随机向量 P 表示。P 的概率分布称为 先验分布。先验分布表示对估计参数 P 的先验知识或假设。在掷骰子之后,我们可以分析观察到的数据,并使用这些数据更新我们对 P 的信念。因此,我们得到一个新的 P 分布,这称为 后验分布。后验分布是通过用观察数据更新先验概率分布得到的。

请记住,随机向量X中的随机变量 Xᵢ 代表观察到的第 i 面的总次数。如果我们知道 p 的值,我们可以使用以下条件概率计算在 n 次掷骰子后观察到 X₁=m₁, X₂=m₂, … Xₖ=mₖ 的概率:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这个条件概率给出了在 n 次掷骰子后观察到每一面骰子特定次数的概率,假设我们知道 P 的真实值。如前所述,P 的概率分布是我们的先验分布。我们用 *f_*P(p) 来表示这个分布的联合概率密度函数。现在,我们可以使用贝叶斯定理将先验和后验联合概率密度函数连接起来:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这里 *f_*P|X (p|X=m) 是后验分布的联合概率密度函数。这个分布在观察到 X 后更新我们对 P 的信念。我们也称 P(X=m|p) 为似然函数,它可以写成一个已知 p 值的多项式分布的概率质量函数(方程 4):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

贝叶斯定理的分母是 X=m 的概率,称为 X 的边际概率质量函数:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

请注意,这与 p 的真实值无关。现在我们假设先验分布是具有参数 α₁ 的 Dirichlet 分布。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

记住,具有 Dirichlet 分布的随机变量应遵循方程 3 中的条件,这些条件与方程 5 的条件完全相同。实际上,方程 3 中的条件允许我们使用 Dirichlet 分布来表示互斥事件的概率的随机变量。

现在我们可以使用贝叶斯规则(方程 6)来写:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这里 c 是一个不依赖于 pᵢ 值的常数。后验联合 PDF 应该被归一化,因此我们有以下条件:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过将方程 7 和 8 与方程 1 和 2 进行比较,我们得出结论,后验分布是一个参数为的 Dirichlet 分布

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

c 仅仅是其归一化因子,我们得到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最后,我们可以写:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

所以,如果我们假设先验分布是 Dirichlet 分布,那么在观察到 X=m 之后的后验分布也是 Dirichlet 分布。我们只需将每个侧面观察到的数量 (mᵢ) 添加到先验分布中的相应参数 (αᵢ) 中,就能得到后验分布的参数。

在贝叶斯概率理论中,如果后验分布属于与先验分布相同的家族,那么先验和后验被称为 共轭分布。因此,我们得出结论,Dirichlet 分布是多项分布的共轭先验(图 3)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3(作者提供的图片)

Dirichlet 分布的一个特殊情况是当

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

然后我们得到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这意味着

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与其 k-1 维单纯形上的均匀分布是一样的,因为联合 PDF 在单纯形上具有相同的值。

在 Python 中建模和可视化

我们可以使用 scipy 库在 Python 中对 Dirichlet 分布进行建模。在 scipy 中,Dirichlet 分布可以通过对象 dirichlet 创建。该对象接受参数 alpha,该参数对应于方程 1 中的 α。我们也可以将 alpha 传递给此对象的方法。方法 pdf() 还接受参数 x,该参数对应于方程 1 中的 x,并返回 x 处分布的联合 PDF。我们还可以使用方法 mean()var() 计算分布的均值和方差。例如,设:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

现在我们想计算 X 的均值及其在 [0.5, 0.3, 0.2]ᵀ 的联合 PDF,使用以下代码片段:

from scipy.stats import dirichlet
dist = dirichlet([5, 5, 5])
print("PDF at [5,5,5]: ",dist.pdf([0.5, 0.3, 0.2]))
print("Mean of disitrubtion: ", dist.mean())
PDF at [5,5,5]:  5.1081030000000025
Mean of disitrubtion:  [0.33333333 0.33333333 0.33333333]

如果 x 的值在单纯形之外,pdf() 会抛出错误:

# This results in an error
dist.pdf([0.5, 0.3, 0.3])

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们可以可视化 k=2 和 3 时的 Dirichlet 联合 PDF(kX 的元素个数)。如前所述,当我们在 X 中有 3 个随机变量(具有 Dirichlet 分布)时,单纯形是一个二维三角形(图 1)。我们可以计算这个单纯形表面上联合 PDF 的轮廓,并用 重心坐标 在二维图中绘制(图 4)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4(作者提供的图像)

重心坐标是一个点在仿射空间中相对于单纯形的坐标。它们可以提供一个点相对于直线、三角形或四面体的位置,而不是全局笛卡尔坐标。在 k 维笛卡尔坐标系中,一个点的坐标可以表示为 k-1 维单纯形的边的归一化加权平均。这些权重给出了该点相对于单纯形的重心坐标。考虑图 5 中显示的二维空间。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5(作者提供的图像)

一维单纯形是端点 [0,1] 和 [0,1] 之间的线段。在这个单纯形上的任意点 p 的坐标可以表示为端点坐标的归一化加权平均:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这里 λ₁ 是 p 到端点 [0,1] 的距离除以单纯形的长度 (L)。类似地,λ₂ 是 p 到端点 [1,0] 的距离除以 L。权重 λ₁ 和 λ₂ 是 p 相对于该单纯形的重心坐标,并且由于端点距离原点只有一个单位,它们与笛卡尔坐标具有相同的值。

接下来,考虑图 6 中显示的二维单纯形。这个单纯形是由端点 [1,0,0]、[0,1,0] 和 [0,0,1] 组成的三角形。该单纯形上点 p 的坐标等于这些端点坐标的归一化加权平均:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6(作者提供的图像)

在这个三角形中,每个节点代表一个坐标轴 x₁、x₂ 或 x₃。假设我们要计算 x₁ 的值。设每条边的长度为 L(这是一个等边三角形)。为了得到 x₁ 的值,我们绘制一条经过 p 并且与不经过 x₁ 所代表的节点的边(这里是 xx₃)平行的直线。这条直线将其余的每一边(xx₂ 和 xx₃)分成两个线段。在这些边上,不包含节点 x₁ 的线段长度是 λL(见图 6)。我们可以类似地计算 λ₂ 和 λ₃ 的值。

现在,我们创建一个 Python 函数来绘制二维单纯形上 Dirichlet 分布的联合 PDF 的等高线。清单 1 导入了我们后续需要的所有库,并在二维图上定义了这个三角形单纯形的边。这些边存储在列表edges中。请注意,这个二维单纯形现在绘制在二维屏幕上,因此所有的边都是二维的。然而,点的重心坐标仍然是这些边的笛卡尔坐标的加权平均值:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这里的H是三角形的高度(图 7)。

我们使用matplotlib.tri库来创建一个三角网格。数组normal_vecs保存了这个三角形每条边的法向量(每条边的法向量都垂直于该边)。

# Listing 1

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
from scipy.stats import dirichlet, multinomial, beta
from math import pi
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
%matplotlib inline

H = np.tan(pi/3)*0.5
edges = np.array([[0, 0], [1, 0], [0.5, H]])
shifted_edges = np.roll(edges, 1, axis=0)
triangle = tri.Triangulation(edges[:, 0], edges[:, 1])

# For each edge of the triangle, the pair of other edges
edge_pairs = [edges[np.roll(range(3), -i)[1:]] for i in range(3)]
# The normal vectors for each side of the triangle
normal_vecs = np.array([[pair[0,1] - pair[1,1],
              pair[1,0] - pair[0,0]] for pair in edge_pairs])

在清单 2 中,函数cart_to_bc()将点的二维笛卡尔坐标转换为相对于edges中定义的二维三角形的重心坐标。

# Listing 2

def cart_to_bc(coords):
    '''Converts 2D Cartesian coordinates to barycentric'''
    bc_coords = np.sum((np.tile(coords, (3, 1))-shifted_edges)*normal_vecs,
                axis=1) / np.sum((edges-shifted_edges)*normal_vecs, axis=1)
    return np.clip(bc_coords, 1.e-10, 1.0 - 1.e-10)

def bc_to_cart(coords):
    '''Converts barycentric coordinates to 2D Cartesian'''
    return (edges * coords.reshape(-1, 1)).sum(axis=0) 

图 7 展示了如何进行这些计算以计算λ₃(作为示例)。如图所示,三角形的一条边(x₁)位于二维笛卡尔坐标系统的原点。我们可以用向量xp来表示这个三角形上的点p。从几何学中,我们知道

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中n是边xx₂的法向量。因此,如果我们知道点p的笛卡尔坐标、三角形的边以及每条边的法向量,我们就可以轻松计算出点p的重心坐标。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7(作者提供的图片)

需要注意的是,这个函数并不总是返回准确的重心坐标。如果重心坐标超出了区间 [1e-10 -10, 1–1e-10],则会使用numpy中的clip()函数将其裁剪到区间边界。原因将在后文中解释。

我们还有函数bc_to_cart(),它将这个三角形的重心坐标转换为笛卡尔坐标。点p的笛卡尔坐标等于三角形边的笛卡尔坐标的加权平均值,而重心坐标只是这些权重:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最后,列表 3 定义了函数plot_contours(),该函数绘制 Dirichlet 分布在这个三角形上的联合 PDF 等高线。此函数在笛卡尔 2D 空间上创建一个三角网格。接下来,计算网格上每个点的重心坐标。然后,它使用重心坐标计算该点的联合 PDF。在计算完三角形上所有点的联合 PDF 后,绘制等高线。请注意,三角网格上的某些点可能稍微超出简单边界。这意味着该点的 x₁+x₂+x₃ 可能稍微小于零或大于 1。将这样的点传递给 dirichlet 对象的 pdf() 方法会引发错误。因此,我们在 cart_to_bc() 中裁剪重心坐标以避免此错误。

# Listing 3

def plot_contours(dist, nlevels=200, subdiv=8, ax=None):
    refiner = tri.UniformTriRefiner(triangle)
    mesh = refiner.refine_triangulation(subdiv=subdiv)
    pdf_vals = [dist.pdf(cart_to_bc(coords)) for coords in zip(mesh.x, mesh.y)]
    if ax:
        contours = ax.tricontourf(mesh, pdf_vals, nlevels, cmap='jet')
        ax.set_aspect('equal')
        ax.set_xlim(0, 1)
        ax.set_ylim(0, H)
        ax.set_axis_off()
    else:
        contours = plt.tricontourf(mesh, pdf_vals, nlevels, cmap='jet')
        plt.axis('equal')
        plt.xlim(0, 1)
        plt.ylim(0, H)
        plt.axis('off')
    return contours

让我们尝试 plot_contours()。我们首先绘制等高线

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如前所述,由于联合 PDF 在单纯形上具有相同的值,因此它与其 2D 单纯形上的均匀分布相同。列表 4 绘制了联合 PDF 的等高线,结果图如图 8 所示。

# Listing 4

plt.figure(figsize=(10, 10))
contours = plot_contours(dirichlet([1, 1, 1]))
v = np.linspace(0, 3, 2, endpoint=True)
plt.colorbar(contours, ticks=[1,2,3], fraction=0.04, pad=0.1)
plt.text(0-0.02, -0.05, "$p_1$", fontsize=22)
plt.text(1-0.02, -0.05, "$p_2$", fontsize=22)
plt.text(0.5-0.02, H+0.03, "$p_3$", fontsize=22)
plt.title("Dir([1,1,1])", fontsize=22)
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 8

如您所见,联合 PDF 在整个单纯形上具有相同的值。接下来,我们绘制等高线

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作为列表 5 中的第二个示例。结果如图 9 所示。

# Listing 5

plt.figure(figsize=(10, 10))
contours = plot_contours(dirichlet([5, 5, 5]))
plt.colorbar(contours, fraction=0.04, pad=0.1)
plt.text(0-0.02, -0.05, "$p_1$", fontsize=22)
plt.text(1-0.02, -0.05, "$p_2$", fontsize=22)
plt.text(0.5-0.02, H+0.03, "$p_3$", fontsize=22)
plt.title("Dir([5,5,5])", fontsize=22)
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 9

影响 α 对联合 PDF 的影响

我们还可以创建联合 PDF 表面的 3D 图。这里我们假设 2D 单纯形位于 XY 平面,Z 轴给出 PDF 的值。列表 6 中的函数plot_surface()生成这样的图。

# Listing 6

def plot_surface(dist, ax, nlevels=200, subdiv=8, log_plot=False, **args):
    refiner = tri.UniformTriRefiner(triangle)
    mesh = refiner.refine_triangulation(subdiv=subdiv)
    pdf_vals = [dist.pdf(cart_to_bc(coords)) for coords in zip(mesh.x, mesh.y)]
    pdf_vals = np.array(pdf_vals, dtype='float64')
    if log_plot:
        pdf_vals = np.log(pdf_vals)
    ax.plot_trisurf(mesh.x, mesh.y, pdf_vals, linewidth=1, **args)

列表 7 使用此函数绘制了具有不同参数的 Dirichlet 分布的联合 PDF。图形如图 10 所示。

# Listing 7

fig = plt.figure(figsize=(15, 10))
ax1 = fig.add_subplot(231, projection='3d')
ax2 = fig.add_subplot(232, projection='3d')
ax3 = fig.add_subplot(233, projection='3d')
ax4 = fig.add_subplot(234, projection='3d')
ax5 = fig.add_subplot(235, projection='3d')
ax6 = fig.add_subplot(236, projection='3d')

ax = [ax1, ax2, ax3, ax4, ax5, ax6]
params = [[1,1,1], [1,7,1], [0.65,7,1], [5,5,5], [30,30,30], [5, 5, 30]]

for i in range(6):
    plot_surface(dirichlet(params[i]), ax[i],
                 antialiased=False, color='yellow')
    ax[i].view_init(35, -135)
    ax[i].set_title("Dir({})".format(params[i]), fontsize=16)
    ax[i].zaxis.set_rotate_label(False) 
    ax[i].set_zlabel("$f_\mathregular{P}(\mathregular{p})$", fontsize=16,
                     weight="bold", style="italic", labelpad=5, rotation=90)
    ax[i].set_xlim([-0.15, 1.1])
    ax[i].set_ylim([-0.15, 1.1])
    if i>2:
        ax[i].set_zlim([0, 65])
    ax[i].xaxis.set_ticklabels([])
    ax[i].yaxis.set_ticklabels([])
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    if i==0:
        ax[i].text(-0.15, -0.07, 2, "$p_1$", fontsize=14)
        ax[i].text(1.07, 0.03, 2, "$p_2$", fontsize=14)
        ax[i].text(0.5, H+0.15, 2, "$p_3$", fontsize=14)
    else:
        ax[i].text(-0.15, -0.07, 0, "$p_1$", fontsize=14)
        ax[i].text(1.07, 0.03, 0, "$p_2$", fontsize=14)
        ax[i].text(0.5, H+0.15, 0, "$p_3$", fontsize=14)

plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 10

这些图可以帮助您理解 α 对联合 PDF 形状的影响。具有 Dirichlet 分布的随机变量 p₁、p₂ 和 p₃ 可以表示 3 个相互排斥事件的概率。因此,单纯形的每条边表示这些事件中的一个,相应的 αᵢ 就像是该事件发生概率的权重。

如前所述,α=[1 1 1]ᵀ 意味着我们在单纯形上有均匀分布。这里,PDF 的值在单纯形上处处为 2,因此联合 PDF 具有平坦的表面。当 αᵢ 相对于其他元素增加时,这意味着第 i 个事件发生的机会更高,因为与其他事件相比,它被观察得更多(这里我们可以假设我们从 Dir([1 1 1]ᵀ) 作为先验分布开始)。一个例子是图 10 中 Dir([1 7 1]ᵀ) 的图形。现在,表面在单纯形的边缘附近升高,表示该事件。

当总和α₁+α₂+α₃增加时,意味着观察的总数量增加了。这将减少我们对P的分布的不确定性,并使 Dirichlet 分布的联合 PDF 看起来更尖锐。正如你在图 10 中看到的,Dir([30 30 30]ᵀ)相比 Dir([5 5 5]ᵀ)要尖锐得多。然而,两者在边缘上看起来都是对称的。因为所有事件被观察的次数相同。当某个αᵢ相对于其他值变大时,联合 PDF 的峰值会向表示该事件的边缘移动。这在 Dir([5 5 30]ᵀ)中得到了体现。这里第三个事件的权重(α₃)较大,意味着第三个事件被观察得更多,因此发生的概率更高。

请注意,所有的α元素应大于零,因此我们不能给事件分配零权重。然而,如果我们设置αᵢ<1,则相应事件的权重会显著下降。这在图 10 中的 Dir([0.65 7 1]ᵀ)的图示中得到了体现。如果你将其与 Dir([1 7 1]ᵀ)的图示进行比较,你会发现为了得到非零 PDF,p₁的重心坐标应非常小。这几乎像在p₂和p₃上有一个 1 维的简单形体。

列表 8 绘制了 Dirichlet 分布的联合 PDF 的对数尺度图(以更好地展示联合 PDF 表面的变化)。结果如图 11 所示。

# Listing 8

fig = plt.figure(figsize=(15, 10))

ax1 = fig.add_subplot(121, projection='3d')
ax2 = fig.add_subplot(122, projection='3d')

ax = [ax1, ax2]
params = [[0.2, 0.2, 0.2], [0.8,0.8,0.8], [0.2,0.5,1]]

for i in range(2):
    plot_surface(dirichlet(params[i]), ax[i], log_plot=True, cmap='jet')
    ax[i].view_init(10, -135)
    ax[i].set_title("Dir({})".format(params[i]), fontsize=20)
    ax[i].zaxis.set_rotate_label(False) 
    ax[i].set_zlabel("$log(f_\mathregular{P}(\mathregular{p}))$",
                     fontsize=18, weight="bold", style="italic",
                     labelpad=5, rotation=90)
    ax[i].set_xlim([-0.15, 1.1])
    ax[i].set_ylim([-0.15, 1.1])
    ax[i].set_zlim([0, 17])
    ax[i].xaxis.set_ticklabels([])
    ax[i].yaxis.set_ticklabels([])
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    ax[i].text(-0.09, -0.07, 0, "$p_1$", fontsize=14)
    ax[i].text(1.07, 0.03, 0, "$p_2$", fontsize=14)
    ax[i].text(0.5, H+0.22, 0, "$p_3$", fontsize=14)

plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 11

请注意,当所有αᵢ小于 1 时,联合 PDF 有一个凸面的表面。PDF 在三角形简单形体的边缘和侧面几乎非常小。它几乎像在三角形的边上有三个 1 维的简单形体。因此,具有这种分布的先验表示一个设置,其中一个或两个pᵢ非常小,它们对应的事件发生的概率很小。

通过比较 Dir([0.2 0.2 0.2]ᵀ)和 Dir([0.8 0.8 0.8]ᵀ),你会发现增加αᵢ的值倾向于使联合 PDF 的表面变平。因此,它减少了边缘和侧面的联合 PDF 值,并增加了在简单形体中部区域的值。

最后,需要注意的是,Dirichlet 分布的参数也可以是非整数的。但例如 Dir([1.65 6 20]ᵀ)是什么意思呢?在这里,我们可以将参数的小数部分分配给先验分布。例如,我们可以将其写成 Dir([0.65+1 1+5 7+13]ᵀ)。这意味着我们从 Dir([0.65 1 7]ᵀ)作为先验分布开始(Dir([0.65 1 7]ᵀ)的联合 PDF 如图 10 所示)。选择这个先验分布意味着我们最初认为p₁几乎为零,它对应的事件发生的可能性非常小。然后我们观察到第一个事件只发生了一次,而第二个和第三个事件分别发生了 5 次和 13 次。这些数字被加到先验分布的参数中,形成了后验分布。

Python 中的贝叶斯推断

现在我们可以绘制轮廓图了,我们可以使用狄利克雷分布来推断多项式分布的参数分布。假设我们有一个 3 面的骰子(当然,它也可以是一个 6 面的骰子,只是上面有 3 个标签(1、2 和 3),每个标签出现在两个面上)。设获得面i的概率为pᵢXᵢ表示观察到面i的总次数(i=1…3)。如前所述,随机向量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

具有参数n的多项式分布

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

设实际的p值为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,这不是一个公平的骰子!我们可以使用scipy中的多项式对象来建模这个分布。以下代码片段显示了掷这个骰子 10 次的结果:

p_act = np.array([0.6, 0.2, 0.2])
sample = multinomial.rvs(n=10, p=p_act, random_state=1)
sample
array([6, 3, 1])

因此,如果我们掷 10 次,我们会得到以下观察结果:

  • 面 1:6 次

  • 面 2:3 次

  • 面 3:仅 1 次

当然,这些是一些随机事件,因此如果我们在rvs()中更改random_state,我们可以获得不同的观察结果(我们固定random_state以使这个特定的观察结果可重复)。

现在假设我们不知道每一面的概率,因此实际的向量p(如方程 10 所示)的值未知。然而,我们仍然可以掷这个骰子n次并观察结果,因此,我们知道Xᵢ的值。如果我们假设未知的概率向量p由随机向量P表示,我们可以使用狄利克雷分布来推断掷骰子后的P的概率分布。

列表 9 首先生成掷骰子n次的结果并存储在m中。然后计算f_P|X (p|X=m),这是后验分布的联合 PDF,并在二维简单形上绘制其轮廓。我们尝试了 5 个不同的n值,范围从 3 到 10000,图 12 展示了这些图形。我们以 Dir([1 1 1]ᵀ)作为P的先验分布。因此,最初我们对P有均匀分布,其中不同的P值是同样可能的。因此,我们对P有最大的未知。

用于生成观察数据的实际P值(方程 10)在这些图中用白色标记显示。随着n的增加,我们获得了更多的观察数据,我们对P的未知性减少。通过增加n,狄利克雷分布从最初的均匀分布变得更加尖锐,并更接近表示***p_***act 的白色标记。

# Listing 9

p_act_coords = bc_to_cart(p_act)

alpha_prior = [1, 1, 1]
number_rolls = [3, 15, 50, 500, 10000]
num_cols = 2

fig, axes = plt.subplots(3, num_cols, figsize=(16, 25))
plt.subplots_adjust(wspace=0.2, hspace=0.05)

contours = plot_contours(dirichlet(alpha_prior), ax=axes[0, 0])
axes[0, 0].set_title("Prior distribution", fontsize=22, pad=50)
axes[0, 0].scatter(p_act_coords[0],
                   p_act_coords[1],
                   s=300, color='white',
                   marker='+')
axes[0, 0].text(0-0.02, -0.05, "$p_1$", fontsize=16)
axes[0, 0].text(1-0.02, -0.05, "$p_2$", fontsize=16)
axes[0, 0].text(0.5-0.02, H+0.05, "$p_3$", fontsize=16)
divider = make_axes_locatable(axes[0, 0])
cax = divider.append_axes('right', size='2%', pad=0.2)
cbar = fig.colorbar(contours, cax=cax)

for i in range(1, 6):
    m= multinomial.rvs(n=number_rolls[i-1], p=p_act, random_state=0)
    contours = plot_contours(dirichlet(m + alpha_prior),
                             ax=axes[i // num_cols, i % num_cols])
    axes[i//num_cols, i%num_cols].set_title("n={}".format(number_rolls[i-1]),
                                            fontsize=22, pad=50)
    axes[i//num_cols, i%num_cols].scatter(p_act_coords[0],
                                          p_act_coords[1],
                                          s=300, color='white',
                                          marker='+')
    axes[i//num_cols, i%num_cols].text(0-0.02, -0.05,
                                       "$p_1$", fontsize=16)
    axes[i//num_cols, i%num_cols].text(1-0.02, -0.05,
                                         "$p_2$", fontsize=16)
    axes[i//num_cols, i%num_cols].text(0.5-0.02, H+0.05,
                                         "$p_3$", fontsize=16)
    divider = make_axes_locatable(axes[i // num_cols, i % num_cols])
    cax = divider.append_axes('right', size='2%', pad=0.2)
    cbar = fig.colorbar(contours, cax=cax)

plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 12

与贝塔分布的关系

让随机向量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

具有参数的狄利克雷分布

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于方程 1,X的联合 PDF 是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由于我们有

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们可以从 PDF 中去除 x₂:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如你所见,X₁ 和 X₂ 的联合 PDF 仅是 x₁ 的函数。因此,随机向量 X 由单一随机变量 X₁ 决定,这意味着上式右侧也是随机变量 X₁ 的 PDF。所以,我们可以写成:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

具有这种 PDF 的连续随机变量称为具有参数 α₁ 和 α₂ 的 贝塔分布,我们用 X₁ ~ Beta(α₁, α₂) 来表示。类似地,我们可以用 x₂ 表达 PDF:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,我们得出结论 X₂ ~ Beta(α₂, α₁),并得出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

X₁ 和 X₂ 的分布被称为 X边际分布。当 α 仅有两个元素,并且我们仅考虑 X 中的一个随机变量时,贝塔分布是狄利克雷分布的特例。因此,它是一个单变量分布。列表 10 绘制了 Dir([5 1]ᵀ) 的联合 PDF 以及其边际分布的 PDF:Beta(5,1) 和 Beta(1,5)。这些图示见图 13。

# Listing 10

N = 1000
simplex_edges = np.array([[1,0], [0,1]])
tol=1e-6
gamma1 = np.linspace(tol, 1-tol, N)
gamma2 = 1-gamma1
bc_coords = np.stack((gamma1, gamma2), axis=-1)
cart_coords = gamma1.reshape(-1,1)*simplex_edges[0] + \
              gamma2.reshape(-1,1)*simplex_edges[1]
alpha = [5, 1]
pdf = [dirichlet(alpha).pdf(x) for x in bc_coords]

x = np.arange(0, 1.01, 0.01)
param_list = [(1,1), (2,2), (5,1)]
beta_dist1 = beta.pdf(x=x, a=alpha[0], b=alpha[1])
beta_dist2 = beta.pdf(x=x, a=alpha[1], b=alpha[0])

fig = plt.figure(figsize=(15, 15))
plt.subplots_adjust(wspace=0.2, hspace=0.1)
gs = gridspec.GridSpec(2, 2, width_ratios=[2.5, 1],
                       height_ratios=[1, 2.5])
ax1 = fig.add_subplot(221, projection='3d')
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)

ax1.plot(simplex_edges[:,0], simplex_edges[:,1],
         [0,0], color = 'gray', label='1-d Simplex')
ax1.plot(cart_coords[:,0], cart_coords[:,1], pdf, color = 'black',
         label='Dir([{},{}])'.format(alpha[0], alpha[1]))
ax1.plot(x, [0]*len(x), beta_dist1, color = 'blue',
         label='Beta({},{})'.format(alpha[0], alpha[1]))
ax1.plot([0]*len(x), x, beta_dist2, color = 'green',
         label='Beta({},{})'.format(alpha[1], alpha[0]))

ax1.view_init(25, -135)
ax1.set_xlabel("$x_1$", fontsize=18)
ax1.set_ylabel("$x_2$", fontsize=18, labelpad= 9)
ax1.set_zlabel("$f_\mathregular{X}(\mathregular{x})$", fontsize=18,
               weight="bold", style="italic",
               labelpad= 2, rotation = 45)
ax1.set_xlim([0, 1])
ax1.set_ylim([0, 1])
ax1.set_zlim([0, 6])
ax1.grid(False)
ax1.legend(loc='best', fontsize= 14)

ax2.plot(x, beta_dist1, label='Beta({},{})'.format(alpha[0],
         alpha[1]), linewidth=2, color='blue')
ax2.set_xlabel('$x_1$', fontsize=18)
ax2.set_ylabel('$f_{X_1}(x_1)$', fontsize=18)
ax2.legend(loc='upper left', fontsize= 16)
ax2.set_xlim([0,1])
ax2.tick_params(axis='both', which='major', labelsize=12)

ax3.plot(x, beta_dist2, label='Beta({},{})'.format(alpha[0],
         alpha[1]), linewidth=2, color='blue')
ax3.set_xlabel('$x_2$', fontsize=18)
ax3.set_ylabel('$f_{X_2}(x_2)$', fontsize=18)
ax3.legend(loc='upper right', fontsize= 16)
ax3.set_xlim([0,1])
ax3.tick_params(axis='both', which='major', labelsize=12)
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 13

请注意,方程 11 中 Dir([αα₂]ᵀ) 的联合 PDF 与方程 12 中其边际分布 (Beta(α₁, α₂)) 的 PDF 相同。然而,它们并不表示相同的分布。前者是随机向量 X 的联合 PDF,而后者是随机变量 X₁ 的 PDF。如图 13 所示,边际分布的 PDFs 是联合 PDF 在由坐标轴 (x₁, f(x)) 和 (x₂, f(x)) 形成的平面上的投影。

请记住,多项分布可以用于建模一个 k 面的骰子。当 k=2 时,骰子变成了硬币。现在 X₁ 可以表示在 n 次掷硬币中的正面总数。类似地,X₂ 表示反面总数。从方程 4,我们得到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由于 x₁+x₂=np₁+p₂=1,我们可以从上述方程中消去 p₂ 和 x₂:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

现在随机向量 X 由单一随机变量 X₁ 决定,这意味着上式右侧也是随机变量 X₁ 的 PDF。因此,X₁ 的 PDF 可以写成:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这是 binomial 分布的 PDF。binomial 分布是 multinomial 分布的特例,当随机向量 X 只有一个元素时(即 multinomial 分布的边际分布)。因此 X₁ 具有参数 np₁ 的 binomial 分布。类似地,X₂ 具有参数 np₂ 的 binomial 分布。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由于 beta 分布和 binomial 分布分别是 Dirichlet 和 multinomial 分布的特例,它们仍然是共轭分布。实际上,beta 分布是 binomial 分布的共轭先验,如图 14 所示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 14(作者提供的图片)

假设我们有一枚硬币,其正面朝上的概率为未知的 p。令随机变量 P 表示未知的概率 p,随机变量 X 表示 n 次掷硬币中正面的总数。假设 P 的概率分布是 Beta(a, b)(这是我们的先验分布)。现在如果我们掷硬币 n 次,观察到 X=k,则 P 的后验分布是 Beta(a+k, b+n-k)。

聚合性质

令随机向量 X 具有以下 Dirichlet 分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们从X中移除随机变量 XᵢX_j,并将 Xᵢ+X_j 插入到任意位置,得到的结果随机向量称为 X’。可以证明,X’ 具有以下 Dirichlet 分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,为了创建新 Dirichlet 分布中参数的向量,首先,我们移除 XᵢX_j 对应的参数 (αᵢα_j),然后在 Xᵢ +X_j 被插入到 X 的相同位置时,插入 αᵢ+α_jαᵢ+α_jXᵢ +X_j 在它们对应的向量中的索引是相同的)。聚合性质的证明见附录。

我们来看一个例子。设 X ~ Dir([1 5 3]ᵀ)。利用聚合性质,我们有:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

列表 11 显示了所有这些分布的联合 PDF。该图如图 15 所示。在此图中,每个聚合随机向量 [Xᵢ X_j+X_k]ᵀ 具有 1 维单纯形。这里,我们假设该单纯形沿着经过 Xᵢ 的三角形的高度。

# Listing 11

N = 1000
alpha = [1, 5, 3]

edges_marg_x1 = np.array([[0,0], [0.75,0.5*np.cos(pi/6)]])
edges_marg_x2 = np.array([[1,0], [0.25,0.5*np.cos(pi/6)]])
edges_marg_x3 = np.array([[0.5,H], [0.5,0]])
tol=1e-6
gamma1 = np.linspace(tol, 1-tol, N)
gamma2 = 1-gamma1
bc_coords = np.stack((gamma1, gamma2), axis=-1)
marg_x1_cart_coords = gamma1.reshape(-1,1)*edges_marg_x1[0] + \
                      gamma2.reshape(-1,1)*edges_marg_x1[1]
marg_x2_cart_coords = gamma1.reshape(-1,1)*edges_marg_x2[0] + \
                      gamma2.reshape(-1,1)*edges_marg_x2[1]
marg_x3_cart_coords = gamma1.reshape(-1,1)*edges_marg_x3[0] + \
                      gamma2.reshape(-1,1)*edges_marg_x3[1]

alpha_agg1 = [alpha[0], alpha[1]+alpha[2]]
alpha_agg2 = [alpha[1], alpha[0]+alpha[2]]
alpha_agg3 = [alpha[2], alpha[0]+alpha[1]]

pdf1 = [dirichlet(alpha_agg1).pdf(x) for x in bc_coords]
pdf2 = [dirichlet(alpha_agg2).pdf(x) for x in bc_coords]
pdf3 = [dirichlet(alpha_agg3).pdf(x) for x in bc_coords]

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
plot_surface(dirichlet(alpha), ax, antialiased=False,
             color='yellow', alpha=0.15)
ax.plot([1,0.5], [0, H], [0, 0], "--", color='black')

ax.plot(marg_x1_cart_coords[:,0], marg_x1_cart_coords[:,1],
        pdf1, color = 'black', zorder=10,
        label="$[x_1, x_2+x_3]$ ~ Dir([{},{}])".format(alpha_agg1[0],
        alpha_agg1[1]))
ax.plot(marg_x2_cart_coords[:,0], marg_x2_cart_coords[:,1],
        pdf2, color = 'blue', zorder=12,
        label="$[x_2, x_1+x_3]$ ~ Dir([{},{}])".format(alpha_agg2[0],
        alpha_agg2[1]))
ax.plot(marg_x3_cart_coords[:,0], marg_x3_cart_coords[:,1],
        pdf3, color = 'red', zorder=10,
        label="$[x_3, x_1+x_2]$ ~ Dir([{},{}])".format(alpha_agg3[0],
        alpha_agg3[1]))

ax.view_init(30, -130)
ax.set_title("Dir([{},{},{}])".format(alpha[0], alpha[1],
             alpha[2]), fontsize=18)
ax.zaxis.set_rotate_label(False) 
ax.set_zlabel("$f_\mathregular{X}(\mathregular{x})$", fontsize=18,
               weight="bold", style="italic", labelpad=15)
ax.set_zlim([0, 17])

ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
ax.set_xticks([])
ax.set_yticks([])
ax.legend(loc='best', fontsize=15)

ax.text(-0.06, -0.03, 0, "$x_1$", fontsize=17)
ax.text(1.03, 0.03, 0, "$x_2$", fontsize=17)
ax.text(0.5, H+0.09, 0, "$x_3$", fontsize=17)

plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 15

边际分布

现在,利用聚合性质,当 X 具有超过 2 个元素时,我们可以找到 Dirichlet 分布的边际分布。令 X 具有 Dirichlet 分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们可以对除 X₁ 外的 X 中的所有元素重复应用聚合性质,得到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们可以将前面的方程写成

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

更一般地,我们可以为每个元素X写出相同的方程:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,根据方程 13,每个 Xᵢ 的边际分布是以下贝塔分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这篇文章中,我们回顾了狄利克雷分布。我们展示了它是多项分布的共轭先验,并且由于这一重要特性,它可以用来推断多项分布的参数。我们还展示了如何在 Python 中对其进行建模以及如何可视化其联合 PDF。最后,我们看到贝塔分布和狄利克雷分布之间的联系,并展示了狄利克雷分布是贝塔分布在更高维度上的推广。

我希望你喜欢阅读这篇文章。如果你有任何问题或建议,请告诉我。本文中的所有代码清单可以从 GitHub 上以 Jupyter Notebook 形式下载,网址为:

github.com/reza-bagheri/probability_distributions/blob/main/dirichlet_distribution.ipynb

通过物理信息神经网络和符号回归发现微分方程

原文:towardsdatascience.com/discovering-differential-equations-with-physics-informed-neural-networks-and-symbolic-regression-c28d279c0b4d

一个逐步代码实现的案例研究

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Shuai Guo

·发表于Towards Data Science ·阅读时间 25 分钟·2023 年 7 月 28 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Steven Coffey拍摄,来源于Unsplash

微分方程作为一个强大的框架,用于捕捉和理解物理系统的动态行为。通过描述变量之间如何变化,它们提供了对系统动态的见解,并允许我们对系统未来的行为进行预测。

然而,我们在许多实际系统中面临的一个共同挑战是,它们的控制微分方程通常仅部分已知,未知的方面以几种方式表现出来:

  • 微分方程的参数是未知的。例如在风工程中,流体动力学的控制方程已被很好地建立,但与湍流流动相关的系数非常不确定。

  • 微分方程的函数形式是未知的。例如,在化学工程中,由于速率决定步骤和反应途径的不确定性,速率方程的确切函数形式可能没有完全理解。

  • 函数形式参数都是未知的。一个典型的例子是电池状态建模,其中常用的等效电路模型仅部分捕捉了电流-电压关系(因此缺失物理的函数形式是未知的)。此外,模型本身包含未知的参数(即电阻和电容值)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1. 许多实际动态系统的控制方程仅部分已知。(图片由本博客作者提供)

对主方程的这种部分了解阻碍了我们对这些动力系统的理解和控制。因此,根据观察数据推断这些未知组件成为动力系统建模中的关键任务。

广义而言,使用观察数据恢复动力系统的主方程的过程属于系统识别的范畴。一旦发现这些方程,我们可以轻松地利用这些方程预测系统的未来状态,告知系统的控制策略,或通过分析技术进行理论研究。

最近,Zhang et al.(2023)提出了一种有前景的策略,该策略利用物理信息神经网络(PINN)和符号回归来发现常微分方程(ODEs)系统中的未知量。虽然他们的重点是发现用于阿尔茨海默病建模的微分方程,但他们提出的解决方案对一般动力系统也具有潜力。

在这篇博客文章中,我们将更深入地了解作者提出的概念,并动手重现论文中的一个案例研究。为此,我们将从零开始构建一个 PINN,利用PySR 库进行符号回归,并讨论获得的结果。

如果你对物理信息神经网络的最佳实践感兴趣,欢迎查看我的博客系列:

物理信息神经网络:以应用为中心的指南

揭示物理信息神经网络的设计模式

记住这一点,让我们开始吧!

目录

· 1. 案例研究 · 2. 为什么传统方法不够有效? · 3. PINN 在系统识别中的应用(理论) · 4. PINN 在系统识别中的应用(代码)

∘ 4.1 定义架构

∘ 4.2 定义 ODE 损失

∘ 4.3 定义梯度下降步骤

∘ 4.4 数据准备

∘ 4.5 PINN 训练

· 5. 符号回归

∘ 5.1 PySR 库

∘ 5.2 实施

∘ 5.3 识别结果

·6. 总结

· 参考文献

1. 案例研究

让我们开始介绍我们旨在解决的问题。在这篇博客中,我们将重现Zhang et al原始论文中的第一个案例研究,即从数据中发现 Kraichnan-Orszag 系统。该系统由以下 ODEs 描述:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

具有初始条件 u₁(0)=1,u₂(0)=0.8,u₃(0)=0.5。Kraichnan-Orszag 系统通常用于湍流研究和流体动力学研究,其目标是对湍流及其结构和动态发展理论见解。

为了模拟一个典型的系统识别设置,我们假设我们对控制常微分方程的了解仅限于部分已知。具体来说,我们假设我们对 u₁ 和 u₂ 的微分方程一无所知。此外,我们假设我们只知道 u₃ 的微分方程右侧是 u₁ 和 u₂ 的线性变换。然后,我们可以将常微分方程系统重写如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中 f₁ 和 f₂ 代表未知函数,ab 是未知参数。我们的目标是校准 ab 的值,并估计 f₁ 和 f₂ 的解析函数形式。 本质上,我们正面临一个具有未知参数和函数形式的复杂系统识别问题。

2. 为什么传统方法会失败?

在传统的系统识别范式中,我们通常使用数值方法(例如,欧拉法、龙格-库塔法等)来模拟和预测系统状态 u₁、u₂ 和 u₃。 然而,这些方法从根本上受限,因为它们通常需要完整的控制微分方程形式,并且无法处理微分方程仅部分已知的情况。

在方程参数未知的情况下,传统方法通常诉诸于优化技术,其中对参数进行初步猜测,然后通过迭代过程来优化,以最小化观察数据与数值求解器预测数据之间的差异。由于每次优化迭代都需要运行一次数值求解器,这种方法虽然可行,但计算开销可能非常大。

请注意,上述讨论仅描述了校准未知参数的情况。当我们需要估计微分方程中的未知函数时,问题变得更加复杂。从理论上讲,我们可以采用类似的方法,即在优化之前对未知函数的形式做出假设。然而,如果我们走这条路,会立即出现问题:如果我们假设一个过于简单的形式,我们面临欠拟合的风险,这可能导致较大的预测误差。另一方面,如果我们假设一个过于复杂的形式(例如,具有许多可调参数),我们面临过拟合的风险,这可能导致较差的泛化性能。

总之,传统方法在处理部分已知微分方程时面临重大挑战:

1️⃣ 传统数值方法依赖于具有完整控制微分方程的形式来进行模拟。

2️⃣ 将传统数值方法与优化算法结合可以解决参数估计问题,但通常代价很高。

3️⃣ 对于嵌入微分方程中的未知函数进行估计时,传统方法可能会得到对假设函数形式高度敏感的结果,这会导致欠拟合或过拟合的风险。

鉴于这些挑战,传统方法在处理未知参数和函数形式共存的系统识别问题时往往效果不佳。这自然引出了物理信息神经网络(PINNs)的话题。在下一节中,我们将看到 PINN 如何有效地解决传统方法面临的挑战。

3. PINN 在系统识别中的应用(理论)

物理信息神经网络(简称 PINN)是Raissi 等人在 2019 年提出的一个强大概念。PINN 的基本思想,像其他物理信息机器学习技术一样,是创建一个混合模型,其中在模型训练中利用了观察数据和已知的物理知识(以微分方程形式表示)。PINN 最初被设计为一个高效的 ODE/PDE 求解器。然而,研究人员很快认识到 PINN 在解决逆问题和系统识别问题上(可以说)具有更大的潜力。

在接下来的内容中,我们将逐一解释如何利用 PINN 克服我们在上一节讨论的挑战。

1️⃣ 传统数值方法依赖于拥有完整形式的主控微分方程来进行模拟。

📣PINN 的响应:与传统方法不同,我能够处理部分已知的微分方程,因此不受完整方程的限制来进行模拟。

从外部角度看,PINN 仅仅类似于一个传统的神经网络模型,该模型将时间/空间坐标(例如,txy)作为输入,并输出我们试图模拟的目标量(例如,速度u,压力p,温度T等)。然而,使 PINN 与传统 NN 不同的是,在 PINN 中,微分方程作为训练过程中的约束。具体来说,PINN 引入了一个额外的损失项,用于计算主控微分方程的残差,该残差通过将预测量代入主控方程计算得到。通过优化这个损失项,我们有效地使训练后的网络意识到潜在的物理规律。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2. 物理信息神经网络将微分方程纳入损失函数中,因此有效地使训练后的网络意识到潜在的物理规律。(图像由本博客作者提供)

由于微分方程仅用于构建损失函数,因此它们对 PINN 模型结构没有影响。这实际上意味着我们在训练时不需要对微分方程有完全的了解。即使我们只知道方程的一部分,这些知识仍然可以被纳入以强制输出遵循已知的物理规律。这种适应知识完整度不同的灵活性相比传统数值方法具有显著优势。

2️⃣ 结合传统数值方法与优化算法可以解决参数估计问题,但通常代价较高。

📣PINN 的回应:我可以提供一种计算上高效的替代方案来估计未知参数。

与将参数估计视为单独优化任务的传统方法不同,PINNs 将这一过程无缝地集成到模型训练阶段。在 PINNs 中,未知参数被简单地视为额外的可训练参数,这些参数在训练过程中与其他神经网络的权重和偏差一起优化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3. 未知参数与 PINN 的权重和偏差一起优化。在训练结束时,我们得到的最终值 ab 作为未知参数的估计值。(图片由本博客作者提供)

此外,PINNs 充分利用现代深度学习框架来执行训练。这允许快速计算所需的梯度(即通过自动微分),以用于高级优化算法(例如 Adam),从而大大加速了参数估计过程,尤其是对于高维参数空间的问题。这些因素使得 PINNs 成为参数估计问题的一个有竞争力的替代方案。

3️⃣ 对于嵌入微分方程中的未知函数,传统方法可能会得到对假设函数形式高度敏感的结果,这会产生欠拟合或过拟合的风险。

📣PINN 的回应:未知函数可以通过额外的神经网络有效地参数化,这些神经网络可以与我一起训练,就像之前的参数估计场景一样。

我们可以用独立的神经网络来逼近未知函数,然后将它们集成到主 PINN 模型中。就像在之前的参数估计场景中一样,我们可以将这些额外的神经网络视为需要估计的大量未知参数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4。未知函数可以通过一个独立的神经网络进行参数化,并与原始 PINN 一起训练。ODE/PDE 残差损失项对辅助神经网络进行正则化,以满足控制方程。这样,辅助神经网络可以直接从数据中自动学习最佳的函数形式。(图像来源于本博客作者)

在训练过程中,这些辅助神经网络的权重和偏差将与原始 PINN 同时训练,以最小化损失函数(数据损失 + ODE 残差损失)。通过这种方式,这些辅助神经网络可以直接从数据中学习最佳的函数形式。通过消除对函数形式进行风险假设的需要,这种策略有助于缓解欠拟合和过拟合的问题。

总结来说,PINN 的优势在于其能够处理部分已知的微分方程,并有效地从数据中学习未知参数和函数形式。这种多功能性使其与传统方法区别开来,因此成为系统识别任务的有效工具。

在下一节中,我们将开始处理我们的案例研究,并将理论转化为实际代码。

4. PINN 用于系统识别(代码)

在本节中,我们将实现一个 PINN(在 TensorFlow 中)来解决我们的目标案例研究。让我们从导入必要的库开始:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

import tensorflow as tf
from tensorflow import keras
tf.random.set_seed(42)

4.1 定义架构

对于主要的 PINN,我们使用一个神经网络来预测u,其具有 1 维输入(即t)和 3 维输出(u₁、u₂和u₃)。此外,如前一节所讨论的,我们使用一个辅助神经网络来逼近未知函数f₁和f₂,该网络具有 4 维输入(即tu₁、u₂和u₃)和 2 维输出(f₁和f₂)。整体 PINN 的架构如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5。所使用的 PINN 模型的架构。(图像来源于本博客作者)

值得再次强调的是,需要向辅助神经网络提供所有可用的特征(在我们当前的情况下,tu₁、u₂和u₃),因为我们不知道f₁和f₂的确切函数形式。在训练过程中,辅助神经网络将以数据驱动的方式自动确定哪些特征是必要的/重要的。

首先,让我们定义一个预测u的神经网络。在这里,我们使用两个隐藏层,每个层配备 50 个神经元和双曲正切激活函数:

def u_net(u_input):
    """Definition of the network for u prediction.

    Args:
    ----
    u_input: input for the u-net

    Outputs:
    --------
    output: the output of u-net
    """

    hidden = u_input
    for _ in range(2):
        hidden = tf.keras.layers.Dense(50, activation="tanh")(hidden)
    output = tf.keras.layers.Dense(3)(hidden)

    return output

接下来,我们定义一个预测f的辅助神经网络。我们采用相同的网络架构:

def f_net(f_inputs, a_init=None, b_init=None):
    """Definition of the network for f prediction.

    Args:
    ----
    f_inputs: list of inputs for the f-net
    a_init: initial value for parameter a
    b_init: initial value for parameter b

    Outputs:
    --------
    output: the output of f-net
    """

    hidden = tf.keras.layers.Concatenate()(f_inputs)
    for _ in range(2):
        hidden = tf.keras.layers.Dense(50, activation="tanh")(hidden)
    output = tf.keras.layers.Dense(2)(hidden)
    output = ParameterLayer(a_init, b_init)(output)

    return output

在上述代码中,我们将ab添加到神经网络模型参数的集合中。这样,ab可以与神经网络的其他权重和偏差一起优化。我们通过定义一个自定义层ParameterLayer实现了这一目标:

class ParameterLayer(tf.keras.layers.Layer):

    def __init__(self, a, b, trainable=True):
        super(ParameterLayer, self).__init__()
        self._a = tf.convert_to_tensor(a, dtype=tf.float32)
        self._b = tf.convert_to_tensor(b, dtype=tf.float32)
        self.trainable = trainable

    def build(self, input_shape):
        self.a = self.add_weight("a", shape=(1,), 
                                 initializer=tf.keras.initializers.Constant(value=self._a),
                                 trainable=self.trainable)
        self.b = self.add_weight("b", shape=(1,), 
                                 initializer=tf.keras.initializers.Constant(value=self._b),
                                 trainable=self.trainable)

    def get_config(self):
        return super().get_config()

    @classmethod
    def from_config(cls, config):
        return cls(**config)

注意,这一层除了引入这两个参数作为模型属性外没有其他作用。

最后,我们将 u-net 和 f-net 结合在一起,定义完整的 PINN 架构:

def create_PINN(a_init=None, b_init=None, verbose=False):
    """Definition of a PINN.

    Args:
    ----
    a_init: initial value for parameter a
    b_init: initial value for parameter b
    verbose: boolean, indicate whether to show the model summary

    Outputs:
    --------
    model: the PINN model
    """

    # Input
    t_input = tf.keras.Input(shape=(1,), name="time")

    # u-NN
    u = u_net(t_input)

    # f-NN
    f = f_net([t_input, u], a_init, b_init)

    # PINN model
    model = tf.keras.models.Model(inputs=t_input, outputs=[u, f])

    if verbose:
        model.summary()

    return model

在上述代码中,我们将输入 tu-net 输出 u₁, u₂, 和 u₃ 进行串联,然后输入到 f-net 中。此外,我们在整体 PINN 模型中输出 uf。虽然在实际应用中只需要 u(因为 u 是我们的建模目标),但后续 f 的预测会变得有用,以提取其分析函数形式(见第五部分)。

4.2 定义 ODE 损失

接下来,我们定义计算 ODE 残差损失的函数。回顾一下,我们的目标 ODEs 是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,我们可以按如下方式定义函数:

@tf.function
def ODE_residual_calculator(t, model):
    """ODE residual calculation.

    Args:
    ----
    t: temporal coordinate
    model: PINN model

    Outputs:
    --------
    ODE_residual: residual of the governing ODE
    """

    # Retrieve parameters
    a = model.layers[-1].a
    b = model.layers[-1].b

    with tf.GradientTape() as tape:
        tape.watch(t)
        u, f = model(t)

    # Calculate gradients
    dudt = tape.batch_jacobian(u, t)[:, :, 0]
    du1_dt, du2_dt, du3_dt = dudt[:, :1], dudt[:, 1:2], dudt[:, 2:]

    # Compute residuals
    res1 = du1_dt - f[:, :1]
    res2 = du2_dt - f[:, 1:]
    res3 = du3_dt - (a*u[:, :1]*u[:, 1:2] + b)
    ODE_residual = tf.concat([res1, res2, res3], axis=1)

    return ODE_residual

虽然上述代码大部分是自解释的,但有几个问题值得提及:

  • 我们使用了 tf.GradientTape.batch_jacobian()(而不是通常的 GradientTape.gradient())来计算 u₁, u₂ 和 u₃ 相对于 t 的梯度。GradientTape.gradient() 在这里不起作用,因为它计算的是 du₁/dt + du₂/dt + du₃/dt。我们也可以在这里使用 GradientTape.jacobian() 来计算每个输出值相对于每个输入值的梯度。有关更多细节,请参见 官方页面

  • 我们使用了 @tf.function 装饰器将上述 Python 函数转换为 TensorFlow 图。这是有用的,因为梯度计算可能非常昂贵,使用图模式执行可以显著加速计算。

4.3 定义梯度下降步骤

接下来,我们配置了计算总损失相对于参数(网络权重和偏差,以及未知参数 ab)的梯度的逻辑。这对于执行模型训练的梯度下降是必要的:

@tf.function
def train_step(X_ODE, X, y, IC_weight, ODE_weight, data_weight, model):
    """Calculate gradients of the total loss with respect to network model parameters.

    Args:
    ----
    X_ODE: collocation points for evaluating ODE residuals
    X: observed samples
    y: target values of the observed samples
    IC_weight: weight for initial condition loss
    ODE_weight: weight for ODE loss
    data_weight: weight for data loss
    model: PINN model

    Outputs:
    --------
    ODE_loss: calculated ODE loss
    IC_loss: calculated initial condition loss
    data_loss: calculated data loss
    total_loss: weighted sum of ODE loss, initial condition loss, and data loss
    gradients: gradients of the total loss with respect to network model parameters.
    """

    with tf.GradientTape() as tape:
        tape.watch(model.trainable_weights)

        # Initial condition prediction
        y_pred_IC, _ = model(tf.zeros((1, 1)))

        # ODE residual
        ODE_residual = ODE_residual_calculator(t=X_ODE, model=model)

        # Data loss
        y_pred_data, _ = model(X)

        # Calculate loss
        IC_loss = tf.reduce_mean(keras.losses.mean_squared_error(tf.constant([[1.0, 0.8, 0.5]]), y_pred_IC))
        ODE_loss = tf.reduce_mean(tf.square(ODE_residual))
        data_loss = tf.reduce_mean(keras.losses.mean_squared_error(y, y_pred_data))

        # Weight loss
        total_loss = IC_loss*IC_weight + ODE_loss*ODE_weight + data_loss*data_weight

    gradients = tape.gradient(total_loss, model.trainable_variables)

    return ODE_loss, IC_loss, data_loss, total_loss, gradients

在上述代码中:

  1. 我们考虑三个损失项:初始条件损失 IC_loss、ODE 残差损失 ODE_loss 和数据损失 data_lossIC_loss 通过将模型预测的u(t=0)与已知的u初始值进行比较来计算,ODE_loss 通过调用我们之前定义的 ODE_residual_calculator 函数来计算,而数据损失则是通过将模型预测值(即 u₁, u₂, u₃)与它们的观测值进行简单比较来计算的。

  2. 我们将总损失定义为 IC_lossODE_lossdata_loss 的加权和。通常,权重控制在训练过程中对各个损失项的重视程度。在我们的案例研究中,将它们全部设置为 1 就足够了。

4.4 数据准备

在本小节中,我们讨论了如何组织数据以进行 PINN 模型训练。

回忆一下,我们的总损失函数包含 ODE 残差损失和数据损失。因此,我们需要生成时间维度上的配点(用于评估 ODE 损失)和配对输入(t)-输出(u)的监督数据。

# Set batch size
data_batch_size = 100
ODE_batch_size = 1000

# Samples for enforcing ODE residual loss
N_collocation = 10000
X_train_ODE = tf.convert_to_tensor(np.linspace(0, 10, N_collocation).reshape(-1, 1), dtype=tf.float32)
train_ds_ODE = tf.data.Dataset.from_tensor_slices((X_train_ODE))
train_ds_ODE = train_ds_ODE.shuffle(10*N_collocation).batch(ODE_batch_size)

# Samples for enforcing data loss
X_train_data = tf.convert_to_tensor(u_obs[:, :1], dtype=tf.float32)
y_train_data = tf.convert_to_tensor(u_obs[:, 1:], dtype=tf.float32)
train_ds_data = tf.data.Dataset.from_tensor_slices((X_train_data, y_train_data))
train_ds_data = train_ds_data.shuffle(10000).batch(data_batch_size)

在上面的代码中,我们在目标时间域[0, 10]内分配了 10000 个等间距的配点。为了方便数据损失计算,我们预生成了配对输入(t)-输出(u)数据集u_obs,其第一列为时间坐标,其余三列分别表示 u₁、u₂ 和 u₃。u_obs包含 1000 个数据点,计算方式如下代码:

# Set up simulation
u_init = [1, 0.8, 0.5]
t_span = [0, 10]
obs_num = 1000

# Solve ODEs
u_obs = simulate_ODEs(u_init, t_span, obs_num)

其中 simulate_ODEs 是 ODE 求解器,它在给定初始条件和模拟域的情况下模拟u轨迹:

def simulate_ODEs(u_init, t_span, obs_num):
    """Simulate the ODE system and obtain observational data. 

    Args:
    ----
    u_init: list of initial condition for u1, u2, and u3
    t_span: lower and upper time limit for simulation
    obs_num: number of observational data points

    Outputs:
    --------
    u_obs: observed data for u's
    """

    # Target ODEs
    def odes(t, u):
        du1dt = np.exp(-t/10) * u[1] * u[2]
        du2dt = u[0] * u[2]
        du3dt = -2 * u[0] * u[1]
        return [du1dt, du2dt, du3dt]

    # Solve ODEs
    t_eval = np.linspace(t_span[0], t_span[1], obs_num)
    sol = solve_ivp(odes, t_span, u_init, method='RK45', t_eval=t_eval)

    # Restrcture solution
    u_obs = np.column_stack((sol.t, sol.y[0], sol.y[1], sol.y[2]))

    return u_obs

下图展示了目标u的轮廓。请注意,我们已经抽取了 1000 个等间距的 (tu₁)、(tu₂) 和 (tu₃) 数据对(包含在u_obs中),作为数据损失计算的监督数据。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6. 我们当前研究的 ODE 的输出轮廓。(图像由本博客作者提供)

4.5 PINN 训练

以下代码定义了主要的训练和验证逻辑:

# Set up training configurations
n_epochs = 1000
IC_weight= tf.constant(1.0, dtype=tf.float32)   
ODE_weight= tf.constant(1.0, dtype=tf.float32)
data_weight= tf.constant(1.0, dtype=tf.float32)
a_list, b_list = [], []

# Initial value for unknown parameters
a_init, b_init = -1, 1

# Set up optimizer
optimizer = keras.optimizers.Adam(learning_rate=2e-2)

# Instantiate the PINN model
PINN = create_PINN(a_init=a_init, b_init=b_init)
PINN.compile(optimizer=optimizer)

# Configure callbacks
_callbacks = [keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=100),
             tf.keras.callbacks.ModelCheckpoint('PINN_model.h5', monitor='val_loss', save_best_only=True)]
callbacks = tf.keras.callbacks.CallbackList(
                _callbacks, add_history=False, model=PINN)

# Start training process
for epoch in range(1, n_epochs + 1):  
    print(f"Epoch {epoch}:")

    for (X_ODE), (X, y) in zip(train_ds_ODE, train_ds_data):

        # Calculate gradients
        ODE_loss, IC_loss, data_loss, total_loss, gradients = train_step(X_ODE, X, y, IC_weight, 
                                                                         ODE_weight, data_weight, PINN)
        # Gradient descent
        PINN.optimizer.apply_gradients(zip(gradients, PINN.trainable_variables))

    # Parameter recording
    a_list.append(PINN.layers[-1].a.numpy())
    b_list.append(PINN.layers[-1].b.numpy())

    ####### Validation
    val_res = ODE_residual_calculator(tf.reshape(tf.linspace(0.0, 10.0, 1000), [-1, 1]), PINN)
    val_ODE = tf.cast(tf.reduce_mean(tf.square(val_res)), tf.float32)

    u_init=tf.constant([[1.0, 0.8, 0.5]])
    val_pred_init, _ = PINN.predict(tf.zeros((1, 1)))
    val_IC = tf.reduce_mean(tf.square(val_pred_init-u_init))

    # Callback at the end of epoch
    callbacks.on_epoch_end(epoch, logs={'val_loss': val_IC+val_ODE})

    # Re-shuffle dataset
    train_ds_data = tf.data.Dataset.from_tensor_slices((X_train_data, y_train_data))
    train_ds_data = train_ds_data.shuffle(10000).batch(data_batch_size) 

    train_ds_ODE = tf.data.Dataset.from_tensor_slices((X_train_ODE))
    train_ds_ODE = train_ds_ODE.shuffle(10*N_collocation).batch(ODE_batch_size) 
  • 正如之前讨论的,我们将不同损失组件的权重设置为 1。

  • 我们将 ab 的初始猜测设置为-1 和 1,分别。回忆一下,这些值与它们的真实值不同,真实值分别为-2 和 0。

  • 为了验证,我们将 ODE 残差损失和初始条件损失相加,作为最终的验证损失。请注意,我们在这里不考虑数据损失,因为我们假设没有额外的配对 tu 数据集用于验证目的。计算出的验证损失用于调整学习率。

下图展示了损失收敛曲线。我们可以看到所有三个损失组件都正确收敛,这表明训练已满意完成。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7. 损失收敛图。(图像由本博客作者提供)

下图展示了预测的u与通过 ODE 求解器计算的真实值之间的比较。在这里,我们还可以看到 PINN 能够准确地解决我们的目标 ODE。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 8. 预测的u与 ODE 求解器计算的真实值的比较。

然而,训练 PINN 并不是我们的最终目标。相反,我们更感兴趣的是估计我们目标 ODE 中嵌入的未知数。让我们从参数估计开始。下图描绘了 ab 的演变。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 9. 未知参数 a 和 b 迅速脱离了指定的初始值,并收敛到它们的真实值。这表明所采用的 PINN 策略能够对 ODE 系统进行参数估计。(图片由本博客作者提供)

我们可以清楚地看到,随着训练的进行,ab的值迅速收敛到各自的真实值。这表明我们的 PINN 策略在参数估计方面是有效的。

除了未知参数外,我们还通过训练好的辅助f-网络获得了未知函数f₁和f₂的估计值。为了检查f₁和f₂的近似精度,我们可以将它们与计算得到的 du₁/dt 和 du₂/dt 进行比较,如下代码所示:

X_test = np.linspace(0, 10, 1000).reshape(-1, 1)
X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)

with tf.GradientTape() as tape:
    tape.watch(X_test)
    u, f = PINN(X_test)

# Calculate gradients
dudt = tape.batch_jacobian(u, X_test)[:, :, 0]
du1_dt, du2_dt = dudt[:, :1], dudt[:, 1:2]

# Visualize comparison
fig, ax = plt.subplots(1, 2, figsize=(10, 4))

ax[0].scatter(du1_dt.numpy().flatten(), f[:, 0].numpy())
ax[0].set_xlabel('$du_1$/dt', fontsize=14)
ax[0].set_ylabel('$f_1$', fontsize=14)

ax[1].scatter(du2_dt.numpy().flatten(), f[:, 1].numpy())
ax[1].set_xlabel('$du_2$/dt', fontsize=14)
ax[1].set_ylabel('$f_2$', fontsize=14)

for axs in ax:
    axs.tick_params(axis='both', which='major', labelsize=12)
    axs.grid(True)

plt.tight_layout()

从下图中我们可以清楚地看到,f-网络的预测完全符合控制 ODE,这与之前观察到的 ODE 残差非常小的情况一致。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 10. 计算出的导数与预测的f函数值的比较。

尽管我们可以用f-网络准确地逼近未知函数f₁和f₂,但归根结底,f-网络是一个黑箱神经网络模型。自然地,我们会想问:这些估计函数的确切功能形式是什么?这个答案可以为我们提供对潜在物理过程的更深入理解,并帮助我们将结果推广到其他类似的问题。

那么,我们如何从训练好的神经网络模型中提取这些精确的功能形式呢?我们将在下一节中探讨这个问题。

5. 符号回归

符号回归是一种强大的监督学习技术,可以用来发现最适合给定数据集的潜在数学公式。正如其名称所示,这项技术包括两个关键组成部分:符号回归

  • 符号指的是使用符号表达式来建模输入输出关系,例如,“+”表示加法,“-”表示减法,“cos”表示余弦函数等。符号回归方法不是拟合预定义模型(例如,多项式模型等),而是通过整个潜在符号表达式的空间进行搜索,以找到最佳拟合。

  • 回归指的是创建一个模型以预测输出变量的过程,该过程基于输入变量,从而捕捉它们之间的潜在关系。尽管“回归”一词可能会让人联想到线性回归,但在符号回归的背景下,它并不局限于任何特定的模型形式,而是可以采用各种数学运算符和结构。

在这一部分,我们将实现符号回归技术,将学习到的 f-网络提炼成可解释且紧凑的数学表达式,这与张等人在他们的原始论文中提出的策略一致。我们将首先介绍将用于符号回归的库 PySR。随后,我们将应用这个库解决我们的课题,并讨论超参数的选择。最后,我们将分析获得的结果。

5.1 PySR 库

PySR 是一个开源 Python 库,旨在提供实用的高性能科学符号回归。它使用先进的 evolutionary 优化算法在简单解析表达式的空间中搜索,以获得准确且可解释的模型,从而将预测误差和模型复杂度共同最小化。

尽管 PySR 暴露了一个类似于 scikit-learn 风格的简单 Python 前端 API,但其后台是用纯 Julia 编写的,库名为 SymbolicRegression.jl。这为用户提供了定制操作符和优化损失函数的灵活性,同时享有高计算性能。有关 PySR 工作原理的更多细节,请参见这篇论文

要开始使用 PySR,你需要首先安装 Julia。然后运行

pip3 install -U pysr

然后通过

python3 -m pysr install

或者在 IPython 中调用

import pysr
pysr.install()

PySR 也可以通过 conda 或 docker 安装。请查看安装页面以获取更多细节。

5.2 实施

接下来,我们应用 PySR 库将学习到的 f-网络提炼成可解释且紧凑的数学表达式。首先,我们需要生成符号回归学习的数据集:

t = np.linspace(0, 10, 10000).reshape(-1, 1)
u, f = PINN.predict(t, batch_size=12800)

# Configure dataframe
df = pd.DataFrame({
    't': t.flatten(),
    'u1': u[:, 0],
    'u2': u[:, 1],
    'u3': u[:, 2],
    'f1': f[:, 0],
    'f2': f[:, 1]
})
df.to_csv('f_NN_IO.csv', index=False)

请注意,对于我们当前的问题,符号回归学习的输入是 tu₁、u₂ 和 u₃,输出是 f₁ 和 f₂。这是因为在我们的目标 ODE 中,我们假设 f₁=f₁(tu₁、u₂、u₃) 和 f₂=f₂(tu₁、u₂、u₃)。我们保存了生成的数据框(见下图)以备后用。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 11. 生成的符号回归学习数据框。(图片来源:本博客作者)

生成数据集后,我们就可以使用 PySR 进行符号回归了。请注意,建议在终端中运行 PySR 代码,而不是在 Jupyter Notebook 中。尽管 PySR 支持 Jupyter Notebook,但在终端环境中的打印(例如,搜索进度、当前最佳结果等)效果要更好。

按照 scikit-learn 风格,我们首先定义一个模型对象:

from pysr import PySRRegressor

model = PySRRegressor(
    niterations=20,  
    binary_operators=["+", "*"],
    unary_operators=[
        "cos",
        "exp",
        "sin",
        "inv(x) = 1/x",
    ],
    extra_sympy_mappings={"inv": lambda x: 1 / x},
    loss="L1DistLoss()",
    model_selection="score",
    complexity_of_operators={
        "sin": 3, "cos": 3, "exp": 3,
        "inv(x) = 1/x": 3
    }
)

以下是指定超参数的详细信息:

  • niterations:算法运行的迭代次数。通常,较大的迭代次数会产生更好的结果,但代价是更高的计算成本。然而,由于 PySR 允许提前终止搜索任务,好的做法是将 niterations 设置为一个非常大的值并保持优化进行。一旦识别出的方程看起来令人满意,就可以提前停止任务。

  • binary_operators:用于搜索的二元运算符字符串列表。PySR 支持的内置二元运算符包括 +-*/^greatermodlogical_orlogical_and

  • unary_operators:用于搜索的一元运算符列表。注意,一元运算符只接受单个标量作为输入。内置的一元运算符包括 negsquarecubeexpabsloglog10log2log1psqrtsincostansinhcoshtanhatanasinhacoshatanh_clip(=atanh((x+1)%2 - 1))、erferfcgammareluroundfloorceilroundsign。注意,要提供自定义运算符,我们需要将“myfunction(x) = …”传递给运算符列表,就像我们用“inv(x) = 1/x”做的那样。

  • extra_sympy_mappings:提供自定义的 binary_operatorsunary_operators 在 julia 字符串中与 sympy 中相同运算符的映射。这在导出结果时非常有用。

  • loss:指定元素级损失函数的 Julia 代码字符串(如在 LossFunctions.jl 中定义)。常用的损失包括 L1DistLoss()(绝对距离损失)、L2DistLoss()(最小二乘损失)、HuberLoss()(用于抗离群值的 Huber 损失函数)。损失函数指定了符号回归搜索的优化目标。

  • model_selection:从每个复杂度的最佳表达式列表中选择最终表达式的标准。score 意味着候选模型将根据最高得分进行选择,得分定义为 -Δlog(loss)/ΔC,其中 C 代表表达式的复杂度,Δ 表示局部变化。因此,如果一个表达式在稍高的复杂度下具有更好的损失,则更受青睐。

  • complexity_of_operators:默认情况下,所有运算符的复杂度为 1。要更改默认复杂度设置并优先考虑不同的运算符,我们可以提供一个字典,键为运算符字符串,值为其对应的复杂度级别。在我们当前的案例中,我们将所有一元运算符的复杂度级别设置为 3,这也在 Zhang 等人的原始论文中采用。

值得一提的是,PySRRegressor 提供了许多其他超参数,用于设置算法、数据预处理、停止标准、性能和并行化、监控、环境和结果导出。有关控制符号回归搜索的所有选项的完整列表,请查看 PySRRegressor 参考页面

5.3 识别结果

在指定模型对象后,我们可以用三行代码启动拟合过程(用于提炼 f₁ 的解析形式):

df = pd.read_csv('f_NN_IO.csv')
X = df.iloc[:, :4].to_numpy()
f1 = df.loc[:, 'f1'].to_numpy()

model.fit(X, f1)

在脚本运行时,你应该能够看到进度条和当前最佳方程,如下图所示。注意 x0、x1、x2 和 x3 分别对应 tu₁、u₂ 和 u₃。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一旦优化任务完成,终端中将出现候选方程列表:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果我们根据 评分值 对方程进行排名,可以看到排名前三的方程是:

  • uuexp( -0.1053391 t )

  • 0.60341805 uu

  • uu

回忆一下我们真实的 ODE 是

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

令人印象深刻的是,PySR 准确地识别出了基本输入(即,它识别出 u₁ 在 f₁ 中不起作用),并发现了一个接近 f₁ 真实表达式的解析表达式(排名第一的结果)。

我们对 f₂ 进行了相同的分析。优化结果如下图所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这次,我们注意到 f₂ 的真实表达式,即 f₂=uu₃,仅作为第二好的(按评分计算)方程出现。然而,请注意,最佳方程,即 u₃,其得分仅比第二好的高一点。另一方面,uu₃ 的损失值比单独使用 u₃ 低一个数量级。这些观察结果表明,在实际操作中,我们需要领域知识/经验来做出明智的决定,以判断追求高准确度所带来的复杂性是否值得。

6. 关键要点

在这篇博客文章中,我们探讨了从观测数据中发现微分方程的问题。我们遵循了 Zhang 等人提出的策略,将其实现为代码,并应用于一个案例研究。以下是关键要点:

1️⃣ 物理信息神经网络 (PINN) 是一个多用途的工具,用于进行系统识别,特别是在对控制微分方程只有部分信息已知的情况下。通过同化观察数据和现有的物理知识,PINN 不仅能有效估计未知参数,还能估计未知函数,如果我们采用用辅助神经网络对未知函数进行参数化的技巧,并与主 PINN 一起联合训练。这些因素共同作用,相比传统的系统识别方法,具有显著的优势。

2️⃣ 符号回归是一种强大的工具,用于揭开学习神经网络的黑箱。通过利用先进的进化算法在整个符号表达式空间中进行搜索,符号回归能够提取出可解释且紧凑的解析表达式,这些表达式可以准确描述隐藏的输入输出关系。这个知识蒸馏过程在实践中受到高度赞赏,因为它能有效增强我们对基础系统动态的理解。

在我们结束这篇博客之前,有几点在将 PINN+符号回归应用于实际问题时值得考虑:

1️⃣ 不确定性量化 (UQ)

在这篇博客中,我们假设我们观察到的 u₁、u₂ 和 u₃ 数据是无噪声的。然而,这种假设通常不成立,因为实际的动态系统中的观察数据很容易被噪声污染。因此,我们系统识别结果的 准确性可靠性 都会受到影响。因此,一个关键方面是考虑在我们的系统识别工作流中进行不确定性量化。像贝叶斯神经网络和 蒙特卡洛模拟 这样的技术可以合理地考虑观察数据中的噪声,并提供对预测的置信区间的估计。

2️⃣ 符号回归的敏感性

一般来说,符号回归得到的结果可能对所使用的损失函数、提供的单一和二元运算符候选项以及定义的运算符复杂度敏感。例如,在我尝试重现 Zhang 等人发布的结果时,尽管我采用了完全相同的设置(据我所知),但我未能获得与原始论文中所示的 f₂ 完全一致的前 3 个方程。这种不匹配可能有几个因素:首先,进化优化技术本质上是随机的,因此结果可能在不同的运行中有所不同。其次,第一阶段训练的 PINN 可能不同,因此生成的数据集(即 tu₁,u₂,u₃ → f₁,f₂)也不同,从而影响了符号回归的结果。

总的来说,这些观察结果表明,符号回归的结果不应盲目接受。相反,依赖领域知识/理解来批判性地评估识别出的方程的合理性至关重要。

如果你觉得我的内容有用,可以在这里请我喝咖啡🤗 非常感谢你的支持!

你可以在这里找到带有完整代码的伴随笔记本和脚本💻

要学习物理信息神经网络的最佳实践,请参阅:解开物理信息神经网络设计模式的奥秘

要了解更多关于物理信息运算符学习的内容,请参阅:通过物理信息深度运算符学习。

随时可以订阅我的新闻通讯或在Medium上关注我。

参考资料

[1] Zhang 等,结合 PINN 与符号回归发现阿尔茨海默病的反应扩散模型。arXiv,2023。

[2] Cranmer 等,使用 PySR 和 SymbolicRegression.jl 进行可解释的机器学习科学研究。arXiv,2023。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值