官方文档写的是 weight的形状为【C】
eg:5个类的话 weight=tensor[1,1,1,1,1]
然后看了一下示例代码:
`import paddle
import numpy as np
input_data = paddle.uniform([5, 100], dtype="float64")
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
weight_data = np.random.random([100]).astype("float64")
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
weight = paddle.to_tensor(weight_data)
ce_loss = paddle.nn.CrossEntropyLoss(weight=weight, reduction='mean')
output = ce_loss(input, label)
print(output)`
这里面,weight的shape是[100],但是这里不是只有5个类嘛。没有理解,请求帮助。谢谢!