pyTorch学习笔记(8)--手写体识别
1.加载数据预处理
1.1需要导入的包和函数
from torchvision.datasets import MNIST
import torchvision
from torch.utils.data import DataLoader
import torch
from torch.autograd import Variable
import torch.nn as nn
import numpy
import os
1.2 加载数据及处理
首先获取MNIST 数据集,如果没有就下载,设置download=True。如果下载了设置为False,然后train=True,设置这是训练集,transform=torchvision.transforms.ToTensor(),是将图片数据做一些处理和转化为tensor,
train_set=MNIST('./mnist',train=True,
transform=torchvision.transforms.ToTensor(),
download=True)# 数据集目录nmist, 训练集 ,转为tensor ,没有数据就下载
test_set=MNIST('./mnist',train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
1.3封装数据
将数据封装到这个DataLoader里边,即一个迭代器每次迭代除出16张图片,返回图片和标签。
train_data = DataLoader(train_set, batch_size=16, shuffle=True) #每次16 张图片,顺序打乱
test_data = DataLoader(test_set, batch_size=16, shuffle=False)
2.设置神经网络
设置好神经网络后,返回一个(batch,class_num),返回一多少batch(这里是16张图片)class_num是分类类别是多少这里是10,返回一个这样的矩阵。
class mncnn(nn.Module):
def __init__(self):
super(mncnn,self).__init__()
self.cov1=nn.Sequential(
nn.Conv2d(1,16,5,1,2), #16*28*28
nn.ReLU(),
nn.MaxPool2d(kernel_size=2) # 16*14*14
)
self.cov2=nn.Sequential(
nn.Conv2d(16,32,5,1,2),# 32*14*14
nn.ReLU(),
nn.MaxPool2d(2) # 32*7*7
)
self.out=nn.Linear(32*7*7,10