核心代码(可以直接运行查看各个Shape)
- 参考PaddleOCR中
configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml
的实现 - 以下代码只给出了搭建模型和推理的代码,前后处理代码均已省略。
import torch
from torch import nn
class DBHead(nn.Module):
def __init__(self, in_channels, k=50, is_train=False):
super().__init__()
self.k = k
self.binarize = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),
nn.Sigmoid(),
)
self.thresh = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, 3, padding=1, bias=False),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),
nn.Sigmoid()
)
self.is_train = is_train
self.k = k
def forward(self, x):
shrink_maps = self.binarize(x)
threshold_maps = self.thresh(x)
if self.is_train:
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
else:
y = torch.cat([shrink_maps, threshold_maps], dim=1)
return y
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
class MaskL1Loss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, pred, gt, mask):
loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + 1e-6)
return torch.mean(loss)
class DiceLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, pred, gt, mask):
intersection = torch.sum(pred * gt * mask)
union = torch.sum(pred * mask) + torch.sum(gt * mask) + 1e-6
loss = 1 - 2.0 * intersection / union
return loss
class BalanceLoss(nn.Module):
def __init__(self, negative_ratio=3, main_loss_type='DiceLoss'):
super().__init__()
self.negative_ratio = negative_ratio
self.eps = 1e-6
if main_loss_type == 'DiceLoss':
self.loss = DiceLoss()
def forward(self, pred, gt, mask=None):
positive = gt * mask
negative = (1 - gt) * mask
positive_count = int(positive.sum())
negative_count = int(min(negative.sum(),
positive_count * self.negative_ratio))
loss = self.loss(pred, gt, mask=mask)
positive_loss = positive * loss
negative_loss = negative * loss
if negative_count > 0:
negative_loss = torch.reshape(negative_loss, shape=[-1])
sort_loss = negative_loss.sort(descending=True)[0]
negative_loss = sort_loss[:negative_count]
balance_loss = (positive_loss.sum() + negative_loss.sum()) / \
(positive_count + negative_count + self.eps)
else:
balance_loss = positive_loss.sum() / (positive_count + self.eps)
return balance_loss
if __name__ == '__main__':
x = torch.randn(1, 96, 240, 240)
label_threshold_map = torch.randn(1, 960, 960)
label_threshold_mask = torch.randint(low=0, high=2, size=(1, 960, 960))
label_shrink_map = torch.randn(1, 960, 960)
label_shrink_mask = torch.randint(low=0, high=2, size=(1, 960, 960))
model = DBHead(96, 3, is_train=True)
y = model(x)
l1_loss = MaskL1Loss()
dice_loss = DiceLoss()
bce_loss = BalanceLoss()
shrink_maps = y[:, 0, :, :]
threshold_maps = y[:, 1, :, :]
binary_maps = y[:, 2, :]
loss_threshold_maps = l1_loss(threshold_maps,
label_threshold_map,
label_threshold_mask)
loss_shrink_maps = bce_loss(shrink_maps,
label_shrink_map,
label_shrink_mask)
loss_binary_maps = dice_loss(binary_maps,
label_shrink_map,
label_shrink_mask)
alpha, beta = 5, 10
loss_shrink_maps = alpha * loss_shrink_maps
loss_threshold_maps = beta * loss_threshold_maps
loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
print(loss_all)
shrink_map
作用:用作和pred结果计算差距shrink_mask
作用:让损失计算聚焦于mask部分,示例代码:torch.abs(pred - gt) * mask