Demo9 :多分类问题
来源:B站刘二大人
说明:
- softmax的输入不需要做非线性变换。也就是说softmax之前不再需要激活函数(relu)。softmax两个作用:1.如果在进行softmax前的input有负数,通过指数变换,得到正数。2.所有类的概率求和为1。
- y的标签编码方式是one-hot编码:只有一位是1,其他位为0。(算法的输入仍为原始标签,只是经过算法后变成one-hot编码)
- 多分类问题,标签y的类型是LongTensor。比如说0-9分类问题,如果y = torch.LongTensor([3]),对应的one-hot是[0,0,0,1,0,0,0,0,0,0].(这里要注意:如果使用了one-hot,标签y的类型是LongTensor,课程中糖尿病数据集中的target的类型是FloatTensor)
- CrossEntropyLoss === Softmax +Log+ NLLLoss。也就是说只需要对线性层最后一层先进行SoftMax处理,再进行log操作,然后使用NLLLoss就行了,具体看代码
过程如下:
代码说明:
- torch.max的返回值有两个:第一个是每一行的最大值是,第二个是每一行最大值的下标(索引)。
传送1. torch.max函数用法
传送2. torch.max()使用讲解 - 在test中涉及到
with torch.no_grad() :
torch.no_grad() 详解
Python中with的用法 - 代码中出现了"_" 。 Python中各种下划线的操作
# 8 多分类问题
import torch
from torchvision import transforms #
from torchvision import datasets # 视觉领域基础数据集包,可下载经典数据集如minist等
from torch.utils.data import DataLoader # 第八讲中详细说明
import torch.nn.functional as F
import torch.optim as optim
# prepare dataset
# 开始用tensor,需要导入transforms包
batch_size = 64
# 需要对数据进行归一化,均值和方差处理
# (0.1307,), (0.3081,) 解释:minist数据集是灰度图,灰度图的方差和标准差是这两个固定值
# RGB图像的话就是这两组:([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(