我发现网上大多数MultiBoxLoss的代码是有bug的,导致训练时不定时的崩溃
比如这个代码:
https://zhuanlan.zhihu.com/p/77868999
正确代码:
# -*- coding:utf-8 -*-
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import configparser
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from focal_loss import FocalLoss
from utils.box_utils import match, log_sum_exp
from data import cfg_mnet
GPU = cfg_mnet['gpu_train']
class MultiBoxLoss(nn.Module):
"""SSD Weighted Loss Function
Compute Targets:
1) Produce Confide