机器学习基础
问题导入:什么是机器学习?
机器学习就是指计算机从数据中学习到规律,从而做出预测
举个例子:计算机视觉中常见的任务——对猫狗图片进行分类,我们很难编写算法直接对图片进行预测来判断它是猫还是狗。于是人们想到了数据驱动的方法,让计算机从大量带标签的图片中学习规律,计算机学习到其中的规律后,我们输入一张新的图片给计算机时,它就可以准确地预测这张图片显示的是猫还是狗。
机器学习两个关键因素
-
大量可学习的数据,如带标签的图片
-
学习的主体,一般称为模型
如何理解模型?
可以把模型看作是一个映射函数,它包含一些参数,这些参数可以与输入计算得到一个输出,我们一般称为预测结果
所谓模型学习的过程,就是模型修正其参数,改善其映射关系的过程
以预测图片是猫还是狗为例,步骤如下:
(1)创建模型
(2)输入一张带标签的图片
(3)使用此模型对次图片做出预测
(4)将预测结果与实际标签比较,产生的差距一般称为损失
(5)以减小损失为优化目标,根据损失优化模型参数
(6)循环重复(2)-(5)步
线性回归
根据受教育年限和平均收入之间对应关系的数据集,当给定一个人的受教育年限时,这个模型能预测其收入
代码思路如下:
- 导包
import torch
import pandas as pd
import matplotlib.pyplot as plt
- 读取并观察数据集
data = pd.read_csv("./datasets/Income1.csv")
print(data.head(3)) # 打印前三行
print(data.info()) # 查看数据基本数据情况
# 分析数据集组成/数据集可视化
plt.scatter(data.Education,data.Income)
plt.xlabel('Education')
plt.ylabel('Income')
plt.show()
- 数据预处理
# 数据预处理
X = torch.from_numpy(data.Education.to_numpy().reshape(-1,1)).type(torch.FloatTensor)
Y = torch.from_numpy(data.Income.to_numpy().reshape(-1,1)).type(torch.FloatTensor)
print(X.size(),Y.size())
- 创建模型
# 定义模型
from torch import nn
class EIModel(nn.Module):
def __init__(self):
super(EIModel,self).__init__() # 继承父类的属性
self.Linear = nn.Linear(in_features = 1, out_features = 1) # 创建线性层 Linear L需要大写
def forward(self,inputs):
logits = self.Linear(inputs)
return logits
# 模型初始化
model = EIModel()
loss_fn = nn.MSELoss() # 均方差误差损失函数
opt = torch.optim.SGD(model.parameters(),lr = 0.0001) # 初始化优化器
- 模型训练
# 模型训练
for epoch in range(5000): # 对全部数据训练5000次
for x,y in zip(X,Y): # 同时对X和Y迭代
y_pred = model(x) # 调用model得到预测输出y_pred
loss =loss_fn(y_pred,y) # 根据模型预测输出与实际的值y计算损失
opt.zero_grad() # 将累计的梯度置为0
loss.backward() # 反向传播损失,计算损失与模型参数之间的梯度
opt.step()
print('Down!')
- 判断模型的性能
print(list(model.named_parameters()))
这里模型有两个参数,分别是权重和偏置,y=wx+b中的w和b*
# 绘制原数据分布散点图
plt.scatter(data.Education,data.Income,label='real data')
# 绘制预测的直线
plt.plot(X,model(X).detach().numpy(),c='r',label = 'predicted line')
plt.xlabel('Education')
plt.ylabel('Income')
plt.legend() # 自动添加默认样式的图片
plt.show()
小结
- 本章简要介绍了机器学习基础知识,并演示pytorch创建和训练模型的整个流程
- 版权所有:《pytorch 深度学习简明实战》作者:日月光华
- 本篇仅作为笔记,主要目的是为了日后复习,如果有小伙伴想讨论相关知识,可以直接评论