mindspore打卡第三课模型定义和训练全流程

mindspore打卡第三课模型定义和训练全流程

import mindspore
from mindspore import nn, ops


class Network(nn.Cell):
    ### 先定义类合参数 需要初始化实例
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )
#######初始化实例后可以调用类内的参数合 函数方法
    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits
    
model = Network()
print(model)    
    

Network<
      (flatten): Flatten<>
      (dense_relu_sequential): SequentialCell<
        (0): Dense<input_channels=784, output_channels=512, has_bias=True>
        (1): ReLU<>
        (2): Dense<input_channels=512, output_channels=512, has_bias=True>
        (3): ReLU<>
        (4): Dense<input_channels=512, output_channels=10, has_bias=True>
        >
      >

###模型传入数据后得到结果
X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
print("logits",logits)
pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

logits [[-0.00105553  0.00748482  0.00836606  0.01366857  0.00911757 -0.00338881
       0.01666504  0.00012192  0.0045125  -0.00740504]]
    Predicted class: [6]

## 创建3通道的模拟数据
input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)
##先过一个打平层看看维度变换不 nn.Flatten层,将28x28的2D张量转换为784大小的连续数组
flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)
print("ft:",flat_image)

(3, 28, 28)
    (3, 784)
    ft: [[1. 1. 1. … 1. 1. 1.]
     [1. 1. 1. … 1. 1. 1.]
     [1. 1. 1. … 1. 1. 1.]]

## 看看全连接层的维度变换
layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)
##激活函数
print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

(3, 20)
    Before ReLU: [[ 0.93543625  0.19160181  1.1723751  -0.8474578  -0.1709224  -0.14004312
       0.5371042   0.13461939 -0.67513317 -0.06042651  0.79395443  0.04006821
      -0.58641505 -0.9113059  -1.3169372   0.59977484  0.02021017 -0.4086219
      -0.37906158  0.28552184]
     [ 0.93543625  0.19160181  1.1723751  -0.8474578  -0.1709224  -0.14004312
       0.5371042   0.13461939 -0.67513317 -0.06042651  0.79395443  0.04006821
      -0.58641505 -0.9113059  -1.3169372   0.59977484  0.02021017 -0.4086219
      -0.37906158  0.28552184]
     [ 0.93543625  0.19160181  1.1723751  -0.8474578  -0.1709224  -0.14004312
       0.5371042   0.13461939 -0.67513317 -0.06042651  0.79395443  0.04006821
      -0.58641505 -0.9113059  -1.3169372   0.59977484  0.02021017 -0.4086219
      -0.37906158  0.28552184]]
    
    
    After ReLU: [[0.93543625 0.19160181 1.1723751  0.         0.         0.
      0.5371042  0.13461939 0.         0.         0.79395443 0.04006821
      0.         0.         0.         0.59977484 0.02021017 0.
      0.         0.28552184]
     [0.93543625 0.19160181 1.1723751  0.         0.         0.
      0.5371042  0.13461939 0.         0.         0.79395443 0.04006821
      0.         0.         0.         0.59977484 0.02021017 0.
      0.         0.28552184]
     [0.93543625 0.19160181 1.1723751  0.         0.         0.
      0.5371042  0.13461939 0.         0.         0.79395443 0.04006821
      0.         0.         0.         0.59977484 0.02021017 0.
      0.         0.28552184]]

##结构组合 可以循环 可以循环出来吗吗  ,一步到位
seq_modules = nn.SequentialCell(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Dense(20, 10)
)

logits = seq_modules(input_image)
print(logits.shape)


# 关于“结构组合能否循环出来”的问题,如果你的意思是能否动态地构建或重复使用这样的结构,答案是肯定的。你可以通过循环或者其他编程结构动态地创建nn.Sequential的内容。例如,如果你想要根据某个条件创建多个相同结构的模块,可以这样做:

# Python
# import torch.nn as nn

# # 假设我们想要重复堆叠相同的块n_times次
# n_times = 3
# layers = []

# for _ in range(n_times):
#     layers.append(nn.Linear(100, 100))  # 添加一个全连接层
#     layers.append(nn.ReLU())           # 添加ReLU激活函数

# seq_modules = nn.Sequential(*layers)  # 使用星号(*)展开列表作为参数

