Jittor MNIST图片识别模型训练+测试

1、先创建一个model.py

import jittor as jt
from jittor import nn, Module
import numpy as np
import sys, os
import random
import math
from jittor import init

class Model (Module):
    def __init__ (self):
        super (Model, self).__init__()
        self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
        self.conv2 = nn.Conv (32, 64, 3, 1)
        self.bn = nn.BatchNorm(64)

        self.max_pool = nn.Pool (2, 2)
        self.relu = nn.Relu()
        self.fc1 = nn.Linear (64 * 12 * 12, 256)
        self.fc2 = nn.Linear (256, 10)
    def execute (self, x) :
        x = self.conv1 (x)
        x = self.relu (x)

        x = self.conv2 (x)
        x = self.bn (x)
        x = self.relu (x)

        x = self.max_pool (x)
        x = jt.reshape (x, [x.shape[0], -1])
        x = self.fc1 (x)
        x = self.relu(x)
        x = self.fc2 (x)
        return x

2、创建train.py来进行训练、保存模型

import jittor as jt
from jittor import nn, Module
import numpy as np
import sys, os
import random
import math
from jittor import init
from model import Model
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import pylab as pl

jt.flags.use_cuda = 0 # if jt.flags.use_cuda = 1 will use gpu

def train(model, train_loader, optimizer, epoch, losses, losses_idx):
    model.train()
    lens = len(train_loader)
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        outputs = model(inputs)
        loss = nn.cross_entropy_loss(outputs, targets)
        optimizer.step (loss)
        losses.append(loss.data[0])
        losses_idx.append(epoch * lens + batch_idx)
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.data[0]))

def test(model, val_loader, epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total_acc = 0
    total_num = 0
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        batch_size = inputs.shape[0]
        outputs = model(inputs)
        pred = np.argmax(outputs.data, axis=1)
        acc = np.sum(targets.data==pred)
        total_acc += acc
        total_num += batch_size
        acc = acc / batch_size
        print('Test Epoch: {} [{}/{} ({:.0f}%)]\tAcc: {:.6f}'.format(epoch, \
                    batch_idx, len(val_loader),100. * float(batch_idx) / len(val_loader), acc))
    print ('Total test acc =', total_acc / total_num)

def main ():
    batch_size = 64
    learning_rate = 0.1
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 5
    losses = []
    losses_idx = []
    train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True)
    val_loader = MNIST(train=False, transform=trans.Resize(28)) .set_attrs(batch_size=1, shuffle=False)
    model = Model ()
    optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
    for epoch in range(epochs):
        train(model, train_loader, optimizer, epoch, losses, losses_idx)
        test(model, val_loader, epoch)

    pl.plot(losses_idx, losses)
    pl.xlabel('Iterations')
    pl.ylabel('Train_loss')
    pl.show()

    model_path = '/home/root/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
    model.save(model_path)

if __name__ == '__main__':
    main()

3、运行train.py文件结果如图(testdata为自己创建的测试数据文件夹、mnist_model.pkl为模型文件)

 4、创建test.py,加载本地图片对模型进行测试。

from datetime import date
from matplotlib import pyplot as plt
import jittor as jt
from numpy.core.fromnumeric import shape
from numpy.lib.type_check import imag
from model import Model
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import numpy as np
import cv2 
import os
from PIL import Image
import matplotlib.pyplot as plt

""" 
#加载MNIST模型库中的图片
model_path = '/home/lizhi528/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
new_model = Model()
new_model.load_parameters(jt.load(model_path))
val_loader = MNIST(train=False, transform=trans.Resize(28)).set_attrs(batch_size=1, shuffle=False)
data_iter = iter(val_loader)
val_data, val_label = next(data_iter)# 
outputs = new_model(val_data)
prediction = np.argmax(outputs.data, axis=1)
print(val_label.data)
print(prediction) 
"""


def ImageClassification(imagPath,model):    
    
    img_path=imagPath
    # 得到一个 HxWx3 的 array(224, 225, 3)
    image = cv2.imread(img_path)     
    cv2.imshow("img",image)
    cv2.waitKey(0)       
    # 把图像缩放到 28x28 个像素(28, 28, 3)
    image = cv2.resize(image, (28, 28)) 
    print(image.shape)
    image = image / 255.0               # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
    image = image.transpose(2, 0, 1)    # 把输入格式从 HWC 改为 CHW
    image = jt.float32(image)           # 变为 Jittor Var
    image = image.unsqueeze(dim=0)      # 加入 batch 维度,变为 [1, C, H, W]
    outputs = model(image)
    prediction = np.argmax(outputs.data, axis=1)
    print(prediction)

    """ 
    img = Image.open(imagPath)
    plt.figure("Image")
    plt.imshow(img) 
    plt.show() 
    """

def main():
    #加载模型
    model_path = '/home/lizhi528/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
    model = Model()
    model.load_parameters(jt.load(model_path))
    #加载本地图片
    img_path='/home/lizhi528/Python_Demo/JittorMNISTImageClassification/testdata/0.jpg'
    ImageClassification(img_path,model)

if __name__ == '__main__':
    main()

5、测试结果如图:

 6、用cv2进行图片展示只能选择停留>0的时间,否则关闭弹窗程序无法集训运行,而用plt对图片展示就没有这个问题,关闭弹窗程序继续运行。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值