神经网络模型是由神经网络层和Tensor操作构成的
class Network(nn.Cell): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.dense_relu_sequential = nn.SequentialCell( nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"), nn.ReLU(), nn.Dense(512, 512, weight_init="normal", bias_init="zeros"), nn.ReLU(), nn.Dense(512, 10, weight_init="normal", bias_init="zeros") ) def construct(self, x): x = self.flatten(x) logits = self.dense_relu_sequential(x) return logits
在__init__中 定义好模型的参数 在construct中定义模型调用的顺序 通过类定义比单独函数调用更加通俗易懂