昇思25天学习打卡营第4天 |昇思MindSpore 网络构建学习与总结

MindSpore 提供了 nn.Cell 类来构建神经网络模型,这是所有网络的基类。网络模型由不同的子 Cell 组成,可以通过面向对象编程的方式进行构建和管理。

网络构建示例

我们将构建一个用于MNIST数据集分类的神经网络模型。以下是网络构建的主要步骤:

  1. 定义模型类:继承 nn.Cell 类。
  2. __init__ 方法中初始化网络层
  3. construct 方法中实现前向传播

代码实现

import mindspore
from mindspore import nn, ops

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

模型结构

层名称输入形状输出形状激活函数
Flatten(28, 28)(784,)
Dense (1)(784,)(512,)ReLU
Dense (2)(512,)(512,)ReLU
Dense (3)(512,)(10,)

模型输出

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
print(logits)

# 使用 Softmax 层获取预测概率
pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

各层解析

我们分解上述网络中的每一层,依次观察其效果:

  1. Flatten 层:将 2D 图像展平为 1D 数组。

    flatten = nn.Flatten()
    flat_image = flatten(input_image)
    print(flat_image.shape)  # (3, 784)
    
  2. Dense 层:全连接层,进行线性变换。

    layer1 = nn.Dense(in_channels=28*28, out_channels=20)
    hidden1 = layer1(flat_image)
    print(hidden1.shape)  # (3, 20)
    
  3. ReLU 层:激活函数,增加非线性。

    hidden1 = nn.ReLU()(hidden1)
    print(hidden1)
    
  4. SequentialCell:按顺序组合多个层。

    seq_modules = nn.SequentialCell(
        flatten,
        layer1,
        nn.ReLU(),
        nn.Dense(20, 10)
    )
    logits = seq_modules(input_image)
    print(logits.shape)  # (3, 10)
    

参数查看

通过 model.parameters_and_names() 获取模型参数的详细信息。

for name, param in model.parameters_and_names():
    print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")

对比分析

特性PyTorchMindSpore
基础类nn.Modulenn.Cell
前向传播方法forward()construct()
参数初始化手动指定提供多种初始化方式,如 weight_init
生态系统支持PyTorch 有广泛的社区和丰富的生态支持MindSpore 主要集中在华为生态系统
  • 7
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值