# # 然后可以使用seq_modules处理输入
# input_data = torch.randn(32, 100)  # 假设的输入数据
# output = seq_modules(input_data)
# 在这个示例中,通过循环,我们创建了一个包含多个相同结构(线性层+ReLU)的序列模型。这种方式允许根据需要灵活地构建复杂的网络结构。

(3, 10)

input_image.shape

(3, 28, 28)


# 假设我们想要重复堆叠相同的块n_times次
n_times = 3
layers = []

for _ in range(n_times):
    layers.append(nn.Dense(28, 28))  # 添加一个全连接层nn.Dense
    layers.append(nn.ReLU())           # 添加ReLU激活函数

seq_modules = nn.SequentialCell(*layers)  # 使用星号(*)展开列表作为参数

# 然后可以使用seq_modules处理输入
input_data =input_image  # 假设的输入数据
output = seq_modules(input_data)

print(output.shape)

(3, 28, 28)

print(logits)
softmax = nn.Softmax(axis=1)
pred_probab = softmax(logits)
print("---最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。--")
print(pred_probab)

[[-0.4724075   0.05367208  0.0397298   0.25943297  0.06311138 -0.00053869
       0.35859025 -0.2345797   0.30336326 -0.39170516]
     [-0.4724075   0.05367208  0.0397298   0.25943297  0.06311138 -0.00053869
       0.35859025 -0.2345797   0.30336326 -0.39170516]
     [-0.4724075   0.05367208  0.0397298   0.25943297  0.06311138 -0.00053869
       0.35859025 -0.2345797   0.30336326 -0.39170516]]
    —最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。–
    [[0.06034356 0.10211848 0.1007046  0.12544842 0.10308698 0.09672994
      0.13852516 0.07654525 0.13108225 0.06541532]
     [0.06034356 0.10211848 0.1007046  0.12544842 0.10308698 0.09672994
      0.13852516 0.07654525 0.13108225 0.06541532]
     [0.06034356 0.10211848 0.1007046  0.12544842 0.10308698 0.09672994
      0.13852516 0.07654525 0.13108225 0.06541532]]

print(f"Model structure: {model}\n\n")

for name, param in model.parameters_and_names():
    print(f"Layer NAME: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")

Model structure: Network<
      (flatten): Flatten<>
      (dense_relu_sequential): SequentialCell<
        (0): Dense<input_channels=784, output_channels=512, has_bias=True>
        (1): ReLU<>
        (2): Dense<input_channels=512, output_channels=512, has_bias=True>
        (3): ReLU<>
        (4): Dense<input_channels=512, output_channels=10, has_bias=True>
        >
      >
    
    
    Layer NAME: dense_relu_sequential.0.weight
    Size: (512, 784)
    Values : [[-4.0671071e-05 -3.8874722e-03  1.4607443e-03 … -1.1891589e-02
       2.0101516e-02  3.4556145e-04]
     [-9.9748727e-03  4.1558472e-03 -1.1567972e-03 …  4.2272024e-03
       5.0952132e-03  1.1591410e-02]] 
    
    Layer NAME: dense_relu_sequential.0.bias
    Size: (512,)
    Values : [0. 0.] 
    
    Layer NAME: dense_relu_sequential.2.weight
    Size: (512, 512)
    Values : [[ 0.02005285  0.01688561 -0.00053313 … -0.00234611  0.00209523
      -0.01069442]
     [-0.01309605  0.01469714  0.00323318 … -0.00890452  0.01209573
      -0.0013482 ]] 
    
    Layer NAME: dense_relu_sequential.2.bias
    Size: (512,)
    Values : [0. 0.] 
    
    Layer NAME: dense_relu_sequential.4.weight
    Size: (10, 512)
    Values : [[-0.01001023 -0.01732783 -0.01063744 …  0.01359949 -0.01130515
       0.01796171]
     [-0.01353195 -0.00085292  0.00219263 … -0.00905481  0.01607992
       0.00187309]] 
    
    Layer NAME: dense_relu_sequential.4.bias
    Size: (10,)
    Values : [0. 0.]

这段描述展示了一个神经网络模型的结构及其部分权重参数。该模型包含一个名为 flatten 的层(用于将输入数据展平为一维向量)以及一个复合结构 dense_relu_sequential,后者是由几个全连接层(Dense)和激活函数(ReLU)组成的序列。具体解析如下:

