easy-Fpn源码解读(五):rpn
region_proposal_network.py代码解析
from typing import Tuple, List
import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from bbox import BBox
from nms.nms import NMS
class RegionProposalNetwork(nn.Module):
def __init__(self, num_features_out: int, anchor_ratios: List[Tuple[int, int]], anchor_scales: List[int], pre_nms_top_n: int, post_nms_top_n: int):
super().__init__()
self._features = nn.Sequential(
nn.Conv2d(in_channels=num_features_out, out_channels=512, kernel_size=3, padding=1),
nn.ReLU()
)
self._anchor_ratios = anchor_ratios
self._anchor_scales = anchor_scales
num_anchor_ratios = len(self._anchor_ratios)
num_anchor_scales = len(self._anchor_scales)
num_anchors = num_anchor_ratios * num_anchor_scales
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self._objectness = nn.Conv2d(in_channels=512, out_channels=num_anchors * 2, kernel_size=1)
self._transformer = nn.Conv2d(in_channels=512, out_channels=num_anchors * 4, kernel_size=1)
def forward(self, features: Tensor, image_width: int, image_height: int) -> Tuple[Tensor, Tensor]:
features = self._features(features)
objectnesses = self._objectness(features)
transformers = self._transformer(features)