下面的代码来自Distraction-aware Shadow Detection
class MyWcploss(nn.Module):
def __init__(self):
super(MyWcploss, self).__init__()
def forward(self, pred, gt):
eposion = 1e-10
sigmoid_pred = torch.sigmoid(pred)
count_pos = torch.sum(gt)*1.0+eposion
count_neg = torch.sum(1.-gt)*1.0
beta = count_neg/count_pos
beta_back = count_pos / (count_pos + count_neg)
bce1 = nn.BCEWithLogitsLoss(pos_weight=beta)
loss = beta_back*bce1(pred, gt)
return loss
上述代码实现了:
第一次看的时候没有看懂,所以去查BCEWithLogitsLoss,
对于pos_weight参数很疑惑,看完下面的公式,就清楚了。
详细可见:BCEWithLogitsLoss pytorch 官方文档