动手学深度学习(pytorch版本) | 第三章:深度学习基础

本文深入探讨了线性回归和softmax回归在机器学习中的应用。线性回归用于连续值预测,而softmax回归则用于多分类任务。介绍了softmax回归的网络结构、交叉熵损失函数以及模型评估指标。通过实例讲解了数据预处理和Fashion-MNIST数据集的使用。此外,还概述了softmax回归的实现步骤。
摘要由CSDN通过智能技术生成



3.1 线性回归

线性回归输出是一个连续值,因此适用于回归问题
softmax回归则适用于分类问题

3.1.1 线性回归的节本要素

  1. 模型定义

  2. 模型训练
    训练数据——>损失函数——>优化算法

  3. 模型预测

3.1.2 线性回归的表示方法

3.4 softmax回归

  • 简介
  • softmax回归跟线性回归一样将输入特征与权重做线性叠加。softmax回归的输出值个数等于标签里的类别数。因为一共有4种特征和3种输出动物类别,所以权重包含12个标量(带下标的w)、偏差包含3个标量(带下标的b)

  • 网络类型:
  • softmax回归同线性回归一样,也是一个单层神经网络。由于每个输出结果的计算都要依赖于所有的输入,所以softmax回归的输出层也是一个全连接层。

    softmax运算符(softmax operator)将输出值变换成值为正且和为1的概率分布

    3.4.5 交叉熵损失函数

    用损失函数来衡量两个概率分布差异的测量函数。

  • 平方损失函数:
  • 过于严格:( || yˆ(i)−y(i) ||^2 ) /2
  • 交叉熵函数:
  • ![交叉熵公式](https://img-blog.csdnimg.cn/4acb36ccefaa46c2ab8b081cec60f3a0.png#pic_center) **交叉熵只关心对正确类别的预测概率**,因为只要其值足够大,就可以确保分类结果正确。但是遇到一个样本有多个标签时,例如图像里含有不止一个物体时,我们并不能做这一步简化。

    假设训练数据集的样本数为n,交叉熵损失函数定义为
    交叉熵损失函数

    3.4.6 模型预测及评价

    准确率(accuracy):
    正确预测数量 / 总预测数量。

    3.4.7 小结

  • softmax回归适用于分类问题。它使用softmax运算输出类别的概率分布。
  • softmax回归是一个单层神经网络,输出个数等于分类问题中的类别个数。
  • 交叉熵适合衡量两个概率分布的差异。
  • 3.5 图像分类数据集(Fashion—MNIST)

    torchvision包
    它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

    1. torchvision.datasets: 加载数据的函数或者数据集的接口
    
    2. torchvision.models: 常用的模型结构(包含预训练模型),包含AlexNet、VGG、ResNet等;
    
    3. torchvision.transforms: 常用的图片转换(裁剪旋转等)
    
    4. torchvision.utils: 其他的一些方法。
    

    3.5.1 获取数据集

    import torch 
    import torchvishion
    import torchvishion.transforms as transforms
    import matplotlib.pyplot as plt
    import time
    import sys
    sys.path.append("..")  # 为了导入上层目录的d21zh_pytorch
    import d21zh_pytorch as d21
    

    使所有数据转换为Tensor:
    transform = transforms.ToTensor()

    **transforms.ToTensor()**将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor。
    通道数C,图像的高度 H,宽度 W

    3.6 softmax回归从零开始实现

    二、使用步骤

    1.引入库

    代码如下(示例):

    2.读入数据

    代码如下(示例):

    data = pd.read_csv(
        'https://labfile.oss.aliyuncs.com/courses/1283/adult.data.csv')
    print(data.head())
    

    该处使用的url网络请求的数据。


    总结

    提示:这里对文章进行总结:
    例如:以上就是今天要讲的内容,本文仅仅简单介绍了pandas的使用,而pandas提供了大量能使我们快速便捷地处理数据的函数和方法。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值