/home/wuchenxi/mmdetection/mmdet/models/anchor_heads/anchor_head.py
from future import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, multiclass_nms, force_fp32)
from …builder import build_loss
from …registry import HEADS
@HEADS.register_module
class AnchorHead(nn.Module):
“”"Anchor-based head (RPN, RetinaNet, SSD, etc.).
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels of the feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],#1:2,1:1,2:1的anchor比例
anchor_strides=[4, 8, 16, 32, 64], #初始的base_size就是通过这个生成的,个数应和FPN的输出层数相等
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes #生成base_size部分代码
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.sampling = loss_cls['type'] not in ['FocalLoss', 'GHMC']
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
else:
self.cls_out_channels = num_classes
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.fp16_enabled =