本文为🔗365天深度学习训练营内部限免文章
参考本文所写记录性文章,请在文章开头保留以下内容
🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🍦 参考文章:365天深度学习训练营-第2周:彩色识别(训练营内部成员可读)
🍖 原作者:K同学啊|接辅导、项目定制
导包
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torchinfo import summary
import warnings
设置数据集
train_ds = torchvision.datasets.CIFAR10('data',
train=True,
transform=torchvision.transforms.ToTensor(),
download = True)
test_ds = torchvision.datasets.CIFAR10('data',
train=False,
transform=torchvision.transforms.ToTensor(),
download = True)
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,
batch_size=batch_size,
shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds,
batch_size=batch_size)