这是对pga.csv中的数据进行分析,建立数据模型,得到下面的公式
紧接着,我们开始根据公示写代码
#梯度下降的代码 #2018.3.15 import pandas as pd import matplotlib.pyplot as plt import numpy as np pga = pd.read_csv("D:/pga.csv") pga.distance = (pga.distance - pga.distance.mean()) / pga.distance.std() pga.accuracy = (pga.accuracy - pga.accuracy.mean()) / pga.accuracy.std() #单变量的损失函数 def cost(theta0, theta1, x, y): #Initialize cost J = 0 m = len(x) for i in range(m): h = theta1 * x[i] + theta0 J += (h - y[i])**2 J /= (2 * m) return J # 对 theta1 进行求导 def partial_cost_theta1(theta0, theta1, x, y): h = theta0 + theta1 * x diff = (h - y) * x partial = diff.sum() / (x.shape[0]) return partial # 对theta0 进行求导 def partial_cost_theta0(theta0, theta1, x, y): h = theta0 + theta1 * x diff = (h - y) partial = diff.sum() / (x.shape[0]) return partial #使用梯度下降法进行更新 def gradient_descent(x, y, alpha=0.1, theta0=0, theta1=0): max_epochs = 1000 #最大迭代次数 counter = 0 #当前是第几次迭代 c = cost(theta1, theta0, pga.distance, pga.accuracy) #当前的代价函数 costs = [c] #每次的损失值都记录下来 convergence_thres = 0.000001 #设置一个收敛的阈值(两次迭代目标函数值没有相差多少就可以停止了) cprev = c + 10 theta0s = [theta0] theta1s = [theta1] #两次间隔迭代目标函数值没有差多少就可以停止了 while (np.abs(cprev - c) > convergence_thres) and (counter < max_epochs): cprev = c #先求导,倒数相当于步长 update0 = alpha * partial_cost_theta0(theta0, theta1, x, y) update1 = alpha * partial_cost_theta1(theta0, theta1, x, y) #算出theta0, theta1的值 theta0 -= update0 theta1 -= update1 theta0s.append(theta0) theta1s.append(theta1) # 当前迭代之后,参数发生更新 c = cost(theta0, theta1, pga.distance, pga.accuracy) costs.append(c) counter += 1 #将当前的theta0,theta1都返回回去 return {'theta0' : theta0, 'theta1' : theta1, 'costs' : costs} print("Theta0 =", gradient_descent(pga.distance, pga.accuracy)['theta0']) print("Theta1 =", gradient_descent(pga.distance, pga.accuracy)['theta1']) print("costs =", gradient_descent(pga.distance, pga.accuracy)['costs']) descend = gradient_descent(pga.distance, pga.accuracy, alpha=.01) plt.scatter(range(len(descend["costs"])), descend["costs"]) plt.show()
pga.csv的数据集是这种格式,这里只罗列一部分
distance accuracy 0 0.314379 -0.707727 1 1.693777 -1.586669 2 -0.059695 -0.176699 3 -0.574047 0.372640