吴恩达机器学习ex1——通过人口预测小摊经济状况

吴恩达机器学习ex1

完整代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

#获取一个文本的数据(即本例中的数据集)
path = r'D:\Project\Pycharm Project\py2022\ExData\ex1data1.txt'#数据集的存储地址
data = pd.read_csv(path, header = None, names = ['Population', 'Profit'])#定义获取数据的形式,一列为人口,一列为经济利润
data.head()#head方法用来显示表格前五行的数据
#显示并检查数据
data.plot(kind='scatter', x='Population', y='Profit', figsize=(12,8))
plt.show()

# 接下来是定义计算代价函数
def cost_func(x, y, w, b):
    #获得每一点数据的误差(成本)
    cost_matrix = np.power(((x * w + b) - y), 2)  # 这个就是矩阵相乘得到一个n*1维的矩阵与n*1维的y矩阵进行相减的平方
    return np.sum(cost_matrix) / (2 * len(x))  # 利用sum函数将cost矩阵所有元素加起来,再除以2m(m为总共的数据数量)

# 计算迭代之后的w和b参数,w为斜率,b为截距,alpha为学习率
def compute_new_wb(x, y, w, b, alpha):
    m = len(x)
    dj_dw = 0
    dj_db = 0 #用于计算梯度

    for i in range(m):
        #老参数w和b所计算的所有y值(y0,y1,y2 ……)
        f_wb_i = w * x[i] + b
        #偏导之后的第i个和项(具体为什么这样可以看吴恩达老师讲解线性回归的视频)
        dj_dw_i = (f_wb_i - y[i]) * x[i]
        dj_db_i = (f_wb_i - y[i])
        ##进行累加
        dj_dw = dj_dw + dj_dw_i
        dj_db = dj_db + dj_db_i

    dj_dw = (1 / m) * dj_dw
    dj_db = (1 / m) * dj_db
	#此时的dj_dw,dj_db就为老参数情况下的偏导
    return w - alpha * dj_dw, b - alpha * dj_db	#返回新的w和b

#获取X矩阵和Y矩阵,为了方便计算成本函数,进行转置变成了列向量
X = np.matrix(data['Population'].values).T
Y = np.matrix(data['Profit'].values).T

print(compute_new_wb(X, Y, 0, 0, 0.001)) #测试一下w=0,b=0迭代之后的w和b
#初始将w和b设为1,学习率设为0.01
w = 1
b = 1
alpha = 0.01
#经过1500次迭代
for i in range(1500):
    w, b = compute_new_wb(X, Y, w, b, alpha)


print('最后计算出的w和b:', w, ',', b)
print('最后的成本函数', cost_func(X, Y, w, b))

# 接下来就是画图
# 在人口的最小值和最大值之间取100个点,返回一个矩阵
x = np.linspace(data.Population.min(), data.Population.max(), 100)
#w[0, 0]为数值,w[0][0]则为列表
y = w[0, 0] * x + b[0, 0] # 获得y矩阵

fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(x, y, 'r', label='Prediction')#使用plot方法绘制预测直线
ax.scatter(data.Population, data.Profit, label='Training Data')
ax.legend(loc = 2)
ax.set_xlabel('Population')
ax.set_ylabel('Profit')
ax.set_title('Predicted Profit vs. Population Size')
plt.show()
#%%预测人口3500 和 7000时的经济收益
prediction1 = w[0, 0] * 3.5 + b[0, 0]
print("人口为3500时,小吃摊经济规模:", prediction1)
prediction2 = w[0, 0] * 7.0 + b[0, 0]
print("人口为7000时,小吃摊经济规模:", prediction2)

代码输出

打印出来的数据集
在这里插入图片描述
整个数据集分布图(注意:数据集必须下载到自己的本地,可以搜索其他博主https://blog.csdn.net/qq_39435411/article/details/109763239
在这里插入图片描述
在这里插入图片描述
你也可以一开始改变w和b的值,例如w = 0,b = 0,不过最后预测出来的会有很小的误差
在这里插入图片描述

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值