《PyTorch深度学习实践》P2线性模型 作业:mplot3d绘制三维图像及mershgrid()生成网格坐标详解,绘制y=x*w+b损失图像

总是使用AI工具,Pytorch没有详细的学过,所以攻略了一下,找到b站的up主刘二大人,准备跟着他的《PyTorch深度学习实践》学一遍。将每节的重点和习题做一下,督促自己每天刷一节课程完成任务,同时给在学的朋友们提供帮助!

目录

mpl_toolkits.mplot3d

numpy.mershgrid()

绘制y=x*w+b损失图像


mpl_toolkits.mplot3d

官方帮助文档:mpl_toolkits.mplot3d — Matplotlib 3.9.2 documentation

介绍:mplot3dMatplotlib 库中的一个子模块,专门用于绘制三维图形。它提供了三维绘图的功能,如三维散点图、线图、曲面图等。

使用方法:步骤、语法如下:

  • 导入Axes3D,创建三维坐标图(matplotlib 3.1.0之后不需要显式导入)
  • 创建Mathplotlib的figure和3D subplot
  • 生成图像
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

 三维散点图、线图、曲面图、条形图绘制:

1、三维散点图绘制

import numpy as np
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

# 创建数据
x = np.random.rand(100)
y = np.random.rand(100)
z = np.random.rand(100)

# 创建图形对象
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# 绘制三维散点图
ax.scatter(x, y, z)

# 显示图形
plt.show()

2、三维线图

参数化函数如代码,数学公式如下:

x=(z^2+1) \cdot \sin \theta

y=(z^2+1) \cdot \cos \theta

z=z

import numpy as np
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

# 创建数据
theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)
z = np.linspace(-2, 2, 100)
r = z**2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)

# 创建图形对象
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# 绘制三维线图
ax.plot(x, y, z)

# 显示图形
plt.show()

3、三维曲面图

函数公式:z=\sin (\sqrt{x^2+y^2})

import numpy as np
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

# 创建数据
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
x, y = np.meshgrid(x, y)
z = np.sin(np.sqrt(x**2 + y**2))

# 创建图形对象
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# 绘制三维曲面图
ax.plot_surface(x, y, z, cmap='viridis')

# 显示图形
plt.show()

4、三维条形图

import numpy as np
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

# 创建数据
_x = np.arange(4)
_y = np.arange(3)
_xx, _yy = np.meshgrid(_x, _y)
x, y = _xx.ravel(), _yy.ravel()
z = np.zeros_like(x)
dx = dy = 0.8
dz = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

# 创建图形对象
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# 绘制三维条形图
ax.bar3d(x, y, z, dx, dy, dz, shade=True)

# 显示图形
plt.show()


numpy.mershgrid()

介绍:numpy.meshgrid() 是 NumPy 库中用于生成坐标网格的函数。它通常用于多维空间中变量的笛卡尔积表示,常用于计算数学和可视化中。meshgrid 接受一个或多个 1D 数组作为输入,并返回对应的多维网格矩阵,使得每个输出数组对应一个维度中的坐标。

numpy.meshgrid(*xi, indexing='xy', sparse=False, copy=True)
  • xi: 一维数组或多个数组,代表不同的坐标轴。
  • indexing: 控制网格的索引类型。可以是 'xy''ij'
    • 'xy':用于二维的笛卡尔坐标系,其中第一个数组代表 x 坐标,第二个数组代表 y 坐标。适合 2D 平面。
    • 'ij':表示矩阵索引法。对任意维度都适用,尤其适合高维数组。
  • sparse: 如果为 True,则生成稀疏网格,节省内存。对于大网格或需要节约资源时很有用。
  • copy: 如果为 True,则会返回新对象;否则返回视图以提高效率。
  • 返回值:函数返回与输入维度相同的n个数组,每个数组的形状与输入数组构成的网格形状相匹配。

示例:

import numpy as np

x = np.array([1, 2, 3])
y = np.array([4, 5])
X, Y = np.meshgrid(x, y)

print(X)
# 输出:
# [[1 2 3]
#  [1 2 3]]

print(Y)
# 输出:
# [[4 4 4]
#  [5 5 5]]

绘制y=x*w+b损失图像

import numpy as np
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

x_data = [1.0, 2.0, 3.0]
y_data = [3.0, 5.0, 7.0]  # 原函数为 y = 2x + 1

def forward(x, w, b):
    return w * x + b

def loss(x, y, w, b):
    y_pred = forward(x, w, b)
    return (y_pred - y) ** 2

w_values = np.arange(0.0, 4.1, 0.1)  # w 从 0 到 4
b_values = np.arange(0.0, 4.1, 0.1)  # b 从 0 到 4
W, B = np.meshgrid(w_values, b_values)

mse_values = np.zeros_like(W)  
for i in range(len(x_data)):
    mse_values += loss(x_data[i], y_data[i], W, B)

mse_values /= len(x_data) 

# 创建三维图像
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surface = ax.plot_surface(W, B, mse_values, cmap='viridis')         #绘制曲面图
fig.colorbar(surface, ax=ax, shrink=0.5, aspect=5)      #添加色柱

ax.set_xlabel('w')
ax.set_ylabel('b')
ax.set_zlabel('Loss')

plt.show()

朋友们有什么建议或疑问可以在评论区给出,或者是私信我!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值