当有多个参数时可以通过设置has_aux=True来排除辅助参数
通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。
CE 代表 cross_entropy 交叉熵
)
# Define model class Network(nn.Cell): def __init__(self): super().__init__() self.w = w self.b = b def construct(self, x): z = ops.matmul(x, self.w) + self.b return z # Instantiate model model = Network() # Instantiate loss function loss_fn = nn.BCEWithLogitsLoss() # Define forward function def forward_fn(x, y): z = model(x) loss = loss_fn(z, y) return loss, z grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params(), has_aux=True) loss, grads = grad_fn(x, y) print(grads)