复习吴恩达的机器学习课程,想到以前好多代码都没有记录,索性这次记录一下。
使用梯度下降算法计算线性回归模型的权重,代码如下:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# 文件中只含有一个变量与一个预测值
path = '~/condaProject/WUENDA/work1/ex1data1.txt'
data = pd.read_csv(path, header=None, names=['Population', 'Profit'])
data.head()
# 可视化数据
data.plot(kind='scatter', x='Population', y='Profit', figsize=(12,8))
# 读取数据,数据处理,在数据最前面添加一列常数,在计算时充当常数项
data.insert(0, 'Ones', 1)
# 初始化X,y
cols = data.shape[1] #shape表示行列数,[0]行,[1]列
X = data.iloc[:, :-1] # 数据为二维数组,取所有的行,以及除了最后一列的所有列
y = data.iloc[:, cols-1 : cols] # 同上,不过是只取最后一列
X.head()
# 因为X , y是一个DataFrame,所以要进行运算就要重新转换成矩阵
Xnp = np.array(X.values)
ynp = np.array(y.values)
theta = np.array([0,0]).reshape(1,2) # 定义theta初始值为0,一行两列
# 使用power函数计算代价函数J(theta)的值,X为一个矩阵
# 计算公式为 J(theta)= (1/2m)* (theta0 + theta1*Xi - yi)i从1-m
def computeCost(X, y, theta):
# 在此数据集中,X为97*2,theta为1*2
inner = np.power((np.dot(X , theta.T) - y), 2) # power可以求次方得出一个矩阵;dot 矩阵相乘,此处得到一个值,而非一个矩阵
# a.dot(b) == np.dot(a,b) 矩阵a 乘 矩阵b ,一维数组时,ab位置无所谓
# print(inner)
return np.sum(inner) / (2 * len(X))
c = computeCost(Xnp, ynp, theta) # 没有使用梯度下降的误差值
print(c)
# 梯度下降算法
def gD(X, y, theta, alpha=0.01,iters=1000):
temp = np.array(np.zeros(theta.shape)) # 初始化参数矩阵
cost = np.zeros(iters) # 初始化一个数组,包含每次更新后的代价值
m = X.shape[0] # 样本数目
for i in range(iters):
temp = theta - (alpha/m) * (X.dot(theta.T) - y).T.dot(X)
theta = temp
cost[i] = computeCost(X, y, theta)
return theta, cost # 返回迭代一万次后的theta权重与迭代一万次的一万个损失值
final_theta, cost = gD(Xnp, ynp, theta)
final_cost = computeCost(Xnp, ynp, final_theta) # 算出的cost跟第一万次的cost一样
population = np.linspace(data.Population.min(), data.Population.max(), 97) # 人口数的一维数组,从小到大排列
profit = final_theta[0, 0] + final_theta[0, 1] * population # 得到预测出权重的数学模型
# 绘制图像
fig, ax = plt.subplots(figsize=(8,6))
ax.plot()
ax.plot(population, profit, 'r', label='Prediction')#最小损失直线
ax.scatter(data['Population'], data['Profit'], label='Training data')#散点
ax.legend(loc=4) # 4表示标签在右下角
ax.set_xlabel('Population')
ax.set_ylabel('Profit')
ax.set_title('Prediction Profit by. Population Size')
plt.show()
以下为数据集,复制到txt文档即可。
6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
6.8825,3.9115
11.708,5.3854
5.7737,2.4406
7.8247,6.7318
7.0931,1.0463
5.0702,5.1337
5.8014,1.844
11.7,8.0043
5.5416,1.0179
7.5402,6.7504
5.3077,1.8396
7.4239,4.2885
7.6031,4.9981
6.3328,1.4233
6.3589,-1.4211
6.2742,2.4756
5.6397,4.6042
9.3102,3.9624
9.4536,5.4141
8.8254,5.1694
5.1793,-0.74279
21.279,17.929
14.908,12.054
18.959,17.054
7.2182,4.8852
8.2951,5.7442
10.236,7.7754
5.4994,1.0173
20.341,20.992
10.136,6.6799
7.3345,4.0259
6.0062,1.2784
7.2259,3.3411
5.0269,-2.6807
6.5479,0.29678
7.5386,3.8845
5.0365,5.7014
10.274,6.7526
5.1077,2.0576
5.7292,0.47953
5.1884,0.20421
6.3557,0.67861
9.7687,7.5435
6.5159,5.3436
8.5172,4.2415
9.1802,6.7981
6.002,0.92695
5.5204,0.152
5.0594,2.8214
5.7077,1.8451
7.6366,4.2959
5.8707,7.2029
5.3054,1.9869
8.2934,0.14454
13.394,9.0551
5.4369,0.61705