整体结构

  • Network Structure: 描述了整个模型的结构,其中包含两个主要部分:
      - Flatten<>: 一个层,负责将输入数据从多维(例如,图像的高×宽×通道)转换为一维向量,以便输入到后续的全连接层。
      - dense_relu_sequential: 一个SequentialCell,里面包含了多个层,这些层顺序执行。具体包含:
        - 3个全连接层 (Dense),每个后面紧跟着一个ReLU激活函数 (ReLU), 以及最后一个不跟随激活函数的全连接层用于输出。
        - 第一个 Dense 层将输入从784维(假设是28x28像素的灰度图像展平后的维度)映射到512维。
        - 接下来的两个 Dense 层都保持了512的输入输出维度。
        - 最后一个 Dense 层将512维映射到10维,这通常用于分类任务(例如,MNIST数据集的手写数字识别,共10个类别)。

权重参数示例

  • dense_relu_sequential.0.weight: 第一个全连接层的权重矩阵,尺寸为 (512, 784),表示从784维输入映射到512维输出的权重。
  • dense_relu_sequential.0.bias: 第一个全连接层的偏置向量,尺寸为 (512,),每个元素对应一个输出神经元的偏置。
  • 类似地,dense_relu_sequential.2.weightbias,以及 dense_relu_sequential.4.weightbias 分别展示了第二个和第三个全连接层的权重和偏置参数。最后一个全连接层(到输出类别)的权重和偏置也给出了,尺寸分别是 (10, 512) 和 (10,),对应于将512维特征映射到10个类别输出。

参数初始化

  • 权重矩阵的初始值看起来是随机生成的,这是典型的神经网络初始化策略,旨在打破对称性,促进模型学习。
  • 偏置项初始化为零,这是一种常见的做法,尤其是在使用某些激活函数(如ReLU)时,因为偏置项可以通过学习调整以应对偏移。

注意

  • 所示的权重值仅为示例中的前几行,实际权重矩阵会有更多行和列,包含了从输入到输出的完整映射关系。
  • 所有的偏置项都初始化为0,这表明模型在训练初期不依赖于偏置来推动激活,偏置值会在训练过程中根据反向传播和梯度下降等优化算法更新。

所谓的权重 只要会矩阵乘法 就知道维度应该怎么变换 对吗

是的,理解权重矩阵在神经网络中的作用确实与掌握矩阵乘法紧密相关。权重矩阵定义了从一层到下一层的变换,它决定了输入数据如何被转换以提取更高级别的特征。以下是几个关键点来帮助理解这一过程:

  1. 维度变换规则:
       - 输入到隐藏层: 如果输入数据是 (n) 维向量,并且我们想要映射到 (m) 个隐藏单元,那么第一个全连接层(Dense layer)的权重矩阵 (W) 将是 (m \times n) 的形状。这是因为每一个隐藏单元都会与输入层的每个特征相连接,权重矩阵的每一行代表一个隐藏单元与输入特征之间的连接强度。
       
       - 隐藏层到输出层: 类似地,如果隐藏层有 (m) 个单元,而输出层需要产生 (k) 个类别概率(例如,在分类任务中),那么接下来的权重矩阵将是 (k \times m) 的尺寸。这表示每个输出单元由前一层的所有隐藏单元通过加权求和得到其净输入。

  2. 矩阵乘法操作:
       - 在神经网络的前向传播过程中,当输入数据(视为一个向量)乘以权重矩阵时,实际上是将输入空间线性变换到一个新的特征空间。这个操作符合矩阵乘法的规则,即如果矩阵 (A) 是 (a \times b),矩阵 (B) 是 (b \times c),那么 (AB) 的运算结果是一个 (a \times c) 的矩阵,其中 (c) 维度对应到下一层的特征数量。

  3. 偏置项的作用:
       - 偏置项(bias)通常是以一个向量形式添加到每层的净输入上,目的是为了能够调整该层输出的总体水平。虽然它不改变输出的维度,但对模型的学习能力和表现至关重要,因为它可以帮助模型拟合数据的偏移。

综上所述,理解矩阵乘法对于设计和解释神经网络中的权重变换是基础且必要的,它直接关系到网络结构中各层之间如何传递信息和进行特征提取。

所以模型的训练本质是在训练模型的权重参数吗
是的,模型训练的本质就是在优化模型的权重参数(以及偏置项)。在神经网络和其他机器学习模型中,权重参数决定了模型如何对输入数据进行变换和处理,进而作出预测或分类。训练过程的目标是通过一系列算法(最常见的是梯度下降及其变种)调整这些权重,使得模型在训练数据上尽可能准确地预测输出,同时希望能泛化到未见过的数据上。

