简单的LeNet网络模型
torchvision.datasets
torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。
以下是torchvision的构成:
- torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
- torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
- torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
- torchvision.utils: 其他的一些有用的方法。
这里只讲下面需要使用的部分:
下面我们介绍一下使用torchvision.datasets
中自带的Fashion-MNIST数据集。
这个数据集包含10种样本,按照标签从小到大的顺序对应的图像分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。训练集中每种样本6000张图片,所以训练集总共6w张图片,测试集中每种样本1000张图片,所以测试集总共1w张图片。每幅图片都是28×28的像素数组,