深度学习-torch-mnist手写识别分类
前言
记录一些上深度学习课的一次作业,记录会出现的问题,然后如何解决。因为这些库都是有点老的,所以各种报错。
一、需要哪些准备
1.python3.6
2.matplotlib 画图
3.numpy 计算
4.torch 深度学习框架
2、3直接pip就可以了,
4可能不知道什么装,是需要安装什么搭配,是CPU还是GPU。
这里我附上链接或者,官方 这里根据介绍装什么
然后在cmd 输入pip install torch1.8.0+cpu torchvision0.9.0+cpu torchaudio===0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
二、代码
1.引入库
代码如下(示例):
import numpy as np
import torch
from torchvision.datasets import mnist
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
def data_tf(x):
x = np.array(x, dtype='float32') / 255
x = (x - 0.5) / 0.5 # normalization #转化为-1到1
x = x.reshape((-1,)) # flatten #拉成一行 维度转化
x = torch.from_numpy(x)
return x
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) #训练集
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True) #测试集
train_data = DataLoader(train_set, batch_size=