史上最易懂——一文详解线性回归算法的纯Python实现

本文作者:黄佳,极客时间专栏《零基础实战机器学习》作者,新加坡埃森哲公司高级顾问,人工智能专家,机器学习和云计算高级工程师,参与过公共事业、医疗、金融等多领域大型项目。著有《零基础学机器学习》,《SAP程序设计》,《SAP高级应用开发》,《SAP业务数据传输指南》。

写在前面

说到机器学习,大家可能会马上联想到艰深的算法,复杂的公式和高等数学。的确,算法和高等数学确实是机器学习时的基础知识储备。不过,我们也可以用比较浅显易懂的方法介绍一些机器学习相关的入门内容和基础算法。你会惊奇的发现入门机器学习并没有想象中那么高的门槛。

那么何为机器学习?

机器学习的关键内涵之一在于「利用计算机的运算能力从大量的数据中发现一个“函数”或“模型”,并通过它来模拟现实世界事物间的关系,从而实现预测或判断的功能。」 这个过程的关键是建立一个正确的模型,因此这个建模的过程就是机器的“学习”。

在这里插入图片描述
比如,一颗钻石的大小(自变量x1)、重量(自变量x2)、颜色(自变量x3)、密度(自变量x4)和它的价格(因变量y)的关系,就体现出了明显的相关性,如下图所示。Image这些自变量(x1, x2, x3, …, xn),在机器学习领域叫作特征(feature),因变量y ,在机器学习领域叫作标签(label)。「机器学习,就是在已知数据集的基础上,通过反复的计算,选择最贴切的函数(function)去描述数据集中特征和标签之间的关系」。

如果机器通过所谓的训练(training)找到了一个函数,对于已有的1000组钻石数据,它都能够根据钻石的各种特征,大致推断出其价格。那么,再给另一批同类钻石的大小、重量、颜色、密度等数据,就很有希望用同样的函数(模型)推断出这另一批钻石的价格。此时,已有的1000组有价格的钻石数据,就叫作「训练数据集」。另一批钻石数据,就叫作「测试数据集」。

通过机器学习模型不仅可以推测钻石价格,还可以实现影片票房预测、人脸识别、根据当前场景控制游戏角色的动作等诸多功能。

在这里插入图片描述

好了,那么关于机器学习最基本的知识我们就介绍到这里。今天我们将从无到有,来通过Python语言搭建一个非常简单的线性回归机器学习模型,并利用这个模型来预测一个零售网店的日销售额。

这次实战的数据集和源代码都可以在阿里云天池中下载:

访问下方地址,在阿里云的《零基础学机器学习》读书会中可以点击实践代码,即可获取全部代码和数据:

https://tianchi.aliyun.com/specials/promotion/activity/bookclub

机器学习项目实战框架

开始构建机器学习模型之前,我们要先简单介绍一下机器学习项目的实战框架,大致分为以下5个环节。

  1. 问题定义。
  2. 数据的收集和预处理。
  3. 模型(算法)的选择。
  4. 训练机器学习模型。
  5. 超参数调试和性能优化。

在这里插入图片描述

这5个环节,每一步的处理是否得当,都直接影响机器学习项目的成败。而且,这些步骤还需要在项目实战中以迭代的方式反复进行,以实现最优的效果。

1. 一个实际问题的定义

来看一下我们今天要解决的问题是什么?

小冰和朋友合伙开一个网店,这个店的基本情况是这样的:正式运营一年多,流量、 订单数和销售额都显著增长。经过一段时间的观察,小冰发现网店商品的销量和广告推广的力度息息相关。她在微信公众号推广,也通过微博推广,还在一些其他网站上面投放广告。当然,投入推广的资金越多,则商品总销售额越多。小冰问她的机器学习导师咖哥:“能不能通过机器学习算法,根据过去记录下来的广告投放金额和商品销售额,来预测在未来的某个节点,一个特定的广告投放金额对应能实现的商品销售额?”

在这里插入图片描述

那么基于这个问题来说呢,数据集的特征就是在各个平台上投放的广告金额(平台可以不止一个),而标签呢,也就是我们要预测的:商品销售额。

2. 数据收集和预处理

小冰已经把过去每周的广告投放金额和销售额数据整理成一个 Excel 表格,并保存为advertising.csv 文件,便于被Python读取。基本上每周的各种广告投放金额和商品销售额都记录在案。

在这里插入图片描述

这个数据记录是实现本课的机器学习项目的基础。没有准确的历史数据,我们是什么都做不了的。

显示数据

用Python代码来显示一下读入的数据:

import numpy as np # 导入NumPy 库
import pandas as pd # 导入Pandas 库
# 读入数据并显示前面几行的内容, 确保已经成功地读入数据
# 示例代码是Kaggle 的数据集读入文件, 如果在本机中则需要指定具体本地路径
# 如,当数据集和代码文件位于相同本地目录,路径名应为'./advertising.csv',或直接为'advertising.
# csv' 亦可
df_ads = pd.read_csv('../input/advertising-simple-dataset/advertising.csv')
df_ads.head()

在这里插入图片描述

相关分析

我们在这里会做一些相关分析,看看所选择的特征和标签之间的关系。相关性系数是一个-1 ~ 1 之间的值,正值表示正相关,负值表示负相关。数值越大,相关性越强。

如果a和b的相关性系数是1,则a和b总是相等的。如果a 和b 的相关性系数是0.9,则b 会显著地随着a 的变化而变化,而且变化的趋势保持一致。
如果a 和b 的相关性系数是0.3,则说明两者之间并没有什么明显的联系。

# 导入数据可视化所需要的库
import matplotlib.pyplot as plt #Matplotlib 为Python 画图工具库
import seaborn as sns #Seaborn 为统计学数据可视化工具库
# 对所有的标签和特征两两显示其相关性的热力图
sns.heatmap(df_ads.corr(), cmap="YlGnBu", annot = True)
plt.show() #plt 代表英文plot, 就是画图的意思

在这里插入图片描述

散点图

下面,通过散点图两两一组显示商品销售额和各种广告投放金额之间的对应关系,来将重点聚焦。散点图是回归分析中,数据点在直角坐标系平面上的分布图,它是相当有效的数据可视化工具。

# 显示销售额和各种广告投放金额的散点图
sns.pairplot(df_ads,
x_vars=['wechat', 'weibo', 'others'],
y_vars='sales',
height=4, aspect=1, kind='scatter')
plt.show()

代码运行之后输出的散点图清晰地展示出了销售额随各种广告投放金额而变化的大致趋势, 根据这个信息,就可以选择合适的函数对数据点进行拟合。

在这里插入图片描述

其实在这里,我们就基本上可以看出微信广告和销售额之间呈现的是一种线性关系,可以选择线性回归进行拟合,不过这是后话。

通过观察相关性和散点图,发现在本案例的3个特征中,微信广告投放金额和商品销售额的相关性比较高。因此,为了简化模型,我们将暂时忽略微博广告和其他类型广告投放金额这两组特征,只留下微信广告投放金额数据。这样,就把多变量的回归分析简化为单变量的回归分析。

数据集清洗和规范化

下面的代码把df_ads 中的微信公众号广告投放金额字段读入一个NumPy 数组X ,也就是清洗了其他两个特征字段,并把标签读入数组y :

X = np.array
  • 5
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值