具体来说,训练流程包括以下关键步骤:

  1. 前向传播:输入数据通过网络,每一层的输出都是前一层输出与该层权重矩阵相乘的结果(加上偏置项),并通过激活函数转换。这一过程产生一个预测输出。

  2. 计算损失:预测输出与真实标签之间的差距通过损失函数(或成本函数)量化。损失函数衡量了模型预测的“错误程度”。

  3. 反向传播:计算损失函数关于模型权重的梯度,即损失函数对每个权重的偏导数。这一步利用链式法则从输出层反向直到输入层,逐层计算误差的贡献。

  4. 权重更新:根据梯度下降的原则,模型的权重会被更新,通常是朝着梯度的负方向,以期望在下一轮迭代中减小损失。更新公式通常包含学习率参数,它决定了权重调整的步长。

  5. 迭代训练:上述过程会重复进行多次(即多个epoch),每次迭代可能使用全部或部分训练数据(批次),直至达到预设的停止条件,比如迭代次数、模型性能提升的阈值等。

因此,训练模型实质上是在不断地调整权重参数,使其在训练数据上表现得更好,同时也试图学到数据中的潜在规律,以便在新的、未见过的数据上也能做出准确预测。这个过程是机器学习的核心,是模型如何从原始数据中“学习”的关键所在。

假设已经通过反向传播得到梯度,矩阵的模型的权重是怎么更新的 请用数据举例说明这个过程

当然,让我们通过一个具体的3x3权重矩阵的例子来说明如何利用反向传播得到的梯度来更新权重。假设我们有一个简单的神经网络层,该层的权重矩阵 (W) 是一个3x3的矩阵,表示为:

[ W = 
\begin{bmatrix}
w_{11} & w_{12} & w_{13} \
w_{21} & w_{22} & w_{23} \
w_{31} & w_{32} & w_{33}
\end{bmatrix}
]

进一步假设我们已经通过反向传播计算得到了关于这个权重矩阵 (W) 的梯度矩阵 (\frac{\partial L}{\partial W}),同样是一个3x3的矩阵,表示为:

[ \frac{\partial L}{\partial W} = 
\begin{bmatrix}
g_{11} & g_{12} & g_{13} \
g_{21} & g_{22} & g_{23} \
g_{31} & g_{32} & g_{33}
\end{bmatrix}
]

其中,(g_{ij}) 表示权重 (w_{ij}) 对损失函数 (L) 的偏导数,也就是权重矩阵中每个元素的梯度。

更新步骤

假设我们的学习率(learning rate)为 (\eta = 0.01),则权重矩阵 (W) 的更新公式为:

[ W_{new} = W_{old} - \eta \cdot \frac{\partial L}{\partial W} ]

这意味着我们将用当前的权重矩阵减去它的梯度乘以学习率来得到新的权重矩阵。

具体计算

为了具体说明,我们选取一些假设的梯度值。假设 (\frac{\partial L}{\partial W}) 中的元素分别为:

[ \frac{\partial L}{\partial W} = 
\begin{bmatrix}
-0.02 & 0.03 & -0.01 \
0.05 & -0.04 & 0.02 \
-0.03 & 0.01 & 0.04
\end{bmatrix}
]

那么,更新后的权重矩阵 (W_{new}) 的计算过程如下:

[ W_{new} = 
\begin{bmatrix}
w_{11} & w_{12} & w_{13} \
w_{21} & w_{22} & w_{23} \
w_{31} & w_{32} & w_{33}
\end{bmatrix}

  • 0.01 \times
    \begin{bmatrix}
    -0.02 & 0.03 & -0.01 \
    0.05 & -0.04 & 0.02 \
    -0.03 & 0.01 & 0.04
    \end{bmatrix}
    ]

[ W_{new} = 
\begin{bmatrix}
w_{11} + 0.0002 & w_{12} - 0.0003 & w_{13} + 0.0001 \
w_{21} - 0.0005 & w_{22} + 0.0004 & w_{23} - 0.0002 \
w_{31} + 0.0003 & w_{32} - 0.0001 & w_{33} - 0.0004
\end{bmatrix}
]

这里,我用了“(w_{ij})”来代表原始的权重值,实际上,你需要用具体的数值替换这些符号来进行实际的计算。例如,如果原始的 (w_{11} = 2.5),则更新后的值 (w_{11,new} = 2.5 + 0.0002 = 2.5002)。

