pytorch使用mnist数据集手写数字识别(极简版)

pytorch使用mnist数据集手写数字识别,初学者入门,只有一个线性层加上softmax函数后输出概率,可以显示图片以及预测数字。

import torchvision
import torch
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt

trans = transforms.ToTensor()
train_data = torchvision.datasets.MNIST(
    root="./MNIST", train=True, transform=trans, download=False)
test_data  = torchvision.datasets.MNIST(
    root="./MNIST", train=False, transform=trans, download=False)

BATCH_SIZE = 256  

train_dataloader = data.DataLoader(dataset = train_data,batch_size = BATCH_SIZE,shuffle = True)
test_dataloader  = data.DataLoader(dataset = test_data,batch_size = BATCH_SIZE)

def show_img(img,label,nrow,ncol,scale=1.5):
    fig,axes = plt.subplots(nrow,ncol, figsize=(ncol*scale,nrow*scale)) 
    axes = axes.flatten()
    for i,(img,ax) in enumerate(zip(img,axes)):
        ax.imshow(img.squeeze(),cmap="Greys")
        ax.set_title(label[i])
        ax.axis('off')
    plt.show()

x, y = next(iter(train_dataloader))
show_img(x,y,2,9)
class model(torch.nn.Module):
    def __init__(self,in_dims,out_dim):
        super().__init__()
        self.layers = torch.nn.modules.Sequential(
            torch.nn.Linear(in_dims,out_dim),
        )
    def forward(self,x):
        x = self.layers(x)
        return x
def softmax(x):
    x_exp = torch.exp(x)
    x_exp_sum = x_exp.sum(1,keepdims=True)
    return x_exp/x_exp_sum

def cross_entropy(y_hat,y):
    return -torch.log(y_hat[range(len(y_hat)),y])

def accuracy(y_hat,y):
    pred = y_hat.argmax(1)
    cmp = pred==y
    print("Accuracy:%.3f"%(cmp.sum()/len(y)))

# aa = torch.tensor([[0.1,0.2],[0.1,0.2]])
# bb = torch.tensor([1,1])
# accuracy(aa,bb)

# y = torch.tensor([0,2])
# y_hat = torch.tensor([[0.1,0.2,0.7],[0.2,0.4,0.4]])
# print(y_hat[[0,1,],y])
# print(cross_entropy(y_hat,y))

# manual configuration to build model
w = torch.normal(0,1,(28*28,10),requires_grad=True)
b = torch.normal(0,1,(10,),requires_grad=True)
n_epoch = 3
optimizer = torch.optim.SGD([w,b],0.001)
for epoch in range(n_epoch):
    for img,label in train_dataloader:
        img = img.flatten(1)
        y_hat = torch.matmul(img,w)+b
        y_hat = softmax(y_hat)
        loss = cross_entropy(y_hat,label)
        optimizer.zero_grad()
        loss.sum().backward()
        optimizer.step()
    for img,label in test_dataloader:
        img = img.flatten(1)
        y_hat = torch.matmul(img,w)+b
        y_hat = softmax(y_hat)
        accuracy(y_hat,label)
        break
# using pytorch to build model    
n_epoch = 3
my_model = model(28*28,10)
optimizer = torch.optim.SGD(my_model.parameters(),0.001)
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
for epoch in range(n_epoch):
    my_model.train()
    for img,label in train_dataloader:
        img = img.flatten(1)
        pred = my_model(img)
        loss = criterion(pred,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    my_model.eval()
    for img, label in test_dataloader:
        with torch.no_grad():
            img = img.flatten(1)
            pred = my_model(img)
            pred = pred.softmax(1)
            accuracy(pred,label)
            break 
     
def pred_img(count): # using manual configuration to build model and predict the answer
    for img,label in train_dataloader:
        img_flatten = img.flatten(1)
        y_hat = torch.matmul(img_flatten,w)+b
        y_hat = softmax(y_hat)
        accuracy(y_hat,label)
        pred  = y_hat.argmax(1)
        title = ['Pred:' + str(int(i)) + '\nTrue:' + str(int(j)) + ('\nx' if int(i)!=int(j) else '') for i,j in zip(pred,label)]
        show_img(img,title,3,9,scale=2.5)
        count-=1
        if count == 0:
            return

pred_img(1)
def pred_img(count): # using pytorch to build model and predict the answer
    for img,label in train_dataloader:
        my_model.eval()
        img_flatten = img.flatten(1)
        y_hat = my_model(img_flatten)
        y_hat = softmax(y_hat)
        accuracy(y_hat,label)
        pred  = y_hat.argmax(1)
        title = ['Pred:' + str(int(i)) + '\nTrue:' + str(int(j)) + ('\nx' if int(i)!=int(j) else '') for i,j in zip(pred,label)]
        show_img(img,title,3,9,scale=2.5)
        count-=1
        if count == 0:
            return

pred_img(1)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值