基本介绍:
昇思框架是什么,
了解到还有oneDNN,Eigen库用于cpu推理
昇思框架(MindSpore)是华为推出的一款开源深度学习框架,旨在提升AI开发效率和部署性能。以下是昇思框架的一些主要特点和功能:
-
易于使用:
- 提供高层次的API,简化了模型的构建和训练过程。
- 支持动态图和静态图两种模式,方便开发者进行灵活的调试和优化。
-
高性能:
- 利用华为自研的Ascend AI处理器,实现高效的模型训练和推理。
- 支持分布式训练,能够在大规模集群上进行高效的模型训练。
-
全场景覆盖:
- 支持从云到边缘的全场景应用,适用于多种硬件平台,包括Ascend、GPU和CPU。
- 提供丰富的模型库和预训练模型,方便用户进行快速部署和迁移学习。
-
安全可靠:
- 提供全面的安全机制,保护用户的数据和模型安全。
- 通过了多项国际认证,确保框架的可靠性和安全性。
-
生态丰富:
- 拥有活跃的开源社区,提供丰富的技术支持和资源。
- 与多种开发工具和平台兼容,支持主流的AI开发流程。
昇思框架的推出,旨在帮助开发者更加高效地进行AI模型的开发、训练和部署,推动AI技术在各行各业的应用和发展。
类似于pytorch,我们可以使用昇思框架的api,搭建神经网络,进行训练和预测
导入依赖的包
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
下载数据集以及处理数据集就不看了,主要看网络搭建的api
mindspore.nn
类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承nn.Cell
类,并重写__init__
方法和construct
方法。__init__
包含所有网络层的定义,construct
中包含数据(Tensor)的变换过程。
# Define model
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512),
nn.ReLU(),
nn.Dense(512, 512),
nn.ReLU(),
nn.Dense(512, 10)
)
def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits
model = Network()
print(model)
Network< (flatten): Flatten<> (dense_relu_sequential): SequentialCell< (0): Dense<input_channels=784, output_channels=512, has_bias=True> (1): ReLU<> (2): Dense<input_channels=512, output_channels=512, has_bias=True> (3): ReLU<> (4): Dense<input_channels=512, output_channels=10, has_bias=True> > >
模型训练:
MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:
- 定义正向计算函数。
- 使用value_and_grad通过函数变换获得梯度计算函数。
- 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
可以看到和pytorch的一些api是很类似的,熟悉pytorch的可以很快地切换
模型的保存:
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")
参数加载:
# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)