请注意,上述计算是基于简化和假设的数值,实际应用中,权重矩阵和梯度矩阵的具体数值会根据模型的具体结构、输入数据和损失函数的具体计算而来。

模型的参数量如何确定

这个模型的参数量(即权重和偏置项的总数)可以通过累加每个层中权重和偏置项的数量来确定。根据给出的信息,我们可以逐层计算:

  1. Layer dense_relu_sequential.0:
       - 权重(weight)大小为 (512 \times 784),因此有 (512 \times 784 = 401,408) 个参数。
       - 偏置(bias)大小为 (512),因此有 (512) 个参数。

  2. Layer dense_relu_sequential.2:
       - 权重大小为 (512 \times 512),因此有 (512 \times 512 = 262,144) 个参数。
       - 偏置大小为 (512),因此有 (512) 个参数。

  3. Layer dense_relu_sequential.4:
       - 权重大小为 (10 \times 512),因此有 (10 \times 512 = 5,120) 个参数。
       - 偏置大小为 (10),因此有 (10) 个参数。

现在,我们将所有层的参数数量相加:

  • 总权重参数 = (401,408 + 262,144 + 5,120 = 668,672)
  • 总偏置参数 = (512 + 512 + 10 = 1,034)

所以,该模型总共有 (668,672 + 1,034 = 669,706) 个参数。

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter


x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias
print(w.shape,"\n",b.shape)

(5, 3) 
     (3,)

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

loss = function(x, y, w, b)
print(loss)

1.0487777

# #help( ops.binary_cross_entropy_with_logits(x, y, ops.ones_like(y), ops.ones_like(y)))
# 至于后面的两个参数 ops.ones_like(z):

# 在一般的二元交叉熵实现中,通常不直接需要这样的参数。这些参数可能意在代表某种形式的平滑(smoothing)或权重(weights),但直接传递与z形状相同的全1张量作为默认行为并不常见。
# 可能的平滑效果: 有时候为了防止log(0)的情况(当预测概率恰好为0或1时导致的数学计算问题),会稍微"平滑"预测概率或标签。但这通常涉及在概率上加一个小的正值(如1e-6),而不是传递全1张量。

# 权重: 如果是用来给不同样本的损失赋予不同的权重,通常这些权重会根据实际情况定制,比如根据样本的难易程度或类别不平衡问题来设定,并非简单地全为1。

# 因此,如果这不是特定框架或实现的特殊要求(比如某些自定义或特定场景下的二叉交叉熵变体),那么这两个ops.ones_like(z)参数可能不是必要的,或者它们的意图没有清晰地从代码中体现出来。标准的二元交叉熵损失计算通常仅需 logits (z) 和 真实标签 (y)。如果这些参数确实代表某种特殊处理(如自定义的平滑技术或权重策略),则需要查阅相关框架或函数的具体文档来明确其功能。
grad_fn = mindspore.grad(function, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
    [[ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01]))

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z  ##同时输出z后对导数有影响  微分函数会求所有输出项对参数的导数
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
    [[ 1.09277666e+00,  1.00154400e+00,  1.31337512e+00],
     [ 1.09277666e+00,  1.00154400e+00,  1.31337512e+00],
     [ 1.09277666e+00,  1.00154400e+00,  1.31337512e+00],
     [ 1.09277666e+00,  1.00154400e+00,  1.31337512e+00],
     [ 1.09277666e+00,  1.00154400e+00,  1.31337512e+00]]), Tensor(shape=[3], dtype=Float32, value= [ 1.09277666e+00,  1.00154400e+00,  1.31337512e+00]))

在自动微分的上下文中,如MindSpore或其他深度学习框架中,函数的返回值会影响梯度计算的行为,尤其是当这些返回值(如这里的z)与计算图中的其他变量存在依赖关系时。具体到您的例子中,function_with_logits不仅返回了损失(loss),还返回了模型的logits(z)。这样做可以间接影响到梯度计算过程,原因在于:

