本课主要还是了解用pytorch 进行cnn模型的构建,训练和评估。
- 导入所需函数
主要用到的是pytorch的nn包
次要matplotlib的pyplot可视化包
设置device - 载入CIFAR10图像数据集
分为80%训练数据和20%测试数据
可视化查看训练数据前20个记录 - 图像预处理
- 构造CNN
构造CNN
打印构造的模型情况 - 训练
- 测试
- 评估训练结果
1. 导入包
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torchinfo import summary
import warnings
# 设置Device
device = torch.device('cpu')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
这段代码的含义是将图像数据进行预处理,包括将图像转换为张量(tensor)格式,并进行归一化处理。
具体来说,使用了transforms.Compose()函数将多个预处理操作组合在一起,
其中包括transforms.ToTensor()将图像转换为张量格式,
以及transforms.Normalize()对张量进行归一化处理。
其中,(0.5,0.5,0.5)和(0.5,0.5,0.5)分别表示三个通道的均值和标准差,用于对图像进行归一化处理。
这样做的目的是为了使得图像数据的分布更加均匀,有利于提高模型的训练效果。
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,