转onehot格式
1.首先用到scatter函数(scatter_原地改变数值)
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
## 下面0为dim表示行index,1表示按列进行赋值
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1
, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
[ 0.0000, 0.0000, 0.0000, 1.2300]])
In []: z = torch.zeros(2, 4).scatter_(1, torch.tensor([[1, 2], [1, 3]]), 1.23
...: )
In []: z
Out[]:
tensor([[0.0000, 1.2300, 1.2300, 0.0000],
[0.0000, 1.2300, 0.0000, 1.2300]])
具体代码为
label = torch.tensor([0,0,1])
onehot_ = torch.FloatTensor(label.shape[0], 3)
onehot.zero_()
onehot.scatter_(1, torch.reshape(label, (3,1)), 1)
计算crosstropy
import torch
from torch import nn
lable = torch.tensor([0,0,1])
fc_out = torch.tensor(
[
[2.5, -2, 0.8989],
[3, 0.8, -865],
[0.00000000000001, 2, 4.9]
])
class CrossEntropyLoss(nn.Module):
def __init__(self):
super(CrossEntropyLoss, self).__init__()
def forward(self, fc_out, label):
one_hot_lable = torch.FloatTensor(fc_out.shape[0], 3)
one_hot_lable.zero_()
one_hot_lable.scatter_(1, torch.reshape(lable, (fc_out.shape[0], 1)), 1)
loss = one_hot_lable * torch.softmax(fc_out, 1)
loss = -torch.sum(torch.log(torch.sum(loss, 1)))/fc_out.shape[0]
return loss
loss = torch.nn.CrossEntropyLoss()
loss1 = CrossEntropyLoss()
l = loss(fc_out, lable)
l2 = loss1(fc_out, lable)
print(l)
print(l2)