为何输出z会影响求导?

  1. 计算图结构的变化:在自动微分系统中,每次调用操作(如ops.matmulops.binary_cross_entropy_with_logits等)都会在后台构建一个计算图。这个计算图记录了数据流动和变换的过程,是后续求导的基础。当你改变函数返回的内容时,实际上是在告诉系统哪些部分的计算图是“有用的”,从而影响了哪些部分会被纳入梯度计算。

  2. 梯度累积的目标:当函数只返回损失(loss)时,自动微分系统知道它需要计算的是使得这个损失最小化的梯度,即针对wb的梯度。而一旦同时返回z,系统可能会尝试计算关于输入x, y, w, b使得z和损失同时达到某种优化状态的梯度,尽管在这个特定的例子中,直接关注于损失是最常见的需求。

举例说明

假设有一个简化的例子,我们仅考虑函数返回损失L和不返回z的情况:

def loss_only(x, y, w, b):
    z = x * w + b
    L = (z - y)**2  # 简化示例,用平方误差代替交叉熵
    return L

在这种情况下,如果你对(w, b)求导,自动微分系统将计算dL/dwdL/db,这是用来更新权重以减少损失的梯度。

然而,如果函数同时返回z

def loss_and_z(x, y, w, b):
    z = x * w + b
    L = (z - y)**2
    return L, z

此时,虽然直观上我们仍主要关心如何通过调整wb来减小L,但由于z也被显式包含在返回值中,理论上自动微分系统会为所有输入(x, y, w, b)计算梯度,即使最终可能只关心(w, b)的梯度。这种情况下,如果求导索引指定为(w, b),结果应与仅返回L的情况相同,但框架内部的计算逻辑或效率可能会有所区别,尤其是在复杂模型或特定框架实现细节中。

结论

总之,虽然在这个特定的例子中同时输出z理论上不应直接影响到对(w, b)求导的结果(如果求导时明确指定了对哪些参数感兴趣),但它确实改变了计算图的结构,可能影响到框架内部的优化路径或资源消耗。实践中,除非有特殊目的(比如想要同时分析或利用z的梯度信息),通常只返回损失函数即可,以保持计算的高效和简洁。

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
    [[ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01]))

Auxiliary data

Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)

grads, (z,) = grad_fn(x, y, w, b)
print(grads, z) #前述章节主要根据计算图对应的函数介绍了MindSpore的函数式自动微分

(Tensor(shape=[5, 3], dtype=Float32, value=
    [[ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01])) [-0.95275986 -5.370149    2.753761  ]

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z
    
# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()


# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params()) 
#由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。
#对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。
loss, grads = grad_fn(x, y)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
    [[ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01],
     [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 9.27767083e-02,  1.54396100e-03,  3.13375115e-01]))

print(model.trainable_params())

[Parameter (name=w, shape=(5, 3), dtype=Float32, requires_grad=True), Parameter (name=b, shape=(3,), dtype=Float32, requires_grad=True)]

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

##数据准备
def datapipe(path, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = MnistDataset(path)
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

train_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)

#网络合模型
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()


##超参
epochs = 3
batch_size = 64
learning_rate = 1e-2

## 实例化损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

# > 在训练过程中,通过微分函数可计算获得参数对应的梯度,将其传入优化器中即可实现参数优化,具体形态如下:
# >
# > grads = grad_fn(inputs)
# >
# > optimizer(grads)

## 开始训练
# Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)  ###前向过程调用loss函数计算loss
    return loss, logits

# Get gradient function 求梯度这里把前向计算(包含lossfuntion)和优化器中需要优化求导参数一起传入
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# Define function of one-step training
#训练步骤是 求得损失和更新权重一起的
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)  #求梯度这里把前向计算(包含lossfuntion)和优化器中需要优化求导参数一起传入  所以训练步骤是 求得损失和更新权重一起的
    optimizer(grads) ##更新权重
    return loss

def train_loop(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()  ##训练模式
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        if batch==0:
            print("chuanru:",batch, (data, label))
        loss = train_step(data, label) ##传入数据  得到每个batch的loss

        if batch % 100 == 0: ##多少步后打印
            loss, current = loss.asnumpy(), batch
            print(f"batch % 100 == 0 loss current : {loss:>7f}  [{current:>3d}/{size:>3d}]")

            
def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)  ##推理模式
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        if test_loss==0:
            print("data, label chuanru:",data, label)        
        pred = model(data)  ###用模型来预测
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy() ###单独loss函数计算loss
        correct += (pred.argmax(1) == label).asnumpy().sum() 
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")            

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)
    
    file_sizes: 100%|███████████████████████████| 10.8M/10.8M [00:00<00:00, 128MB/s]
    Extracting zip file…
    Successfully downloaded / unzipped to ./

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(model, train_dataset)
    test_loop(model, test_dataset, loss_fn)
