在瑞士卷数据集上使用python绘制测地线

本文介绍了如何使用Python在西瓜书中的流形学习中,通过欧氏距离计算近邻点并构建连接图,进而绘制测地线。步骤包括引入相关库、读取Swiss Roll数据、近邻查找、构建图和求最短路径,最后展示了实际的绘图结果。
摘要由CSDN通过智能技术生成

在学习西瓜书上的流形学习时,我们学习到了测地线的概念,那么如何画测地线呢?本文将使用python简单的实现一下在瑞士卷数据集上测地线的绘制。

目录

前言

具体步骤

1.引入库

2.读入数据

3.绘图

4.测地线的绘制

4.1首先对每个点基于欧 氏距离找出其近邻点

4.2建立一个近邻连接图

4.3找出从源点到终点的最短路径

4.4绘制

5.结果展示

总结

源代码:


前言

 在学习西瓜书上的流形学习时,我们学习到了测地线的概念,那么如何画测地线呢?

da0459354ced47ddae22fd55047276b1.png

首先看书上怎么说的

2ffc2dceae0c4f6ca2105bcccb2a83f2.png

 书上讲的很清楚了,求测地线的步骤大致为:

(1)首先对每个点基于欧 氏距离找出其近邻点

(2)建立一个近邻连接图,近邻点之间存在连接,而非近邻点之间不存在连接

(3)找出从源点到终点的最短路径,连接起来就是我们要的测地线了

所以接下来我将按照这个步骤一步步的实现它。

具体步骤

1.引入库

代码如下(示例):

import mat4py as mp
import numpy as np
# 载入必要库
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
%matplotlib inline
import pandas as pd
import networkx as nx  # 导入 NetworkX 工具包

from sklearn.neighbors import NearestNeighbors

2.读入数据

代码如下(示例):

from sklearn.datasets import make_swiss_roll
# 用make_swiss_roll得到渐变色
X, t = make_swiss_roll(n_samples=1000, noise=0.2, random_state=42)

3.绘图

我们看一下原始数据集在3维空间上的分布,可以看到这是一个流形。

# 绘图
fig = plt.figure(figsize=(12, 8))
ax = Axes3D(fig, elev=10, azim=80)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.Spectral)
ax.set_title('S Curve', fontsize=20)

5ef022d6f3044c6a92292707890cc55c.png


4.测地线的绘制

4.1首先对每个点基于欧 氏距离找出其近邻点

这里我们直接调用NearestNeighbors()方法计算就行了

返回值说明:

# 返回值indices:第0列元素为参考点的索引,后面是(n_neighbors - 1)个与之最近的点的索引
# 返回值distances:第0列元素为与自身的距离(为0),后面是(n_neighbors - 1)个与之最近的点与参考点的距离

# j 计算每个点的k近邻:
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(X)
    distances, indices = nbrs.kneighbors(X)

4.2建立一个近邻连接图

近邻点之间存在连接,而非近邻点之间不存在连接

初始化近邻矩阵:

dist_matrix=np.zeros((m,m))

 获取近邻矩阵:

 for i in range(m):
        for j in range(m):
            if j not in indices[i]:#若X[j]点不是X[i]的k近邻,则距离为0
                dist_matrix[i][j]=0
            else:#若X[j]点是X[i]的k近邻
                for index in range(len(indices[i])):#求X[j]到X[i]的距离
                    if indices[i][index]==j:
                        dist_matrix[i][j]=distances[i][index]
                        break

4.3找出从源点到终点的最短路径

这里可以使用NetworkX图去求

dfAdj = pd.DataFrame(dist_matrix)
G1 = nx.from_pandas_adjacency(dfAdj)  # 由 pandas 顶点邻接矩阵 创建 NetworkX 图
# 两个指定顶点之间的最短加权路径
minWPath = nx.bellman_ford_path(G1, source=source, target=target)  # 顶点 10 到 顶点 100 的最短加权路径
print("最短路径为:",minWPath)

d3b022bca565435c803ab5f4bd829202.png

4.4绘制

有了最短路径,把路径上的点连起来就可以进行绘制了

(1)获得坐标


def cedi_line(X):
    if len(X[0])==2:
        x=[]
        y=[]
        for i in minWPath:
            x.append(X[i,0])
            y.append(X[i,1])
        return x,y
    if len(X[0])==3:
        x=[]
        y=[]
        z=[]
        for i in minWPath:
            x.append(X[i,0])
            y.append(X[i,1])
            z.append(X[i,2])
        return x,y,z

(2)绘制

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

# 绘图
fig = plt.figure(figsize=(12, 8))
ax = Axes3D(fig, elev=10, azim=80)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.Spectral)
ax.set_title('S Curve', fontsize=20)

x,y,z=cedi_line(X)
ax.plot(x, y, z, label='parametric curve',color='red')

# 显示图例
ax.legend()

# 显示图形
plt.show()

5.结果展示

5746c7b85dc54f678b7561642d16a056.png

降成二维后,测地线的绘制

e5e3df95bfdc4d259dfd4b43de09baca.png

总结

以上就是今天要讲的内容,本文基于西瓜书上绘制测地线的方法进行了实现,至于有不有更简洁、更正确的画法,还请不吝赐教!

源代码:

本文参考的文章:

https://blog.csdn.net/youcans/article/details/116999881icon-default.png?t=N7T8https://blog.csdn.net/youcans/article/details/116999881

  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 9
    评论
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值