问题描述:
nn.SoftmaxCrossEntropyWithLogits
【操作步骤&问题现象】
使用nn.SoftmaxCrossEntropyWithLogits 这个loss时,发现模型输出为80个类的logits,而数据集提供的label类别有400个,依然可以使用model.train训练,这是正常的么?不应该是logits和label的类别数量不一致然后无法训练么?还是说它会对logits进行补0?
解答:
nn.SoftmaxCrossEntropyWithLogits API文档里没有要求输入shape一致。从官网例子来看,sparse为True时,logits shape为(N, C),labels shape为(N)
比较与tf.nn.softmax_cross_entropy_with_logits的功能差异 — MindSpore master documentation
这里比较了该算子与TensorFlow的差异