import numpy as np
# y = wx + b
# 计算avg_loss
def compute_error_for_line_given_points(b, w, points):
totalError = 0
for i in range(0, len(points)): # range()左闭右开
x = points[i, 0]
y = points[i, 1]
totalError += (y - (w * x + b)) ** 2
return totalError / float(len(points))
# 进行梯度下降
def step_gradient(b_current, w_current, points, learningRate):
b_gradient = 0
w_gradient = 0
N = float(len(points))
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
b_gradient += 2 / N * ((w_current * x) + b_current - y)
w_gradient += 2 / N * ((w_current * x) + b_current - y) * x
new_b = b_current - learningRate * b_gradient
new_w = w_current - learningRate * w_gradient
return [new_b, new_w]
# 运行梯度下降
def gradient_descent_runner(points, starting_b, starting_w, learning_rate, num_iterations):
b = starting_b
w = starting_w
for i in range(0, len(points)):
b, w = step_gradient(b, w, np.array(points), learning_rate)
return [b, w]
def run():
points = np.genfromtxt("data.csv", delimiter=",")
learning_rate = 0.0001
initial_b = 0
initial_w = 0
num_iterations = 100
print("Starting gradient descent at b = {0} , w = {1}, error = {2}"
.format(initial_b, initial_w,
compute_error_for_line_given_points(initial_b, initial_w, points)))
print("Running...")
[b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
print("After {0} iterations b = {1}, w = {2}, error = {3}"
.format(num_iterations, b, w,
compute_error_for_line_given_points(b, w, points)))
run()
运行结果:
data.csv数据:
32.50235 | 31.70701 |
53.4268 | 68.7776 |
61.53036 | 62.56238 |
47.47564 | 71.54663 |
59.81321 | 87.23093 |
55.14219 | 78.21152 |
52.2118 | 79.64197 |
39.29957 | 59.17149 |
48.10504 | 75.33124 |
52.55001 | 71.30088 |
45.41973 | 55.16568 |
54.35163 | 82.47885 |
44.16405 | 62.00892 |
58.16847 | 75.39287 |
56.72721 | 81.43619 |
48.95589 | 60.7236 |
44.6872 | 82.8925 |
60.29733 | 97.3799 |
45.61864 | 48.84715 |
38.81682 | 56.87721 |
66.18982 | 83.87856 |
65.41605 | 118.5912 |
47.48121 | 57.25182 |
41.57564 | 51.39174 |
51.84519 | 75.38065 |
59.37082 | 74.76556 |
57.31 | 95.45505 |
63.61556 | 95.22937 |
46.73762 | 79.05241 |
50.55676 | 83.43207 |
52.224 | 63.35879 |
35.56783 | 41.41289 |
42.43648 | 76.61734 |
58.16454 | 96.76957 |
57.50445 | 74.08413 |
45.44053 | 66.58814 |
61.89622 | 77.76848 |
33.09383 | 50.71959 |
36.43601 | 62.12457 |
37.67565 | 60.81025 |
44.55561 | 52.68298 |
43.31828 | 58.56982 |
50.07315 | 82.90598 |
43.87061 | 61.42471 |
62.99748 | 115.2442 |
32.66904 | 45.57059 |
40.1669 | 54.08405 |
53.57508 | 87.99445 |
33.86421 | 52.72549 |
64.70714 | 93.57612 |
38.11982 | 80.16628 |
44.50254 | 65.10171 |
40.59954 | 65.5623 |
41.72068 | 65.28089 |
51.08863 | 73.43464 |
55.0781 | 71.13973 |
41.37773 | 79.10283 |
62.4947 | 86.52054 |
49.20389 | 84.7427 |
41.10269 | 59.35885 |
41.18202 | 61.68404 |
50.18639 | 69.8476 |
52.37845 | 86.09829 |
50.13549 | 59.10884 |
33.64471 | 69.89968 |
39.5579 | 44.86249 |
56.13039 | 85.49807 |
57.36205 | 95.53669 |
60.26921 | 70.25193 |
35.67809 | 52.72173 |
31.58812 | 50.39267 |
53.66093 | 63.6424 |
46.68223 | 72.24725 |
43.10782 | 57.81251 |
70.34608 | 104.2571 |
44.49286 | 86.64202 |
57.50453 | 91.48678 |
36.93008 | 55.23166 |
55.80573 | 79.55044 |
38.95477 | 44.84712 |
56.90121 | 80.20752 |
56.8689 | 83.14275 |
34.33312 | 55.72349 |
59.04974 | 77.63418 |
57.78822 | 99.05141 |
54.28233 | 79.12065 |
51.08872 | 69.5889 |
50.28284 | 69.5105 |
44.21174 | 73.68756 |
38.00549 | 61.3669 |
32.94048 | 67.17066 |
53.69164 | 85.6682 |
68.76573 | 114.8539 |
46.23097 | 90.12357 |
68.31936 | 97.91982 |
50.03017 | 81.53699 |
49.23977 | 72.11183 |
50.03958 | 85.23201 |
48.14986 | 66.22496 |
25.12848 | 53.45439 |