1、先创建一个model.py
import jittor as jt
from jittor import nn, Module
import numpy as np
import sys, os
import random
import math
from jittor import init
class Model (Module):
def __init__ (self):
super (Model, self).__init__()
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
self.conv2 = nn.Conv (32, 64, 3, 1)
self.bn = nn.BatchNorm(64)
self.max_pool = nn.Pool (2, 2)
self.relu = nn.Relu()
self.fc1 = nn.Linear (64 * 12 * 12, 256)
self.fc2 = nn.Linear (256, 10)
def execute (self, x) :
x = self.conv1 (x)
x = self.relu (x)
x = self.conv2 (x)
x = self.bn (x)
x = self.relu (x)
x = self.max_pool (x)
x = jt.reshape (x, [x.shape[0], -1])
x = self.fc1 (x)
x = self.relu(x)
x = self.fc2 (x)
return x
2、创建train.py来进行训练、保存模型
import jittor as jt
from jittor import nn, Module
import numpy as np
import sys, os
import random
import math
from jittor import init
from model import Model
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import pylab as pl
jt.flags.use_cuda = 0 # if jt.flags.use_cuda = 1 will use gpu
def train(model, train_loader, optimizer, epoch, losses, losses_idx):
model.train()
lens = len(train_loader)
for batch_idx, (inputs, targets) in enumerate(train_loader):
outputs = model(inputs)
loss = nn.cross_entropy_loss(outputs, targets)
optimizer.step (loss)
losses.append(loss.data[0])
losses_idx.append(epoch * lens + batch_idx)
if batch_idx % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx, len(train_loader),
100. * batch_idx / len(train_loader), loss.data[0]))
def test(model, val_loader, epoch):
model.eval()
test_loss = 0
correct = 0
total_acc = 0
total_num = 0
for batch_idx, (inputs, targets) in enumerate(val_loader):
batch_size = inputs.shape[0]
outputs = model(inputs)
pred = np.argmax(outputs.data, axis=1)
acc = np.sum(targets.data==pred)
total_acc += acc
total_num += batch_size
acc = acc / batch_size
print('Test Epoch: {} [{}/{} ({:.0f}%)]\tAcc: {:.6f}'.format(epoch, \
batch_idx, len(val_loader),100. * float(batch_idx) / len(val_loader), acc))
print ('Total test acc =', total_acc / total_num)
def main ():
batch_size = 64
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
epochs = 5
losses = []
losses_idx = []
train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True)
val_loader = MNIST(train=False, transform=trans.Resize(28)) .set_attrs(batch_size=1, shuffle=False)
model = Model ()
optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
for epoch in range(epochs):
train(model, train_loader, optimizer, epoch, losses, losses_idx)
test(model, val_loader, epoch)
pl.plot(losses_idx, losses)
pl.xlabel('Iterations')
pl.ylabel('Train_loss')
pl.show()
model_path = '/home/root/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
model.save(model_path)
if __name__ == '__main__':
main()
3、运行train.py文件结果如图(testdata为自己创建的测试数据文件夹、mnist_model.pkl为模型文件)
4、创建test.py,加载本地图片对模型进行测试。
from datetime import date
from matplotlib import pyplot as plt
import jittor as jt
from numpy.core.fromnumeric import shape
from numpy.lib.type_check import imag
from model import Model
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import numpy as np
import cv2
import os
from PIL import Image
import matplotlib.pyplot as plt
"""
#加载MNIST模型库中的图片
model_path = '/home/lizhi528/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
new_model = Model()
new_model.load_parameters(jt.load(model_path))
val_loader = MNIST(train=False, transform=trans.Resize(28)).set_attrs(batch_size=1, shuffle=False)
data_iter = iter(val_loader)
val_data, val_label = next(data_iter)#
outputs = new_model(val_data)
prediction = np.argmax(outputs.data, axis=1)
print(val_label.data)
print(prediction)
"""
def ImageClassification(imagPath,model):
img_path=imagPath
# 得到一个 HxWx3 的 array(224, 225, 3)
image = cv2.imread(img_path)
cv2.imshow("img",image)
cv2.waitKey(0)
# 把图像缩放到 28x28 个像素(28, 28, 3)
image = cv2.resize(image, (28, 28))
print(image.shape)
image = image / 255.0 # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
image = image.transpose(2, 0, 1) # 把输入格式从 HWC 改为 CHW
image = jt.float32(image) # 变为 Jittor Var
image = image.unsqueeze(dim=0) # 加入 batch 维度,变为 [1, C, H, W]
outputs = model(image)
prediction = np.argmax(outputs.data, axis=1)
print(prediction)
"""
img = Image.open(imagPath)
plt.figure("Image")
plt.imshow(img)
plt.show()
"""
def main():
#加载模型
model_path = '/home/lizhi528/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
model = Model()
model.load_parameters(jt.load(model_path))
#加载本地图片
img_path='/home/lizhi528/Python_Demo/JittorMNISTImageClassification/testdata/0.jpg'
ImageClassification(img_path,model)
if __name__ == '__main__':
main()
5、测试结果如图:
6、用cv2进行图片展示只能选择停留>0的时间,否则关闭弹窗程序无法集训运行,而用plt对图片展示就没有这个问题,关闭弹窗程序继续运行。