import torch
import numpy as np
class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
"""
Network has to have NO NONLINEARITY!
"""
def __init__(self, weight=None):
super(WeightedCrossEntropyLoss, self).__init__()
self.weight = weight
def forward(self, logit, target):
wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)
loss = wce_loss(logit, target)
return loss
if __name__ == "__main__":
# logit.shape:[2,13,320,640]
logit = np.random.random((2, 13, 320, 640))
target1 = np.random.randint(0, 13, size=(320, 640))
target1 = target1[np.newaxis, :, :]
target2 = np.random.randint(0, 13, size=(320, 640))
target2 = target2[np.newaxis, :, :]
# target.shape:[2,320,640]
target = np.vstack([target1, target2])
# numpy --> tensor
logit = torch.tensor(logit)
target = torch.tensor(target).long()
# loss forword
weight = torch.tensor([0.1, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,0.01, 0.01, 0.01,0.01,0.1,0.7]).to(torch.double)
F = WeightedCrossEntropyLoss(weight)
loss = F.forward(logit, target)
print(loss)
WeightedCrossEntropyLoss 源码
最新推荐文章于 2023-08-24 15:00:03 发布