单变量线性回归
#导入模块
import matplotlib.pyplot as plt
#载入数据集
file = open("ex1data1.txt")
#数据划分,得到x,y两个列表数据集
data = file.readlines() #读取文本,将字符串格式的数据传入列表data
x = []
y = []
for num in data:
num = num.split(',') #将字符串按照 ‘ ,’切片为列表。
x.append(float(num[0])) #x_data 存放切片后的第一列数据【0】
y.append(float(num[1]))
#获取数据集的大小
m = len(x)
#设置学习率alpha
a = 0.01
#梯度下降循环主体
none = True
count = 0 #循环次数
theta = [0,0] #创建一个数组,存放theta0和theta1两个数据
while none:
#超过循环次数后,跳出循环
count+=1
if count>=10000: #最多150次循环
none = False
wucha = [0,0] #存放两个误差数据
for i in range(m):
wucha[0] += theta[0]+theta[1]*x[i]-y[i]
wucha[1] += (theta[0]+theta[1]*x[i]-y[i])*x[i]
#迭代主体
theta[0] = theta[0] - (a/m)*wucha[0]
theta[1] = theta[1] - (a/m)*wucha[1]
profit = [] #利润
for i in range(m):
outs = theta[0]+theta[1]*x[i] #通过迭代得到的一次线性回归方程
profit.append(outs)
#设置x,y轴的范围
plt.axis([4,24,-5,25])
#设置x,y,title
plt.xlabel('Population')
plt.ylabel('Profit')
plt.title('Predicted Profit')
#绘制散点图和一次函数
plt.scatter(x,y)
plt.plot(x,profit,'r')
#显示散点图
plt.legend(["red","Blue"])#设置图例
plt.show()