本文将介绍搭建一个简单的手写字体Mnist分类任务,了解一些基础的torch模块的使用,从较为简单的案例入手,为接下来的学习打下基础(其中不涉及卷积神经网络)。
#首先第一步读取数据
from pathlib import Path
import requests
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist" #mnist数据集文件路径
PATH.mkdir(parents=True, exist_ok=True)
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
下载好的mnist数据集为压缩包格式。
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[1].reshape((28, 28)), cmap="gray") #显示数据图片
print(x_train.shape) #打印数据的形状
(50000, 784)