loss = nn.KLDivLoss(reduce=False)
batch_size = 5
log_probs1 = F.log_softmax(torch.randn(batch_size, 10), 1)
probs2 = F.softmax(torch.randn(batch_size, 10), 1)
loss(log_probs1, probs2) / batch_size
loss = nn.KLDivLoss(reduce=False)
batch_size = 5
log_probs1 = F.log_softmax(torch.randn(batch_size, 10), 1)
probs2 = F.softmax(torch.randn(batch_size, 10), 1)
loss(log_probs1, probs2) / batch_size