【ML】numpy meshgrid函数使用说明(全网最简单版)

12 篇文章 0 订阅
2 篇文章 0 订阅

meshgrid的作用?

首先要明白numpy.meshgrid()函数是为了画网格,(对就是画格子,至于格子怎么用,那要看实际场景了,我们这里只关心怎么画格子)

怎么使用(举例说明)

为了方便大家理解,我以结果反推的方式进行讲解,这样更直观。先看下图:
假如我们要得到这样一个网格图(注意坐标):
在这里插入图片描述

手工描点(帮助理解)

  1. 先找到坐标x=1,然后分别画出(1,5),(1,6),(1,7)
  2. 再找到坐标x=2,然后分别画出(2,5),(2,6),(2,7)
  3. 以此类推即可

我们可以得到:x=[1,2,3,4],y=[5,6,7]
做个笛卡尔积即可得到所有点。所以我们可以有以下代码:

x_component = np.array([1,2,3,4])
y_component = np.array([5,6,7])
x,y = np.meshgrid(x_component,y_component)

输出结果:

x=[[1 2 3 4]
 [1 2 3 4]
 [1 2 3 4]]
y=[[5 5 5 5]
 [6 6 6 6]
 [7 7 7 7]]

输出结果有点不好理解。这是啥???,但是我们观察规律,如果我们把x,y两个矩阵当做两张图片叠加在一起是什么效果?
示意图:

[[1 5    2 5    3 5    4 5]
 [1 6    2 6    3 6    4 6]
 [1 7    2 7    3 7    4 7]]

然后上下翻转一下:

[[1 7    2 7    3 7    4 7]
 [1 6    2 6    3 6    4 6]
 [1 5    2 5    3 5    4 5]]

这不是跟图上的坐标一模一样嘛!!!

怎么画三维?

先看图(目标):
在这里插入图片描述

x_component = np.array([1,2,3,4])
y_component = np.array([5,6,7])
z_component = np.array([8,9])
x,y,z = np.meshgrid(x_component,y_component,z_component)

输出(怎么理解?叠加法!!!):

x= [[[1 1]
  [2 2]
  [3 3]
  [4 4]]

 [[1 1]
  [2 2]
  [3 3]
  [4 4]]

 [[1 1]
  [2 2]
  [3 3]
  [4 4]]]
y= [[[5 5]
  [5 5]
  [5 5]
  [5 5]]

 [[6 6]
  [6 6]
  [6 6]
  [6 6]]

 [[7 7]
  [7 7]
  [7 7]
  [7 7]]]
z= [[[8 9]
  [8 9]
  [8 9]
  [8 9]]

 [[8 9]
  [8 9]
  [8 9]
  [8 9]]

 [[8 9]
  [8 9]
  [8 9]
  [8 9]]]

一维展开后是什么效果?

# 二维展开成一维,展开函数ravel(), 拼接函数c_可见后续博客
xy = np.c_[xv.ravel(),yv.ravel()]
# 三维展开成一维
xyz = np.c_[xv.ravel(),yv.ravel(),zv.ravel()]

展开后即可得到二维和三维的所有坐标

附画图代码

二维图:

#二维图
import numpy as np
x_component = np.array([1,2,3,4])
y_component = np.array([5,6,7])
xv,yv = np.meshgrid(x_component,y_component)

import matplotlib.pyplot as plt
str_label = '({x_label}, {y_label})'
fig = plt.figure(figsize=(5,5))
plt.axis([0,5,4,8])

xy = np.c_[xv.ravel(),yv.ravel()]
for point in xy:
    x = point[0]
    y = point[1]
    color = 'r' if y==5 else ('b' if y==6 else 'g')
    plt.scatter(x, y, c=color)
    plt.annotate(str_label.format(x_label=x,y_label=y),xy = (x, y), xytext = (x+0.1, y+0.1))
                
plt.show()

三维图:

# 3维图
import numpy as np
x_component = np.array([1,2,3,4])
y_component = np.array([5,6,7])
z_component = np.array([8,9])
xv,yv,zv = np.meshgrid(x_component,y_component,z_component)

import matplotlib.pyplot as plt
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(projection='3d')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

xyz = np.c_[xv.ravel(),yv.ravel(),zv.ravel()]
for point in xyz:
    x = point[0]
    y = point[1]
    z = point[2]
    color = 'r' if z == 8 else 'b'
    ax.scatter(x, y, z, c=color)
plt.show()
  • 5
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值