源代码
def forward_propgation(X):
# z = w1 * x1s + w2 * x2s + b
# ndarray的dot函数:点乘运算
# ndarray的T属性:转置运算
Z = X.dot(W.T) + B
# a = 1 / (1 + np.exp(-z))
A = 1 / (1 + np.exp(-Z))
return A
def show_scatter_surface(xs,y,forward_propgation):
x = xs[:,0]
z = xs[:,1]
fig = plt.figure()
ax = Axes3D(fig)
fig.add_axes(ax)
ax.scatter(x, z, y)
x = np.arange(np.min(x),np.max(x),0.01)
z = np.arange(np.min(z),np.max(z),0.01)
x,z = np.meshgrid(x,z)
y = forward_propgation(x,z)
ax.plot_surface(x, z, y, cmap='rainbow')
plt.show()
出现的问题
问题1
Traceback (most recent call last): File "D:\work\python\lesson78\lesson7.py", line 31, in <module> plot_utils.show_scatter_surface(X, Y, forward_propgation) File "D:\work\python\lesson78\plot_utils.py", line 38, in show_scatter_surface y = forward_propgation(x,z) TypeError: forward_propgation() takes 1 positional argument but 2 were given
原因
在的 show_scatter_surface
函数中,错误发生在调用 forward_propagation(x, z)
这一行。错误提示是 TypeError: forward_propagation() takes 1 positional argument but 2 were given
,这意味着 forward_propagation
函数只接受一个参数,但传递了两个参数。
问题2
File "D:\work\python\lesson78\plot_utils.py", line 44, in show_scatter_surface
ax.plot_surface(x_vals, z_vals, y_vals, cmap='rainbow')
File "D:\work\python\venv\lib\site-packages\mpl_toolkits\mplot3d\axes3d.py", line 1700, in plot_surface
raise ValueError("Argument Z must be 2-dimensional.")
ValueError: Argument Z must be 2-dimensional.
原因
错误 "参数 Z 必须是二维的" 通常发生在尝试使用 Matplotlib 中的 plot_surface
绘制表面图时,输入数组 x_vals
、z_vals
和 y_vals
的形状与 y_vals
期望的二维形状不匹配。
在你的情况下,似乎 x_vals
、z_vals
和 y_vals
都是二维数组,但它们的形状可能存在不匹配。为了解决这个问题,你可以确保 x_vals
、z_vals
和 y_vals
都是形状一致的二维数组。
这个错误的原因是 plot_surface
函数期望输入是二维数组,其中 x_vals
和 z_vals
定义了网格,而 y_vals
包含了对应的高度信息。如果它们的形状不匹配,就会引发 "参数 Z 必须是二维的" 错误。
请确保在调用 plot_surface
之前,你的 x_vals
、z_vals
和 y_vals
具有相同的形状,并且是二维数组。
使用y_vals = y_vals.reshape(x_vals.shape)
y_vals = np.atleast_2d(y_vals)
这两行代码的目的是确保 y_vals
具有正确的形状和维度,以便与 plot_surface
函数一致。
-
y_vals = y_vals.reshape(x_vals.shape)
: 这一行使用reshape
函数将y_vals
重新塑造为与x_vals
相同的形状。plot_surface
函数期望输入是三维数组,其中x_vals
和z_vals
定义了网格,而y_vals
包含了对应的高度信息。通过使用reshape(x_vals.shape)
,确保y_vals
具有与x_vals
相同的形状,以便正确匹配网格。 -
y_vals = np.atleast_2d(y_vals)
: 这一行使用np.atleast_2d
函数确保y_vals
至少是一个二维数组。在某些情况下,forward_propagation
返回的结果可能是一维数组,但plot_surface
函数期望输入是二维数组。通过使用np.atleast_2d
,如果y_vals
是一维数组,它将被转换为一个行向量或列向量的二维数组。
这两行代码的目标是保证 y_vals
具有正确的形状和维度,以适应 Matplotlib 中的 plot_surface
函数。
修改后
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
def show_scatter_surface(xs, y, forward_propagation):
x = xs[:, 0]
z = xs[:, 1]
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, z, y)
x_vals = np.arange(np.min(x), np.max(x), 0.01)
z_vals = np.arange(np.min(z), np.max(z), 0.01)
x_vals, z_vals = np.meshgrid(x_vals, z_vals)
# Combine x_vals and z_vals into a 2D input feature array
input_features = np.column_stack((x_vals.flatten(), z_vals.flatten()))
y_vals = forward_propagation(input_features)
y_vals = y_vals.reshape(x_vals.shape) # Reshape to match the dimensions
# Make sure y_vals is a 2D array
y_vals = np.atleast_2d(y_vals)
ax.plot_surface(x_vals, z_vals, y_vals, cmap='rainbow')
plt.show()
# 示例用法
# 请确保定义了 forward_propagation 函数,并提供合适的 xs 和 y
# show_scatter_surface(xs, y, forward_propagation)