PyTorch 深度学习实践 第九讲 ---多分类问题

Demo9 :多分类问题

来源:B站刘二大人
说明:

  1. softmax的输入不需要做非线性变换。也就是说softmax之前不再需要激活函数(relu)。softmax两个作用:1.如果在进行softmax前的input有负数,通过指数变换,得到正数。2.所有类的概率求和为1。
  2. y的标签编码方式是one-hot编码:只有一位是1,其他位为0。(算法的输入仍为原始标签,只是经过算法后变成one-hot编码)
  3. 多分类问题,标签y的类型是LongTensor。比如说0-9分类问题,如果y = torch.LongTensor([3]),对应的one-hot是[0,0,0,1,0,0,0,0,0,0].(这里要注意:如果使用了one-hot,标签y的类型是LongTensor,课程中糖尿病数据集中的target的类型是FloatTensor)
  4. CrossEntropyLoss === Softmax +Log+ NLLLoss。也就是说只需要对线性层最后一层先进行SoftMax处理,再进行log操作,然后使用NLLLoss就行了,具体看代码

过程如下:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
代码说明:

  1. torch.max的返回值有两个:第一个是每一行的最大值是,第二个是每一行最大值的下标(索引)。
    传送1. torch.max函数用法
    传送2. torch.max()使用讲解
  2. 在test中涉及到with torch.no_grad() :
    torch.no_grad() 详解
    Python中with的用法
  3. 代码中出现了"_" 。 Python中各种下划线的操作
# 8 多分类问题
import torch
from torchvision import transforms  # 
from torchvision import datasets  # 视觉领域基础数据集包,可下载经典数据集如minist等
from torch.utils.data import DataLoader  # 第八讲中详细说明
import torch.nn.functional as F
import torch.optim as optim

# prepare dataset

# 开始用tensor,需要导入transforms包
batch_size = 64

# 需要对数据进行归一化,均值和方差处理
# (0.1307,), (0.3081,) 解释:minist数据集是灰度图,灰度图的方差和标准差是这两个固定值
#  RGB图像的话就是这两组:([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值