当我们只取一个参数w的时候,b不存在,那么就观察损失函数的情况
这里引用滑块的函数用来表是不同的w的值对函数的影响
from matplotlib.widgets import Slider
import matplotlib.pyplot as plt
# 创建一个图形窗口和子图
fig, ax = plt.subplots()
# 定义滑块的初始值和范围
initial_value = 0.0
min_value = -50
max_value = 500
# 创建一个滑块控件
slider_ax = plt.axes([0.25, 0.1, 0.65, 0.03]) #窗口的参数
slider = Slider(slider_ax, 'w', min_value, max_value, valinit=initial_value)
'''
slider_ax:滑动窗口的大小
min_value:窗口的达到的最小值
max_value:窗口的达到的最大值
valinit=(具体的值):窗口的初始位置
'''
# 定义更新函数
def update(val):
# 获取滑块的值
slider_value = slider.val
# 打印滑块的值
print("Slider value:", slider_value)
# 将更新函数与滑块控件关联
slider.on_changed(update)
# 显示图形
plt.show()
如图当滑动蓝色的矩形的时候,会输出对应的坐标轴上的值
首先初始化数据,并绘制初始的图:
# 初始化数据
x_train = np.array([1, 2, 3, 4, 5])
y_train = np.array([200, 400, 500, 400, 500])
#创建一个窗口和两个子图
fig,ax = plt.subplots(1,2)
plt.subplots_adjust(bottom = 0.25)
#设置初始的w的值,因为绘制的是b = 0的情况(没有b),所有w就置为0
initial_w = 0.0
#绘制初始的预测结果
y = initial_w * x_train
line, = ax[0].plot(x_train, y, c='b', label='预测')
ax[0].scatter(x_train,y_train,marker = 'X',c = 'r')
#二次函数(J函数)
w_range = np.arange(-50,501,1)
y_range = compute(x_train,y_train,w_range)
#创建函数图像
lin ,= ax[1].plot(w_range,y_range,label = 'f = wx')
ax[1].set_xlabel('W')
ax[1].set_ylabel('dis')
ax[0].legend()
ax[1].legend()
这里创建了一个画布的两个子图,分别表示初始预测的和对应的损失函数
之后定义一个函数计算不同w值对应的损失函数
def compute(x,y,w_array):
arr = np.zeros_like(w_array)
for j , w in enumerate(w_array):
m = x.shape[0]
tol = 0
for i in range(m):
tol += (w * x[i] - y[i]) ** 2
arr[j] = (1/(2 * m)) * tol
return arr
然后就是怎么去滑动:
#更新函数
def update(val):
w = slider.val # 获取滑块的值
# slider.val是用于获取滑块当前所处位置对应的数值的属性
#当滑块到某一个数值的时候,通过slider.val来获取这个数值然后进行数据的更新和绘图
y = w * x_train #获得y的最新的值
line.set_ydata(y) # 更新图上的y值坐标
# 删除上一次的差值记录(每次更新前都清空两个子图的所有内容,保证不会被上次的内容干扰到本次的绘图)
for arrow in ax[0].patches:
arrow.remove()
for text in ax[0].texts:
text.remove()
for arrow in ax[1].patches:
arrow.remove()
for text in ax[1].texts:
text.remove()
# 以上常在需要动态更新图表时使用,
# 特别是在与滑块(Slider)、按钮(Button)、单选按钮(RadioButtons)等
# Matplotlib的交互式控件结合使用时
for i in range(len(x_train)):
#计算误差坐标
dx = x_train[i] * 0.01
dy = w * y_train[i] * 0.01
# 设置误差线
ax[0].annotate('', xy=(x_train[i], w * x_train[i]), xytext=(x_train[i], y_train[i]),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
# 设置误差
error = np.abs(y_train[i] - w * x_train[i])
ax[0].text(x_train[i] + dx, w * x_train[i] + dy, 'er:' + str(error), fontsize=8)
# 在update函数中,获取与 w 最接近的索引
nearest_index = np.abs(w_range - w).argmin()
# 使用最接近的索引来确定箭头的位置
ax[1].annotate('', xy=(w, 0), xytext=(w_range[nearest_index], y_range[nearest_index]),
arrowprops=dict(arrowstyle='-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].annotate('', xy=(0, y_range[nearest_index]), xytext=(w_range[nearest_index], y_range[nearest_index]),
arrowprops=dict(arrowstyle='-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].text(w, y_range[nearest_index], f'point:{(w, y_range[nearest_index])}', fontsize=8)
fig.canvas.draw_idle() # 重画图像
当我们去滑动窗口的时候,要保证当前的图上只有当前的结果,不能有上次的结果去影响
所以每次都要清空图上的内容:
# 删除上一次的差值记录(每次更新前都清空两个子图的所有内容,保证不会被上次的内容干扰到本次的绘图)
for arrow in ax[0].patches:
arrow.remove()
for text in ax[0].texts:
text.remove()
for arrow in ax[1].patches:
arrow.remove()
for text in ax[1].texts:
text.remove()
#以上常在需要动态更新图表时使用,
#特别是在与滑块(Slider)、按钮(Button)、单选按钮(RadioButtons)等
# Matplotlib的交互式控件结合使用时
之后就是设计具体的滑动效果:
dx = x_train[i] * 0.01 # 设置箭头位置
dy = w * x_train[i] * 0.01
#dx 和 dy 分别表示箭头在x和y方向上的偏移量。
# 设置误差线
ax[0].annotate('', xy=(x_train[i], w * x_train[i]), xytext=(x_train[i], y_train[i]),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
#使用 ax[0].annotate() 方法绘制误差箭头,该箭头从当前样本点的真实值位置指向当前样本点的预测值位置。
# 这里使用了 arrowprops 参数来设置箭头的样式和属性。
# 设置误差
error = np.abs(y_train[i] - w * x_train[i])
ax[0].text(x_train[i] + dx, w * x_train[i] + dy, 'er:'+str(error), fontsize=8)
#使用 ax[0].text() 方法在误差箭头附近添加误差标签,标签内容为当前样本点的误差值。
在Matplotlib中,ax.annotate()
是一个功能强大的函数,用于在图上添加注释。这个函数允许你在一个特定的数据坐标 (xy)
上放置文本,并可以指定一个箭头从 (xy)
指向另一个位置 (xytext)
。这对于在图上解释或突出显示特定数据点非常有用。
ax.annotate()
的基本语法如下:
ax.annotate(text, xy, xytext=None, xycoords='data', textcoords=None,
arrowprops=None, **kwargs)
参数解释:
text
:注释的文本内容。xy
:被注释的数据点坐标,通常是一个元组,如(x, y)
。xytext
:注释文本的位置坐标,也是一个元组。如果为None
,则文本将直接位于xy
指定的位置。xycoords
:xy
坐标系的类型,如'data'
(数据坐标)、'figure'
(图形坐标)等。默认为'data'
。textcoords
:xytext
坐标系的类型。如果为'offset points'
,则xytext
将被解释为从xy
的偏移量(以点为单位)。arrowprops
:一个字典,用于定义箭头的属性,如箭头样式、颜色、线宽等。**kwargs
:其他关键字参数,如文本的字体、颜色等。
在你给出的代码中,ax[0].annotate()
被用来从预测值 (x_train[i], w * x_train[i])
指向实际值 (x_train[i], y_train[i])
的位置,并显示一个箭头。这里 xy
是预测值的位置,而 xytext
实际上是实际值的位置(尽管通常 xytext
用于指定注释文本的位置,但在这里它实际上被用作箭头的起点)。
为了更清晰地显示箭头的起点和终点,你可能想要将 xytext
设置为一个稍微偏离实际值 y_train[i]
的位置,以避免箭头与实际值的数据点重叠。同时,通过 dx
和 dy
的使用,你已经尝试对文本位置进行了微调。
最后,arrowprops
字典定义了箭头的样式和属性,如箭头样式 arrowstyle='->'
、连接样式 connectionstyle='arc3'
、颜色 color='gray'
和线宽 lw=1
。
在Matplotlib中,ax.text()
方法用于在坐标轴上添加文本。这个方法允许你指定文本的位置、内容以及其他属性,如字体、颜色等。
ax.text()
的基本语法如下
ax.text(x, y, s, fontdict=None, **kwargs)
x
,y
:文本的位置坐标,通常是数据坐标(即x轴和y轴上的值)。s
:要添加的文本内容。fontdict
:一个字典,用于定义文本的字体属性,如字体名称、大小、颜色等。如果未提供,则使用当前的默认字体属性。**kwargs
:其他关键字参数,用于定义文本的其他属性,如旋转角度、对齐方式等。
在你给出的代码中,ax[0].text()
被用来在每个数据点上添加表示误差的文本。具体地,你计算了每个数据点的实际值与预测值之间的误差,并将误差的绝对值以文本形式显示在图上。文本的位置是通过在数据点的x坐标和预测值的y坐标上添加小的偏移量(dx
和 dy
)来确定的,以避免文本与数据点重叠。
eg:这是一个示例,展示了如何使用 ax.text()
在图上添加文本:
import matplotlib.pyplot as plt
import numpy as np
# 假设有一些数据点
x_train = np.linspace(0, 10, 10)
y_train = 3 * x_train + 2 # 假设的线性关系
w = 2.5 # 假设的权重
# 创建一个图形和坐标轴
fig, ax = plt.subplots(1, 1)
# 绘制数据点
ax.scatter(x_train, y_train, label='Data')
# 绘制线性拟合线(这里只是一个示例,实际上应该使用拟合得到的权重)
ax.plot(x_train, w * x_train, 'r-', label='Fit')
# 添加文本以显示误差
for i in range(len(x_train)):
# 计算误差
error = np.abs(y_train[i] - w * x_train[i])
# 为文本添加小的偏移量以避免重叠
dx = x_train[i] * 0.01
dy = (w * x_train[i]) * 0.01
# 添加文本
ax.text(x_train[i] + dx, w * x_train[i] + dy, f'er:{error:.2f}', fontsize=8)
# 添加图例和标签
ax.legend()
ax.set_xlabel('X')
ax.set_ylabel('Y')
# 显示图形
plt.show()
在这个示例中,我们为每个数据点添加了一个文本标签,显示该点的预测误差。注意我们使用了 f'er:{error:.2f}'
来格式化文本,这样误差值将以两位小数的形式显示。
在损失函数上用w_range中与w最接近的元素的索引来进行数据的更新:
# 在update函数中,获取与 w 最接近的索引
nearest_index = np.abs(w_range - w).argmin()
# 使用最接近的索引来确定箭头的位置
ax[1].annotate('', xy=(w, 0), xytext=(w_range[nearest_index], y_range[nearest_index]),
arrowprops=dict(arrowstyle='-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].annotate('', xy=(0, y_range[nearest_index]), xytext=(w_range[nearest_index], y_range[nearest_index]),
arrowprops=dict(arrowstyle='-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].text(w, y_range[nearest_index], f'point:{(w, y_range[nearest_index])}', fontsize=8)
fig.canvas.draw_idle() # 重画图像
为什么要w_range中与w最接近的元素的索引呢?
在NumPy中,np.abs(w_range - w).argmin()
这个表达式是用来找到数组 w_range
中与给定值 w
最接近的元素的索引。这个表达式的工作原理可以分为几个步骤来解释:
-
计算差值:
w_range - w
会对w_range
数组中的每个元素与w
做差,得到一个新的数组,其中的每个元素是原数组元素与w
的差值。 -
取绝对值:
np.abs(...)
会取上一步得到的差值数组的绝对值。这是必要的,因为我们关心的是差值的大小,而不是差值的正负。取绝对值可以确保我们找到的是距离w
最近的数值,而不是仅仅是小于或大于w
的数值。 -
找到最小值的索引:
.argmin()
方法会返回绝对值数组中最小元素的索引。这个索引对应于w_range
中与w
最接近的元素。
所以,np.abs(w_range - w).argmin()
整体上就是在找到 w_range
中与 w
数值上最接近的元素所对应的索引位置。这种方法在计算上非常高效,因为它利用了NumPy的向量化操作,避免了显式的循环遍历数组。
最后关于几个小的方面:
完整代码:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
def compute(x,y,w_array):
arr = np.zeros_like(w_array)
for j , w in enumerate(w_array):
m = x.shape[0]
tol = 0
for i in range(m):
tol += (w * x[i] - y[i]) ** 2
arr[j] = (1/(2 * m)) * tol
return arr
#更新函数
def update(val):
w = slider.val # 获取滑块的值
# slider.val是用于获取滑块当前所处位置对应的数值的属性
#当滑块到某一个数值的时候,通过slider.val来获取这个数值然后进行数据的更新和绘图
y = w * x_train #获得y的最新的值
line.set_ydata(y) # 更新图上的y值坐标
# 删除上一次的差值记录(每次更新前都清空两个子图的所有内容,保证不会被上次的内容干扰到本次的绘图)
for arrow in ax[0].patches:
arrow.remove()
for text in ax[0].texts:
text.remove()
for arrow in ax[1].patches:
arrow.remove()
for text in ax[1].texts:
text.remove()
# 以上常在需要动态更新图表时使用,
# 特别是在与滑块(Slider)、按钮(Button)、单选按钮(RadioButtons)等
# Matplotlib的交互式控件结合使用时
for i in range(len(x_train)):
#计算误差坐标
dx = x_train[i] * 0.01
dy = w * y_train[i] * 0.01
# 设置误差线
ax[0].annotate('', xy=(x_train[i], w * x_train[i]), xytext=(x_train[i], y_train[i]),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
# 设置误差
error = np.abs(y_train[i] - w * x_train[i])
ax[0].text(x_train[i] + dx, w * x_train[i] + dy, 'er:' + str(error), fontsize=8)
# 在update函数中,获取与 w 最接近的索引
nearest_index = np.abs(w_range - w).argmin()
# 使用最接近的索引来确定箭头的位置
ax[1].annotate('', xy=(w, 0), xytext=(w_range[nearest_index], y_range[nearest_index]),
arrowprops=dict(arrowstyle='-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].annotate('', xy=(0, y_range[nearest_index]), xytext=(w_range[nearest_index], y_range[nearest_index]),
arrowprops=dict(arrowstyle='-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].text(w, y_range[nearest_index], f'point:{(w, y_range[nearest_index])}', fontsize=8)
fig.canvas.draw_idle() # 重画图像
# 初始化数据
x_train = np.array([1, 2, 3, 4, 5])
y_train = np.array([200, 400, 500, 400, 500])
#创建一个窗口和两个子图
fig,ax = plt.subplots(1,2)
plt.subplots_adjust(bottom = 0.25)
#设置初始的w的值,因为绘制的是b = 0的情况(没有b),所有w就置为0
initial_w = 0.0
#绘制初始的预测结果
y = initial_w * x_train
line, = ax[0].plot(x_train, y, c='b', label='预测')
ax[0].scatter(x_train,y_train,marker = 'X',c = 'r')
#二次函数(J函数)
w_range = np.arange(-50,501,1)
y_range = compute(x_train,y_train,w_range)
#创建函数图像
lin ,= ax[1].plot(w_range,y_range,label = 'f = wx')
ax[1].set_xlabel('W')
ax[1].set_ylabel('dis')
ax[0].legend()
ax[1].legend()
# 创建一个滑块来表示不同的w的值,J函数的值
slider_ax = plt.axes([0.25,0.1,0.65,0.03])
slider = Slider(slider_ax,'w',-50,500,valinit=initial_w)
#将更新函数与滑块关联
slider.on_changed(update)
plt.show()
结果:
原版:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
"""
>>> 对于函数图像,w值改变,两点值改变,直线也改变,从w,b改变,但我们并没用改变b。
"""
# 二次函数
def compute(x, y, w_array):
arr = np.zeros_like(w_array) # 创建一个与 w_array 相同形状的数组,用于存储每个 w 对应的误差
for j, w in enumerate(w_array):
m = x.shape[0] # 获取样本数量
tol = 0
for i in range(m):
tol += (w * x[i] - y[i]) ** 2
arr[j] = (1 / (2 * m)) * tol
return arr
#更新函数
def update(val):
w = slider.val # 获取滑块的值
#slider.val是用于获取滑块当前所处位置对应的数值的属性
'''
当滑块到某一个数值的时候,通过slider.val来获取这个数值然后进行数据的更新和绘图
'''
y = w * x_train #获得y的最新的值
line.set_ydata(y) # 更新图上的y值坐标
# 删除上一次的差值记录(每次更新前都清空两个子图的所有内容,保证不会被上次的内容干扰到本次的绘图)
for arrow in ax[0].patches:
arrow.remove()
for text in ax[0].texts:
text.remove()
for arrow in ax[1].patches:
arrow.remove()
for text in ax[1].texts:
text.remove()
#以上常在需要动态更新图表时使用,
#特别是在与滑块(Slider)、按钮(Button)、单选按钮(RadioButtons)等
# Matplotlib的交互式控件结合使用时
for i in range(len(x_train)):
# 误差坐标
dx = x_train[i] * 0.01 # 设置箭头位置
dy = w * x_train[i] * 0.01
#dx 和 dy 分别表示箭头在x和y方向上的偏移量。
# 设置误差线
ax[0].annotate('', xy=(x_train[i], w * x_train[i]), xytext=(x_train[i], y_train[i]),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
#使用 ax[0].annotate() 方法绘制误差箭头,该箭头从当前样本点的真实值位置指向当前样本点的预测值位置。
# 这里使用了 arrowprops 参数来设置箭头的样式和属性。
# 设置误差
error = np.abs(y_train[i] - w * x_train[i])
ax[0].text(x_train[i] + dx, w * x_train[i] + dy, 'er:'+str(error), fontsize=8)
#使用 ax[0].text() 方法在误差箭头附近添加误差标签,标签内容为当前样本点的误差值。
'''
ax[1].annotate('', xy=(w, 0), xytext=(w, y_range[int(w)]),
arrowprops=dict(arrowstyle='<-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].annotate('', xy=(0, y_range[int(w)]), xytext=(w, y_range[int(w)]),
arrowprops=dict(arrowstyle='<-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
'''
# 在update函数中,获取与 w 最接近的索引
nearest_index = np.abs(w_range - w).argmin()
# 使用最接近的索引来确定箭头的位置
ax[1].annotate('', xy=(w, 0), xytext=(w_range[nearest_index], y_range[nearest_index]),
arrowprops=dict(arrowstyle='-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].annotate('', xy=(0, y_range[nearest_index]), xytext=(w_range[nearest_index], y_range[nearest_index]),
arrowprops=dict(arrowstyle='-', connectionstyle='arc3', color='gray', lw=1),
fontsize=8)
ax[1].text(w, y_range[nearest_index], f'point:{(w, y_range[nearest_index])}', fontsize=8)
fig.canvas.draw_idle() # 重画图像
# 初始化数据
x_train = np.array([1, 2, 3, 4, 5])
y_train = np.array([200, 400, 500, 400, 500])
# 创建一个图形窗口和子图
fig, ax = plt.subplots(1, 2)
plt.subplots_adjust(bottom=0.25)
# 设置初始的 w 值
initial_w = 0.0
# 绘制初始的预测结果
y = initial_w * x_train
line, = ax[0].plot(x_train, y, c='b', label='Prediction')
ax[0].scatter(x_train, y_train, marker='x', c='r')
# 二次函数
w_range = np.arange(-50, 501, 1)
y_range = compute(x_train, y_train, w_range)
# 创建二次函数图像
lin, = ax[1].plot(w_range, y_range, label='f = wx')
ax[1].set_xlabel('W')
ax[1].set_ylabel('dis')
ax[0].legend()
ax[1].legend()
#通常是右上角添加一个图例,其中包含了你为各个绘图元素指定的标签
'''compute loss
def compute_a(x, w):
m = x.shape[0]
y_temp = w * x
res = 0
for i in range(len(y_temp)):
res += (1 / (2 * m)) * (y_temp[i] - y_train[i]) ** 2
return res
line_dis = ax[1].plot(initial_w, compute_a(x_train, initial_w), c='b', label='distance')
'''
# 创建一个滑块控件
slider_ax = plt.axes([0.25, 0.1, 0.65, 0.03])
slider = Slider(slider_ax, 'w', -50, 500, valinit=initial_w)
# 将更新函数与滑块控件关联
slider.on_changed(update)
# 显示图形
plt.show()
结果与上面大致一样