网络构建
神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。一个神经网络模型表示为一个Cell
,它由不同的子Cell
构成。使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。
下面我们将构建一个用于Mnist数据集分类的神经网络模型。
import mindspore
from mindspore import nn, ops
定义模型类
当我们定义神经网络时,可以继承nn.Cell
类,在__init__
方法中进行子Cell的实例化和状态管理,在construct
方法中实现Tensor操作。
construct
意为神经网络(计算图)构建,相关内容详见使用静态图加速。
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
)
def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits
构建完成后,实例化Network
对象,并查看其结构。
model = Network()
print(model)
我们构造一个输入数据,直接调用模型,可以获得一个二维的Tensor输出,其包含每个类别的原始预测值。
model.construct()
方法不可直接调用。
X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits
在此基础上,我们通过一个nn.Softmax
层实例来获得预测概率。
pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
模型层
本节中我们分解上节构造的神经网络模型中的每一层。首先我们构造一个shape为(3, 28, 28)的随机数据(3个28x28的图像),依次通过每一个神经网络层来观察其效果。
input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)
nn.Flatten
实例化nn.Flatten对象,将输入展平为一维Tensor。
flatten = nn.Flatten()
flattened_image = flatten(input_image)
print(flattened_image.shape)
nn.Dense
实例化nn.Dense对象,并应用于展平后的输入。
layer1 = nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros")
hidden1 = layer1(flattened_image)
print(hidden1.shape)
nn.ReLU
实例化nn.ReLU对象,并应用于隐藏层。
relu = nn.ReLU()
relu_output = relu(hidden1)
print(relu_output.shape)
nn.SequentialCell
使用nn.SequentialCell将多个层连接在一起。
model = nn.SequentialCell(
nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
)
print(model(input_image))
至此,我们已经构建了一个简单的神经网络模型,并展示了如何一步步地对输入进行处理。
Reference
教程来自:
https://gitee.com/mindspore/docs/blob/r2.3/tutorials/source_zh_cn/beginner/model.ipynb