import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# torchvision 数据集的输出是范围在[0,1]之间的 PILImage,我们将他们转换成归一化范围为[-1,1]之间的张量 Tensors
# Compose()串联对图像的操作
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
# shuffle 每一个epoch过程中会打乱数据顺序,重新随机选择
# num_workers 线程数,默认为0,在主线程运行
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
pytorch 图像分类器代码详解
最新推荐文章于 2024-08-22 14:30:41 发布
本文详细讲解了如何使用PyTorch构建一个图像分类器,从数据预处理到模型搭建,再到训练和评估,涵盖了神经网络基础知识和PyTorch关键API的运用。
摘要由CSDN通过智能技术生成