用法介绍
torchvision中datasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。以datasets.MNIST类为例,具体参数和用法如下所示:
CLASS torchvision.datasets.MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
- root (string): 表示数据集的根目录,其中根目录存在MNIST/processed/training.pt和MNIST/processed/test.pt的子目录
- train (bool, optional): 如果为True,则从training.pt创建数据集,否则从test.pt创建数据集
- download (bool, optional): 如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载
- transform (callable, optional): 接收PIL图片并返回转换后版本图片的转换函数
- target_transform (callable, optional): 接收PIL接收目标并对其进行变换的转换函数
datasets.MNIST(“mnist-data”)下载mnist数据集之后生成的目录结构如下所示
代码实例
以下是用datasets.MNIST()类加载mnist数据集,并进行训练的完整代码
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.Sq1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), # (16, 28, 28) # output: (16, 28, 28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), # (16, 14, 14)
)
self.Sq2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), # (32, 14, 14)
nn.ReLU(),
nn.MaxPool2d(2), # (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.Sq1(x)
x = self.Sq2(x)
x = x.view(x.size(0), -1)
output = self.out(x)
return output
def train():
epoches = 1
mnist_net = CNN()
mnist_net.train()
loss_fn = nn.CrossEntropyLoss()
opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 5, shuffle=True)
loss = 0
for epoch in range(epoches):
for batch_X, batch_Y in train_loader:
opitimizer.zero_grad()
outputs = mnist_net(batch_X)
loss = loss_fn(outputs, batch_Y)
loss.backward()
opitimizer.step()
loss += loss.item()
print(loss)
if __name__ == '__main__':
train()