这段代码计算交差熵的代码分析: def cross_entropy(y_hat, y): return -torch.log(y_hat[list(range(len(y_hat))), y]).mean() cross_entropy(y_hat, y)
分类操作利用one-hot编码,如一个正确的分类P用[0,0,1]等方式表示,预测该利率用Q表示[0.2,0.4,0.4],那么,
PlogQ = 0*log0.2+0*log0.4+1*log0.4 表示。
-torch.log(y_hat[list(range(len(y_hat))), y]).mean()中的参数部分可认为是利用上面方式查找非零对应的项,如:
y = torch.tensor([0, 2]) y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]]) y_hat[[0, 1], y] -> y_hat[[0, 1], [0,2]] 输出:tensor([0.1000, 0.5000])
y_hat[[0, 1], y],前面[0,1]可理解为在y_hat[0]和y_hat[1]中的数据,[0,2]分别对应前面两个张量中对应的编号。
#0.1是y_hat[0]在[0.1,0.3,0.6]里第y[0] 0.5是y_hat[1]在[0.3,0.,0.5]里第y[2]