AnchorGenerator类简介

Pytorch中的AnchorGenerator

在pytorch中, AnchorGenerator主要用于生成候选框,该类存储在torchvision/models/detection/rpn.py中。

#创建AnchorGenerator类
from torchvision.models.detection.rpn import AnchorGenerator
generator = AnchorGenerator()

该类继承于nn.Module, 因此包含forward属性
要获得generator首先要有一个3通道的图像,以及图像对应的features
首先生成一幅伪图像

#利用随机函数创建图像, 生成1幅600×800的3通道图像
import torch
image = torch.randn(1,3, 600, 800)   

然后生成该图像对应的features

#利用resnet与金字塔的结合模型生成图像的特征图 
import torchvision
from torchvision.models.detection.backbone_utils import BackboneWithFPN

#创建一个50层的resnet
resnet = torchvision.models.resnet.resnet50(pretrained=False)    
#给该resnet加上FeaturePyramidNetword
#layer1..layer4分别是resnet中的卷积层的属性名
return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
#inchannels_stage2是layer1..layer4卷积层输出的通道数
in_channels_stage2 = resnet.inplanes // 8
in_channels_list = [
        in_channels_stage2,
        in_channels_stage2 * 2,
        in_channels_stage2 * 4,
        in_channels_stage2 * 8,
    ]
out_channels = 256  #金字塔层输出的通道数
#生成与金字塔相融合的resnet网络模型
resnet_fpn = BackboneWithFPN(resnet, return_layers, in_channels_list, out_channels)

至此我们创建好了用于生成特征向量的的模型。
接下来将image转成向量

#获得特征向量
features = resnet_fpn(image)

下面来看一下features的样子

print(type(features))
for key in features:
	print(key,"的大小:",features[key].shape)

得到的输出结果如下:

<class 'collections.OrderedDict'>
0 的大小: torch.Size([1, 256, 150, 200])
1 的大小: torch.Size([1, 256, 75, 100])
2 的大小: torch.Size([1, 256, 38, 50])
3 的大小: torch.Size([1, 256, 19, 25])
pool 的大小: torch.Size([1, 256, 10, 13])

可以看到输出的feature是OrderedDict类型,共包含了5个特征图
下面将图像和特征图输入到AnchorGenerator中
由于AnchorGenerator.foword的输入参数分别是ImageList和List类型,我们需要对上面对应的变量进行转换

from torchvision.models.detection.image_list import ImageList
#对image进行转换
#ImageList第一个参数是tensor类型,第二个参数是List[Tuple[int, int]]类型
#当包含多幅图像时,第二个参数应分别累出每幅图像的大小:list([(600,800),(600,800),(600,800),])
imglist = ImageList(image, list([(600,800)]))

#转换feature为list类型
flist = list(features.values())

最后看一下生成anchors

#最后将imglist和flist输入AnchorGenerator
anchors = generator(imglist,flist)

print(type(anchors))
print(len(anchors))
print(anchors[0][1:10,:])

输出结果:

<class 'list'>
1
tensor([[-64., -64.,  64.,  64.],
        [-45., -91.,  45.,  91.],
        [-87., -45.,  95.,  45.],
        [-60., -64.,  68.,  64.],
        [-41., -91.,  49.,  91.],
        [-83., -45.,  99.,  45.],
        [-56., -64.,  72.,  64.],
        [-37., -91.,  53.,  91.],
        [-79., -45., 103.,  45.]])

为什么anchor中好多负数?

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值