#!/usr/bin/env python
# -*- coding: utf-8 -*-
import time
import torch
import torch
v = 0.5 # 1-0.0001
v1 = v - 0.01
a = torch.FloatTensor([[v, v1, v],[v, v1, v]])
b = torch.FloatTensor([[0, 0, 0],[1, 1, 1]])
loss_fn = torch.nn.BCELoss() # reduce=False, size_average=False)
x = loss_fn(a, b).item()
print(x)
a=a.view(-1,2)
print(a)
b=b.long()
cross= torch.nn.CrossEntropyLoss(reduction='sum')
aaa= cross(a, b)
print(aaa)
# tensor([[0.1933, 0.1425, 0.8572, 0.0224, 0.3811],
# [0.6134, 0.9766, 0.6086, 0.0163, 0.1514]])
# tensor([[2, 4, 0, 1, 3],
# [1, 0, 2, 4, 3]])
# tensor([[2, 3, 0, 4, 1],
# [1, 0, 2, 4, 3]])
#最小的索引在第2个位置,次小的索引在第3位
# tensor([[False, False, True, False, False],
# [True, True, True, False, False]])
# tensor([0.8572, 0.6134, 0.9766, 0.6086])