本文将介绍机器学习中较为常见也较为简单的线性回归模型,以帮助读者快速了解以及快速上手。
1.理论支撑
本文所使用的数据集为如下样例:
Experience | Salary |
1.1 | 25068 |
1.2 | 32661 |
1.6 | 35111 |
1.6 | 43569 |
1.7 | 34264 |
1.7 | 46598 |
1.8 | 37035 |
2.2 | 35545 |
现有二元关系工作经验和薪资,我们将通过其数据的数学逻辑构建线性回归模型,本文将用到
普通最小二乘法(Ordinary Least Squares)作为构建精确的线性回归模型的支持。
1.1 OLS
OLS 的目标是找到一条直线,使得数据点到该直线的垂直距离的平方和最小化。
以满足斜率b1
, 截距b0
满足
这是线性回归分析中常用的方法,以拟合一个模型,以实现对预测数据进行较为精确的预测,同时也满足了数据集的因变量和自变量的大致关系。
2. 数据处理
2.1 使用到以下库
- Numpy
- Pandas
- Matplotlib
- Sklearn
2.2 通过 pandas 读取csv文件
datasets = pd.read_csv('./Salary-1.csv')
2.3 将因变量工作经验用 x 来表示,自变量薪资用 y 来表示
iloc[:,:]
按位置索引来访问和操作数据
x = datasets.iloc[:,:-1].values
y = datasets.iloc[:,-1].values
2.4 分割训练集和验证集
test_size 是指测试集占原有数据集的比重,random_state则是随机种子,可以理解为打乱抽取
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
3.训练模型
3.1 初始化模型
使用 sklearn.linear_model 中的 LinearRegression
lr = LinearRegression()
3.2 进行训练
lr.fit(x_train,y_train)
4.通过图表对比模型预测效果
为了直观表示出数据与我们的简单线性回归模型之间的联系,我们将原始数据x_test,y_test通过散点图的形式表示,将预测数据通过直线表示
plt.scatter(x_test,y_test,color='red')
plt.plot(x_test,lr.predict(x_test),color='blue')
plt.xlabel('experience')
plt.ylabel('salary')
plt.title('relation')
plt.show()
比较图如下:
由于测试集数据较少,模型的直观性不高,我们换回训练集对比效果:
随着数据的增加,我们可以看到已经根据该数据集成功拟合了简单的线性回归模型。
简单的线性回归模型旨在拟合一个能预测较为准确的模型,当然了,这并不是百分百,而是处于一定范围之间的准确率。能大致预测出结果。
以上就是对机器学习中的简单线性回归模型的分享,欢迎各位的交流!