线性回归模型

准备工作

  1. 引入numpy包和matplotlib.pyplot包
  2. 设置图表背景
    import numpy as py
    import matplotlib.pyplot as plt
    plt.system.use("blm")
    printf(plt.style.available)#可以查看所有可用的样式

 描点

x=np.array([1,2,1.5,1.8,1.2])#X轴的点
y=np.array([3,4,3.4,3.6,3.3])#Y轴的点
plt.scatter(x,y,marker='x',c='r')#画出一些点
plt.title("A Try")
plt.xlabel('x')
plt.ylabel('y')
plt.show()#显示图标

 函数

1.根据x,w,b获得预测的y点

def getLine(x,w,b):
    t_wb=np.zeros(x.shape[0])
    for i in range(x.shape[0]):
        t_wb[i]=w*x[i]+b
    return t_wb#根据提供的X,w,b获得预测的Y

 shape函数是Numpy中的函数,它的功能是读取矩阵的长度。直接用.shape可以快速读取矩阵的形状

import numpy as np
x=np.array([[1,2,3],[4,5,6]])
print(x.shape)
(2, 3)

在这里(2,3)表示的是2行3列

cost function

def getCost(x,y,w,b):
    m=x.shape[0]
    cost=0
    for i in range(m):
          cost+=((w*x[i]+b)-y[i])**2
    cost/=(2*m)
    return cost  

根据此函数,获得当前w,b的cost value

获得偏导值

def getGradient(x,y,w,b):
    dj_dw=0
    dj_db=0
    for i in range(x.shape[0]):
        dj_dw+=(((w*x[i]+b)-y[i])*x[i])
        dj_db+=((w*x[i]+b)-y[i])
    return (dj_db/x.shape[0]),(dj_dw/x.shape[0])

梯度下降函数

def gradient_descent(x,y,w,b,alpha,interation,getCost,getGradient):
    for i in range(interation):
        cost=getCost(x,y,w,b)
        dj_db,dj_dw=getGradient(x,y,w,b)
        w-=alpha*dj_dw
        b-=alpha*dj_db
        if i%1000==0:
            print(f'cost={cost}')
            plt.plot(x,getLine(x,w,b),c='b',label='final line')#画出一条线
            plt.scatter(x,y,marker='x',c='r',label='real point')
            plt.title("Final")
            plt.xlabel('x')
            plt.ylabel('y')
            plt.legend(loc='upper left')#显示标释,第一个字母为upper lower center,第二个为left right center
            plt.show()
    return w,b

当然,在此函数里我添加了多次打印,用来表现函数图象的变化

打印

最后,我们用代码打印出最终的函数

r_wb1=getLine(x,w_final,b_final)
plt.plot(x,r_wb1,c='b',label='final line')#画出一条线
plt.scatter(x,y,marker='x',c='r',label='real point')
plt.title("Final")
plt.xlabel('x')
plt.ylabel('y')
plt.legend(loc='upper left')#显示标释,第一个字母为upper lower center,第二个为left right center
plt.show()

在上面的代码里:

plt.plot用来打印直线,提供起始点,color,label(名称)

plt.scatter用来打印点,提供点,marker(点的样子),color,label

plt.xlabel:给x轴命名  plt.xlabel同理

plt.legend(loc="  ")用来确定点或线的注释区的位置

  • 8
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值