心得体会
本次主要了解如何构建一个神经网络模型,通过Cell
基类实现对神经网络结构进行构建和管理。
模型定义
通过继承nn.Cell
类,在__init__
方法中进行子Cell的实例化和状态管理,在construct
方法中实现对Tensor操作。在这里,nn.Cell
类与pytorch中的nn.module
类相似,可以把construct
方法看做forward
方法。
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
)
def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits
模型层
常用的网络结构有如下:
- nn.Flatten:将输入的多维数据转换成一维数组。
- nn.Dense:全连接层,其使用权重和偏差对输入进行线性变换。
- nn.ReLU:非线性的激活函数。
- nn.SequentialCell:Cell容器,类似于pytorch中的
nn.Senquential
- nn.Softmax:将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。
axis
指定的维度数值和为1。