Pytorch中的AnchorGenerator
在pytorch中, AnchorGenerator主要用于生成候选框,该类存储在torchvision/models/detection/rpn.py中。
#创建AnchorGenerator类
from torchvision.models.detection.rpn import AnchorGenerator
generator = AnchorGenerator()
该类继承于nn.Module, 因此包含forward属性
要获得generator首先要有一个3通道的图像,以及图像对应的features
首先生成一幅伪图像
#利用随机函数创建图像, 生成1幅600×800的3通道图像
import torch
image = torch.randn(1,3, 600, 800)
然后生成该图像对应的features
#利用resnet与金字塔的结合模型生成图像的特征图
import torchvision
from torchvision.models.detection.backbone_utils import BackboneWithFPN
#创建一个50层的resnet
resnet = torchvision.models.resnet.resnet50(pretrained=False)
#给该resnet加上FeaturePyramidNetword
#layer1..layer4分别是resnet中的卷积层的属性名
return_layers = {'layer1': '0', 'layer2': '1', 'layer