print("Done!")

Epoch 1
    -------------------------------
    chuanru: 0 (Tensor(shape=[64, 1, 28, 28], dtype=Float32, value=
    [[[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     …
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]]]), Tensor(shape=[64], dtype=Int32, value= [5, 7, 6, 6, 2, 6, 8, 6, 0, 1, 8, 1, 4, 3, 8, 8, 1, 2, 3, 5, 5, 0, 0, 5, 
     3, 1, 1, 4, 8, 0, 5, 4, 9, 7, 5, 4, 3, 6, 0, 3, 6, 6, 8, 3, 0, 0, 0, 0, 
     6, 6, 3, 2, 3, 6, 1, 6, 4, 3, 2, 7, 7, 6, 8, 2]))
    batch % 100 == 0 loss current : 2.316737  [  0/938]
    batch % 100 == 0 loss current : 1.621314  [100/938]
    batch % 100 == 0 loss current : 0.824378  [200/938]
    batch % 100 == 0 loss current : 0.611629  [300/938]
    batch % 100 == 0 loss current : 0.445010  [400/938]
    batch % 100 == 0 loss current : 0.469159  [500/938]
    batch % 100 == 0 loss current : 0.482957  [600/938]
    batch % 100 == 0 loss current : 0.353181  [700/938]
    batch % 100 == 0 loss current : 0.387352  [800/938]
    batch % 100 == 0 loss current : 0.228070  [900/938]
    data, label chuanru: [[[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     …
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]] [5 4 2 7 4 0 8 2 1 2 9 3 5 4 3 2 3 7 2 9 0 2 9 0 9 0 0 2 1 2 2 0 5 9 0 9 3
     1 6 1 3 9 4 0 6 2 1 3 1 6 4 9 4 0 6 3 3 9 6 7 3 1 7 2]
    Test: 
     Accuracy: 91.1%, Avg loss: 0.315627 
    
    Epoch 2
    -------------------------------
    chuanru: 0 (Tensor(shape=[64, 1, 28, 28], dtype=Float32, value=
    [[[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     …
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]]]), Tensor(shape=[64], dtype=Int32, value= [6, 3, 4, 6, 7, 0, 9, 2, 0, 8, 3, 3, 3, 7, 6, 8, 7, 2, 5, 6, 7, 3, 4, 8, 
     8, 5, 1, 6, 9, 4, 4, 7, 7, 8, 6, 0, 1, 7, 1, 9, 5, 4, 7, 3, 8, 1, 5, 4, 
     7, 9, 0, 1, 4, 8, 8, 1, 3, 2, 1, 1, 6, 0, 2, 8]))
    batch % 100 == 0 loss current : 0.310537  [  0/938]
    batch % 100 == 0 loss current : 0.211063  [100/938]
    batch % 100 == 0 loss current : 0.282362  [200/938]
    batch % 100 == 0 loss current : 0.495924  [300/938]
    batch % 100 == 0 loss current : 0.169102  [400/938]
    batch % 100 == 0 loss current : 0.166638  [500/938]
    batch % 100 == 0 loss current : 0.357913  [600/938]
    batch % 100 == 0 loss current : 0.381598  [700/938]
    batch % 100 == 0 loss current : 0.232926  [800/938]
    batch % 100 == 0 loss current : 0.100311  [900/938]
    data, label chuanru: [[[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     …
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]] [8 2 5 0 5 1 8 8 5 2 7 7 0 7 5 2 6 8 1 1 3 2 4 4 1 3 0 8 2 9 8 2 4 1 2 4 6
     2 0 5 2 4 4 7 0 2 1 1 2 3 4 5 2 1 9 6 2 2 6 2 9 7 1 6]
    Test: 
     Accuracy: 92.7%, Avg loss: 0.254948 
    
    Epoch 3
    -------------------------------
    chuanru: 0 (Tensor(shape=[64, 1, 28, 28], dtype=Float32, value=
    [[[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     …
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]],
     [[[-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       …
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01],
       [-4.24212962e-01, -4.24212962e-01, -4.24212962e-01 … -4.24212962e-01, -4.24212962e-01, -4.24212962e-01]]]]), Tensor(shape=[64], dtype=Int32, value= [7, 1, 4, 5, 0, 1, 3, 0, 2, 9, 5, 7, 1, 5, 0, 3, 8, 8, 9, 2, 3, 0, 6, 5, 
     3, 4, 2, 7, 1, 2, 8, 7, 8, 4, 2, 2, 3, 2, 0, 2, 4, 1, 4, 1, 6, 5, 1, 2, 
     2, 2, 3, 7, 7, 4, 1, 2, 1, 3, 1, 0, 7, 7, 9, 7]))
    batch % 100 == 0 loss current : 0.314135  [  0/938]
    batch % 100 == 0 loss current : 0.367540  [100/938]
    batch % 100 == 0 loss current : 0.210689  [200/938]
    batch % 100 == 0 loss current : 0.248528  [300/938]
    batch % 100 == 0 loss current : 0.182782  [400/938]
    batch % 100 == 0 loss current : 0.346134  [500/938]
    batch % 100 == 0 loss current : 0.191530  [600/938]
    batch % 100 == 0 loss current : 0.064461  [700/938]
    batch % 100 == 0 loss current : 0.173537  [800/938]
    batch % 100 == 0 loss current : 0.240742  [900/938]
    data, label chuanru: [[[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     …
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]
    
    
     [[[-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       …
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]
       [-0.42421296 -0.42421296 -0.42421296 … -0.42421296 -0.42421296
        -0.42421296]]]] [9 4 3 2 8 2 1 0 2 3 3 8 1 2 4 3 0 5 9 1 8 7 8 3 1 4 2 5 3 0 2 7 4 9 0 7 6
     5 4 9 7 9 5 6 4 5 1 3 6 9 8 6 8 3 2 5 6 1 3 6 1 4 1 9]
    Test: 
     Accuracy: 93.7%, Avg loss: 0.211543 
    
    Done!

# import numpy as np
# import mindspore
# from mindspore import nn
# from mindspore import Tensor

# def network():
#     model = nn.SequentialCell(
#                 nn.Flatten(),
#                 nn.Dense(28*28, 512),
#                 nn.ReLU(),
#                 nn.Dense(512, 512),
#                 nn.ReLU(),
#                 nn.Dense(512, 10))
#     return model
# model = network()
mindspore.save_checkpoint(model, "model.ckpt")

#model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
#> `param_not_load`是未被加载的参数列表,为空时代表所有参数均加载成功。

[]

#model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
mindspore.set_context(mode=mindspore.GRAPH_MODE)

graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)

(1, 10)

print("wancheng  author:yangge yyp 2024-6-21 第三可 训练网络和训练流程")

wancheng  author:yangge yyp 2024-6-21 第三可 训练网络和训练流程

  • 16
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
工程训练大赛开环打卡赛题是关于AI凸轮设计的任务,我们可以在CSDN上找到相关的应用程序(App)。 AI凸轮设计是一项利用人工智能技术开发的凸轮设计方法,它可以优化凸轮的设计过程,提高凸轮在机械系统中的性能。这个赛题要求我们设计一个可以帮助工程师进行AI凸轮设计的应用程序。 在CSDN上搜索相关内容,我们可以找到很多与AI凸轮设计相关的技术文章和项目案例。这些资源可以帮助我们了解AI凸轮设计的基本原理和应用方法。 针对这个任务,我们可以开发一个名为"AI凸轮设计App"的移动应用程序。这个App将提供以下功能: 1. 数据收集与分析:用户可以输入相关机械系统的参数和要求,然后App会收集并分析这些数据,为凸轮设计提供参考。 2. AI凸轮设计算法:App内嵌有AI凸轮设计算法模块,可以自动计算和优化凸轮的设计。 3. 可视化展示:App可以将AI凸轮设计的结果以可视化的方式展示给用户,包括凸轮的形状、尺寸和运动轨迹等。 4. 参数调整和优化:用户可以通过调整输入参数来进一步优化凸轮设计,App会即时更新设计结果。 5. 数据导出:用户可以将AI凸轮设计的结果导出为文件,以供其他工程软件或设备使用。 通过这个"AI凸轮设计App",工程师可以更加高效地设计凸轮,减少人工设计的工作量,提高设计的准确性和效率。这个App还可以为工程师提供参考,帮助他们更好地理解AI凸轮设计的原理和应用。 总的来说,通过开发这款应用程序,我们可以方便工程师在移动设备上进行AI凸轮设计,提高设计效率和准确性,推动工程技术的发展和应用。这也符合工程训练大赛开展的宗旨和主题。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值