一、线性回归
定义
线性回归是利用数理统计中回归分析,确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法,运用十分广泛。其表达形式为y = w’x+b,b为误差服从均值为0的正态分布。
回归分析中,只包括一个自变量和一个因变量,且二者的关系可用一条直线近似表示,这种回归分析称为一元线性回归分析。
求解思路
对于线性回归问题(y = wx + b),求解思路如下:
-
计算当前损失(loss)
– 此处采用方差计算当前的损失
– 对单个点求解损失并叠加
– 将总损失求平均,记为当前参数的损失值 -
更新参数,降低损失
– 此处采用梯度下降方式更新参数w和b -
循环
– 循环执行上述步骤,降低损失值。
二 、分步解析
Task1:计算损失
def compute_error(points, w, b):
total_error = 0
for point in points:
x = point[0]
y = point[1]
# computer mean squared error
total_error += (y - (w*x + b)) ** 2
# average loss for each point
return total_error / float(len(points))
该函数通过计算方差记录单一点的拟合损失,最后对损失求取平均值并输出。
Task2:梯度下降
def step_gradient(b_current, w_current, points, learning_rate):
# params
b_gradient = 0
w_gradient = 0
point_number = float(len(points))
for point in points:
x = point[0]
y = point[1]
# grad_b = 2(wx+b-y)
b_gradient += 2 * ((w_current * x + b_current) - y) / point_number
# grad_w = 2(wx+b-y)*x
w_gradient += 2 * ((w_current * x + b_current) - y) * x / point_number
# update w
new_b = b_current - (learning_rate * b_gradient)
new_w = w_current - (learning_rate * w_gradient)
return [new_b, new_w]
该部分通过梯度下降的方式,计算并更新参数。
对于现有参数,通过用其损失对参数求偏导,计算其导数并向着减小的方向降低参数值。
为防止移动过大,对偏微分的结果取平均并乘上设定的学习率,从而降低移动的距离。
Task3:循环
def gradient_descent_runner(initial_b, initial_w, num_iterations, points):
b = initial_b
w = initial_w
learning_rate = 0.0001
# update for several times
for i in range(num_iterations):
b, w = step_gradient(b, w, points, learning_rate)
print("After {0} iterations b = {1}, w = {2}, error = {3}".
format(i, b, w, compute_error(points, b, w))
)
return [b, w]
设置重复执行上述任务,从而不断的降低损失。
Task4:运行
def run(initial_b, initial_w, points, num_iterations):
print("Starting gradient descent at b = {0}, w = {1}, error = {2} ".
format(initial_b, initial_w, compute_error(points, initial_w, initial_b))
)
print("Running……")
[last_b, last_w] = gradient_descent_runner(initial_b, initial_w, num_iterations, points)
error = compute_error(points, last_b, last_w)
外部运行接口。
if __name__ == '__main__':
run()
三、类封装
将上述代码封装,即成为如下代码:
import numpy as np
class LinearRegression:
def __init__(self, data, learning_rate, num_iterations, initial_b, initial_w):
"""
:param data: 数据点 csv数据格式
:param learning_rate:学习率
:param num_iterations:训练次数
:param initial_b:初始b
:param initial_w:初始w
"""
# points
self.points = np.genfromtxt(data, delimiter=" ")
self.point_number = float(len(self.points))
# params
self.learning_rate = learning_rate
self.num_iterations = num_iterations
# init
self.initial_b = initial_b
self.initial_w = initial_w
self.error = self.compute_error(self.initial_b, self.initial_w)
self.last_b = initial_b
self.last_w = initial_w
def run(self):
print("Starting gradient descent at b = {0}, w = {1}, error = {2} ".
format(self.initial_b, self.initial_w, self.error)
)
print("Running……")
[self.last_b, self.last_w] = self.gradient_descent_runner()
self.error = self.compute_error(self.last_b, self.last_w)
def compute_error(self, w, b):
total_error = 0
for point in self.points:
x = point[0]
y = point[1]
# computer mean squared error
total_error += (y - (w * x + b)) ** 2
# average loss for each point
return total_error / self.point_number
def step_gradient(self, b_current, w_current):
b_gradient = 0
w_gradient = 0
for point in self.points:
x = point[0]
y = point[1]
# grad_b = 2(wx+b-y)
b_gradient += 2 * ((w_current * x + b_current) - y) / self.point_number
# grad_w = 2(wx+b-y)*x
w_gradient += 2 * ((w_current * x + b_current) - y) * x / self.point_number
# update w
new_b = b_current - (self.learning_rate * b_gradient)
new_w = w_current - (self.learning_rate * w_gradient)
return [new_b, new_w]
def gradient_descent_runner(self):
b = self.initial_b
w = self.initial_w
# update for several times
for i in range(self.num_iterations):
b, w = self.step_gradient(b, w)
print("After {0} iterations b = {1}, w = {2}, error = {3}".
format(i, b, w, self.compute_error(b, w))
)
return [b, w]
if __name__ == '__main__':
linear = LinearRegression(data="data.csv", learning_rate=0.0001, num_iterations=1000, initial_b=0, initial_w=0)
linear.run()
print("the last b = {0}, w = {1}, error = {2}".
format(linear.last_b, linear.last_w, linear.error)
)