在本部分的练习中,需要预测房价,输入变量有两个特征,一是房子的面积,二是房子卧室的数量;输出变量是房子的价格。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
data = pd.read_csv('D:\yuxin\data_sets\ex1data2.txt',header=None,names=['size','rooms','price'])
X1 = data['size']
X2 = data['rooms']
y = data['price']
特征缩放:因为对于大多数的机器学习算法和优化算法来说,将特征值缩放到相同区间可以使得获取性能更好的模型。就梯度下降算法而言,例如有两个不同的特征,第一个特征的取值范围为1-10,第二个特征的取值范围为1-10000。在梯度下降算法中,代价函数为最小平方误差函数,所以在使用梯度下降算法的时候,算法会明显的偏向于第二个特征,因为它的取值范围更大,而且会导致多元梯度下降法收敛速度过慢。
常用的特征缩放算法有两种,归一化(normalization)和标准化(standardization)。归一化算法是通过特征的最大最小值将特征缩放到[0,1]区间范围内,而多于许多机器学习算法,标准化也许会更好,标准化是通过特征的平均值和标准差将特征缩放成一个标准的正态分布,均值为0,方差为1。
# X1_norm = (X1 - np.mean(X1))/(max(X1)-min(X1))
# X2_norm = (X2 - np.mean(X2))/(max(X2)-min(X2))
X1_norm = (X1 - np.mean(X1))/np.std(X1)
X2_norm = (X2 - np.mean(X2))/np.std(X2)
X1_norm = X1_norm.values.reshape(47,1)
X2_norm = X2_norm.values.reshape(47,1)
std =np.std(y)
mean = np.mean(y)
y = (y - mean)/std # 这里对标签Y进行处理是为了在进行梯度迭代的时候能够确保收敛 对预测值是没有影响的 需要反归一化
y = y.values.reshape(47,1)
X1_norm = np.insert(X1_norm,0,values=1,axis=1)
X= np.c_[X1_norm,X2_norm]
X1_norm.shape,X2_norm.shape,X.shape,y.shape
((47, 2), (47, 1), (47, 3), (47, 1))
正规方程法
theta1 = np.linalg.inv(X.T @ X) @ X.T@y
theta1
array([[-2.08166817e-17],
[ 8.84765988e-01],
[-5.31788197e-02]])
梯度下降法
J = []
def cost(X,y,theta):
j = np.sum((X@theta.T-y)**2)/(2*len(X))
return j
def gradient(X,y,theta,a,n):
for i in range(n):
theta1 = theta -(a/len(X))*((X@theta.T -y).T @X)
theta = theta1
j = cost(X,y,theta)
J.append(j)
return theta1[0]
theta0 =np.array([0,0,0]).reshape(1,3)
a = 1
n = 100
cost(X,y,theta0)
0.5
theta = gradient(X,y,theta0,a,n)
theta,theta.shape
(array([-1.00983052e-16, 8.84765988e-01, -5.31788197e-02]), (3,))
plt.figure()
plt.plot(J)
X@theta*std +mean
将预测输出值进行反归一化,发现与为进行归一的标签值的结果相同,说明了在进行归一化是,对标签进行归一化不会影响最终的预测结果,优点是在梯度下降迭代时不会因为数据差距过大而产生不收敛的情况。s
array([356283.1103389 , 286120.93063401, 397489.46984812, 269244.1857271 ,
472277.85514636, 330979.02101847, 276933.02614885, 262037.48402897,
255494.58235014, 271364.59918815, 324714.54068768, 341805.20024107,
326492.02609913, 669293.21223209, 239902.98686016, 374830.38333402,
255879.96102141, 235448.2452916 , 417846.48160547, 476593.38604091,
309369.11319496, 334951.62386342, 286677.77333009, 327777.17551607,
604913.37413438, 216515.5936252 , 266353.01492351, 415030.01477434,
369647.33504459, 430482.39959029, 328130.30083656, 220070.5644481 ,
338635.60808944, 500087.73659911, 306756.36373941, 263429.59076914,
235865.87731365, 351442.99009906, 641418.82407778, 355619.31031959,
303768.43288347, 374937.34065726, 411999.63329673, 230436.66102696,
190729.36558116, 312464.00137413, 230854.29304902])
X@theta
array([356283.1103389 , 286120.93063401, 397489.46984812, 269244.1857271 ,
472277.85514636, 330979.02101847, 276933.02614885, 262037.48402897,
255494.58235014, 271364.59918815, 324714.54068768, 341805.20024107,
326492.02609913, 669293.21223209, 239902.98686016, 374830.38333402,
255879.96102141, 235448.2452916 , 417846.48160547, 476593.38604091,
309369.11319496, 334951.62386342, 286677.77333009, 327777.17551607,
604913.37413438, 216515.5936252 , 266353.01492351, 415030.01477434,
369647.33504459, 430482.39959029, 328130.30083656, 220070.5644481 ,
338635.60808944, 500087.73659911, 306756.36373941, 263429.59076914,
235865.87731365, 351442.99009906, 641418.82407778, 355619.31031959,
303768.43288347, 374937.34065726, 411999.63329673, 230436.66102696,
190729.36558116, 312464.00137413, 230854.29304902])