吴恩达机器学习题解ex1

单变量线性回归

#导入模块
import matplotlib.pyplot as plt
#载入数据集
file = open("ex1data1.txt")
#数据划分,得到x,y两个列表数据集
data = file.readlines() #读取文本,将字符串格式的数据传入列表data
x = []
y = []
for num in data:
    num = num.split(',')   #将字符串按照 ‘ ,’切片为列表。
    x.append(float(num[0]))    #x_data 存放切片后的第一列数据【0】
    y.append(float(num[1]))
#获取数据集的大小
m = len(x)
#设置学习率alpha
a = 0.01
#梯度下降循环主体
none = True
count = 0 #循环次数
theta = [0,0] #创建一个数组,存放theta0和theta1两个数据
while none:
    #超过循环次数后,跳出循环
    count+=1
    if count>=10000:  #最多150次循环
        none = False
    wucha = [0,0] #存放两个误差数据
    for i in range(m):
        wucha[0] += theta[0]+theta[1]*x[i]-y[i]
        wucha[1] += (theta[0]+theta[1]*x[i]-y[i])*x[i]
    #迭代主体
    theta[0] = theta[0] - (a/m)*wucha[0]
    theta[1] = theta[1] - (a/m)*wucha[1]
profit = []   #利润 
for i in range(m): 
   outs = theta[0]+theta[1]*x[i]   #通过迭代得到的一次线性回归方程
   profit.append(outs) 
#设置x,y轴的范围
plt.axis([4,24,-5,25])
#设置x,y,title
plt.xlabel('Population')
plt.ylabel('Profit')
plt.title('Predicted Profit')
#绘制散点图和一次函数
plt.scatter(x,y)
plt.plot(x,profit,'r')
#显示散点图
plt.legend(["red","Blue"])#设置图例
plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值