1.总的程序
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 18 15:19:41 2021
@author: pony
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
#matplotlib inline
# 定义超参数
image_size = 28 #图像的总尺寸28*28
num_classes = 10 #标签的种类数
num_epochs = 20 #训练的总循环周期
batch_size = 64 #一个撮(批次)的大小,64张图片
# 加载MINIST数据,如果没有下载过,就会在当前路径下新建/data子目录,并把文件存放其中
# MNIST数据是属于torchvision包自带的数据,所以可以直接调用。
# 在调用自己的数据的时候,我们可以用torchvision.datasets.ImageFolder或者torch.utils.data.TensorDataset来加载
train_dataset = dsets.MNIST(root='./data', #文件存放路径
train=True, #提取训练集
transform=transforms.ToTensor(), #将图像转化为Tensor,在加载数据的时候,就可以对图像做预处理