Pytorch中的AnchorGenerator
在pytorch中, AnchorGenerator主要用于生成候选框,该类存储在torchvision/models/detection/rpn.py中。
#创建AnchorGenerator类
from torchvision.models.detection.rpn import AnchorGenerator
generator = AnchorGenerator()
该类继承于nn.Module, 因此包含forward属性
要获得generator首先要有一个3通道的图像,以及图像对应的features
首先生成一幅伪图像
#利用随机函数创建图像, 生成1幅600×800的3通道图像
import torch
image = torch.randn(1,3, 600, 800)
然后生成该图像对应的features
#利用resnet与金字塔的结合模型生成图像的特征图
import torchvision
from torchvision.models.detection.backbone_utils import BackboneWithFPN
#创建一个50层的resnet
resnet = torchvision.models.resnet.resnet50(pretrained=False)
#给该resnet加上FeaturePyramidNetword
#layer1..layer4分别是resnet中的卷积层的属性名
return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
#inchannels_stage2是layer1..layer4卷积层输出的通道数
in_channels_stage2 = resnet.inplanes // 8
in_channels_list = [
in_channels_stage2,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = 256 #金字塔层输出的通道数
#生成与金字塔相融合的resnet网络模型
resnet_fpn = BackboneWithFPN(resnet, return_layers, in_channels_list, out_channels)
至此我们创建好了用于生成特征向量的的模型。
接下来将image转成向量
#获得特征向量
features = resnet_fpn(image)
下面来看一下features的样子
print(type(features))
for key in features:
print(key,"的大小:",features[key].shape)
得到的输出结果如下:
<class 'collections.OrderedDict'>
0 的大小: torch.Size([1, 256, 150, 200])
1 的大小: torch.Size([1, 256, 75, 100])
2 的大小: torch.Size([1, 256, 38, 50])
3 的大小: torch.Size([1, 256, 19, 25])
pool 的大小: torch.Size([1, 256, 10, 13])
可以看到输出的feature是OrderedDict类型,共包含了5个特征图
下面将图像和特征图输入到AnchorGenerator中
由于AnchorGenerator.foword的输入参数分别是ImageList和List类型,我们需要对上面对应的变量进行转换
from torchvision.models.detection.image_list import ImageList
#对image进行转换
#ImageList第一个参数是tensor类型,第二个参数是List[Tuple[int, int]]类型
#当包含多幅图像时,第二个参数应分别累出每幅图像的大小:list([(600,800),(600,800),(600,800),])
imglist = ImageList(image, list([(600,800)]))
#转换feature为list类型
flist = list(features.values())
最后看一下生成anchors
#最后将imglist和flist输入AnchorGenerator
anchors = generator(imglist,flist)
print(type(anchors))
print(len(anchors))
print(anchors[0][1:10,:])
输出结果:
<class 'list'>
1
tensor([[-64., -64., 64., 64.],
[-45., -91., 45., 91.],
[-87., -45., 95., 45.],
[-60., -64., 68., 64.],
[-41., -91., 49., 91.],
[-83., -45., 99., 45.],
[-56., -64., 72., 64.],
[-37., -91., 53., 91.],
[-79., -45., 103., 45.]])
为什么anchor中好多负数?