导入库
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
加载数据
torchvision 中包含一些常用的计算机视觉的数据集 这里用的是FashionMNIST
首先分别定义并下载训练数据和测试数据
training_data = datasets.FashionMNIST(
root="E:/mywork/demo/data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="E:/mywork/demo/data",
train=False,
download=True,
transform=ToTensor(),
)
然后定义batch_size, 并将数据载入到dataloader迭代器中
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
定义模型
要去定义一个神经网络,我们需要创建一个继承nn.Module的类。在类中的__init__()函数中定义网络的各个层,在forward函数中去定义数据经过整个网络的具体流程。
为了实现运算的加速,我们可以将数据移入GPU中计算
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device)