本项目是以pytorch为框架进行mnist图像分类任务:
CNN:
#coding = utf-8
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from torch import optim
#parameters
epochs = 2
batch_size = 100
lr = 0.01
download_mnist = False
train_data = torchvision.datasets.MNIST(
root='./minst_data',
train=True,
transform = torchvision.transforms.ToTensor(),
download=download_mnist
)
#plot one example
# print(train_data.data.size())
# print(train_data.targets.size())
# plt.imshow(train_data.data[0].numpy(),cmap = 'gray')
# plt.title('%i'%train_data.targets[0])
# plt.show()
train_loader =Data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
test_data = torchvision.datasets.MNIST(root='./minst_data',train=False)
test_x = Variabl