import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['font.size']=20 #设置全局字体大小
fig=plt.figure(figsize=(10,5)) #创建figure对象
ax=fig.add_axes([0,0,1,1]) #创建axes对象
ax.spines['top'].set_color(None) #隐藏上坐标轴
ax.spines['right'].set_color(None) #隐藏又坐标轴
#--------------------------------------------------------------
x=np.linspace(1,10,100) #创建1到10,元素个数100的等差数组
y=100+50*x+np.random.rand(100)*10 #广播
def piana(a,b): #对a偏导
sum=0
for i in range(100):
sum+=a+b*x[i]-y[i]
return sum/100
def pianb(a,b): #对b偏导
sum=0
for i in range(100):
sum+=(a+b*x[i]-y[i])*x[i]
return sum/100
a=b=0 #设定a、b初值(起点)
rate=0.01 #设定学习速率(控制step)
while abs(piana(a,b))>0.01 or abs(pianb(a,b))>0.01: #梯度下降
tempa=a-rate*piana(a,b)
tempb=b-rate*pianb(a,b)
a=tempa
b=tempb
#--------------------------------------------------------------
plt.scatter(x, y) #绘制度xy散点图
ax.plot(x,a+b*x,'r-',linewidth=3) #绘制拟合函数
ax.text(6.3,250,'训练集:\n$x=np.linspace(1,10,100)$\n$y=100+50*x+np.random.rand(100)*10$') #指定位置插入文本
ax.annotate(f'拟合函数:$y={a}+{b}x$',xy=(8,a+b*8), xytext=(9,450),arrowprops=dict(connectionstyle='arc3,rad=.1',color='r',shrink = 0.1)) #添加注释
plt.show() #显示图片
效果: