mindspore是华为深度学习框架,网址为:https://www.mindspore.cn/
本代码主要参考快速入门的代码,加了模型导出为onnx
mindspore在模型搭建上基本上的语法和pytorch差不多
只是分为了网络和模型,模型主要拿来训练和预测,而网络就是单纯的网络,网络可以拿来导出模型文件,但是预测只能使用模型
训练代码如下:
# -*- coding: utf-8 -*-
import os
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import Tensor, Model,export,load_checkpoint
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
import numpy as np
import mindspore.dataset as ds
from mindspore.train.callback import Callback
# https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/MNIST_Data.zip
train_data_path = "./datasets/MNIST_Data/train"
test_data_path = "./datasets/MNIST_Data/test"
mnist_path = "./datasets/MNIST_Data"
model_path = "./models/ckpt/"
#定义数据集
def create_dataset(data_path, batch_size=128, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
Args:
data_path (str): Data path
batch_size (int): The number of data records in each group
repeat_size (int): The number of replicated data records
num_parallel_workers (int): The number of parallel workers
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
# define some parameters needed for data enhancement and rough justification
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# according to the parameters, generate the corresponding data enhancement method
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# using map to apply operations to a dataset
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# process the generated dataset
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
# custom callback function
class StepLossAccInfo(Callback):
def __init__(self, model, eval_dataset, step_loss, steps_eval):
self.model = model
self.eval_dataset = eval_dataset
self.step_loss = step_loss
self.steps_eval = steps_eval
def step_end(self, run_context):
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
self.step_loss["loss_value"].append(str(cb_params.net_outputs))
self.step_loss["step"].append(str(cur_step))
if cur_step % 125 == 0:
acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
self.steps_eval["step"].append(cur_step)
self.steps_eval["acc"].append(acc["Accuracy"])
#定义网络
class mnist(nn.Cell):
def __init__(self, num_class=10):
super(mnist, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(8, 12, 5, pad_mode='valid')
self.fc1 = nn.Dense(300 , 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 60, weight_init=Normal(0.02))
self.fc3 = nn.Dense(60, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
network = mnist()
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
#定义模型
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()} )
# save the network model and parameters for subsequence fine-tuning
config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16)
# group layers into an object with training and evaluation features
ckpoint_cb = ModelCheckpoint(prefix="mnist", directory=model_path, config=config_ck)
eval_dataset = create_dataset("./datasets/MNIST_Data/test")
step_loss = {"step": [], "loss_value": []}
steps_eval = {"step": [], "acc": []}
# collect the steps,loss and accuracy information
step_loss_acc_info = StepLossAccInfo(model , eval_dataset, step_loss, steps_eval)
repeat_size = 1
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
model.train(1, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=False)
#测试
ds_test = create_dataset(test_data_path).create_dict_iterator()
data = next(ds_test)
images = data["image"].asnumpy()
labels = data["label"].asnumpy()
print(model.predict(Tensor(data['image'])))
print(images.shape)
#导出格式为onnx文件
load_checkpoint("models\\ckpt\\mnist-1_1875.ckpt", net=network)
print("load****")
input =np.ones([1, 1, 32, 32]).astype(np.float32)
export(network, Tensor(input), file_name='minist', file_format='ONNX')
注意,数据集需要自己提前下载:https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/MNIST_Data.zip
运行后就可以导出模型
运行结果:
可以打开onnx文件看看网络结构:
验证一下数据结果:
onnx:
import onnxruntime
import numpy as np
x=np.ones([1, 1, 32, 32]).astype(np.float32)
session = onnxruntime.InferenceSession("minist.onnx")
inputs = {session.get_inputs()[0].name: x}
outs = session.run(None, inputs)
print('onnx result is:',outs)
运行结果:
mindspore:
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import load_checkpoint, load_param_into_net
from mindspore import Tensor, Model
import numpy as np
class mnist(nn.Cell):
def __init__(self, num_class=10):
super(mnist, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(8, 12, 5, pad_mode='valid')
self.fc1 = nn.Dense(300 , 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 60, weight_init=Normal(0.02))
self.fc3 = nn.Dense(60, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
network = mnist()
ckpt_path = "models\\ckpt\\mnist-1_1875.ckpt"
trained_ckpt = load_checkpoint(ckpt_path)
load_param_into_net(network, trained_ckpt)
input=np.ones([1, 1, 32, 32]).astype(np.float32)
model = Model(network, metrics={'acc'}, eval_network=network)
print(model.predict(Tensor(input)))
运行结果:
运行结果一致,说明导出文件正确。