在学习西瓜书上的流形学习时,我们学习到了测地线的概念,那么如何画测地线呢?所谓“条条大路通北京”,上一篇文章我们根据西瓜书实现了测地线的画法,但是作者觉得在求距离矩阵时自己的代码有些啰嗦,所以本文将针对求距离矩阵的方法进行改进。
文章目录
前言
在学习西瓜书上的流形学习时,我们学习到了测地线的概念,那么如何画测地线呢?
首先看书上怎么说的
书上讲的很清楚了,求测地线的步骤大致为:
(1)首先对每个点基于欧 氏距离找出其近邻点
(2)建立一个近邻连接图,近邻点之间存在连接,而非近邻点之间不存在连接
(3)找出从源点到终点的最短路径,连接起来就是我们要的测地线了
所以接下来我将按照这个步骤一步步的实现它。
一、具体步骤
1.引入库
代码如下(示例):
import mat4py as mp
import numpy as np
# 载入必要库
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
%matplotlib inline
import pandas as pd
import networkx as nx # 导入 NetworkX 工具包
from sklearn.neighbors import NearestNeighbors
2.读入数据
代码如下(示例):
from sklearn.datasets import make_swiss_roll
# 用make_swiss_roll得到渐变色
X, t = make_swiss_roll(n_samples=1000, noise=0.2, random_state=42)
3.绘图
我们看一下原始数据集在3维空间上的分布,可以看到这是一个流形。
# 绘图
fig = plt.figure(figsize=(12, 8))
ax = Axes3D(fig, elev=10, azim=80)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.Spectral)
ax.set_title('S Curve', fontsize=20)
4.测地线的绘制
4.1调用ISOMAP方法获取距离矩阵
isomap = Isomap(n_components=3,n_neighbors=3)
data_2d = isomap.fit_transform(X)
dist_matrix = isomap.dist_matrix_
为什么可以根据ISOMAP方法求解近邻矩阵呢?由下面的ISOMAP算法求解流程就可知道了
4.2找出从源点到终点的最短路径
接下来就可以直接求最短路径了,比之前的方法简洁多了。(反正是能调库就调库吧,毕竟人家写的比我们好)
dfAdj = pd.DataFrame(dist_matrix)
G1 = nx.from_pandas_adjacency(dfAdj) # 由 pandas 顶点邻接矩阵 创建 NetworkX 图
# 两个指定顶点之间的最短加权路径
minWPath = nx.bellman_ford_path(G1, source=source, target=target) # 顶点 10 到 顶点 100 的最短加权路径
print("最短路径为:",minWPath)
4.3绘制
有了最短路径,把路径上的点连起来就可以进行绘制了
(1)获得坐标
if len(X[0])==2:
x=[]
y=[]
for i in minWPath:
x.append(X[i,0])
y.append(X[i,1])
return x,y
if len(X[0])==3:
x=[]
y=[]
z=[]
for i in minWPath:
x.append(X[i,0])
y.append(X[i,1])
z.append(X[i,2])
return x,y,z
(2)绘制
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
# 绘图
fig = plt.figure(figsize=(12, 8))
ax = Axes3D(fig, elev=10, azim=80)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.Spectral)
ax.set_title('S Curve', fontsize=20)
x,y,z=cedi_line(X)
ax.plot(x, y, z, label='parametric curve',color='red')
# 显示图例
ax.legend()
# 显示图形
plt.show()
5.结果展示
降成二维后,测地线的绘制
总结
以上就是今天要讲的内容,本文基于西瓜书上绘制测地线的方法进行了实现,是对上一篇文章的改进,毕竟进无止境嘛。至于有不有更简洁、更正确的画法,还请不吝赐教!
源代码:
本文参考的文章: