详见代码:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import matplotlib.patches as patches
from PIL import Image
# 原图片的大小,宽为W,高为H
W = 256
H = 256
# 下采样的倍数
rpn_stride = 8 # times downsampling
# Conv提取特征后feature maps的宽和高
w = W/rpn_stride
h = H/rpn_stride
# scale 和 ratios(其中,scale为anchor box的宽和高之和。ratio为之比)
scales = [3 ,5 ,9] # sum of w and h
ratios = [0.5, 1 ,2 ] # 3 ratios
def anchor(w,h,rpn_stride,scales,ratios):
'''
input : feature maps的w和h
rpn_stride 下采样的倍数,用于映射anchor boxs到原图
scales,ratios anchor box的设置
output :numpy.ndarray shape=(w*h*k,4)
'''
# combinations of scales and ratios
scales , ratios = np.meshgrid(scales,ratios)
scales , ratios = scales.flatten() , ratios.flatten()
# calculating w,h of anchors
anchorbox_Ws = scales * np.sqrt(ratios)
anchorbox_Hs = scales / np.sqrt(ratios)
# mapping anchor porints to raw input
raw_xs = np.arange(0,w) * rpn_stride
raw_ys = np.arange(0,h) * rpn_stride
###############################################################################
# combinations of anchor points in raw input
raw_xs , raw_ys = np.meshgrid(raw_xs,raw_ys)
# 9 anchor boxs for each anchor points
centerXs , anchorbox_Ws = np.meshgrid(raw_xs ,anchorbox_Ws)
centerYs , anchorbox_Hs = np.meshgrid(raw_ys ,anchorbox_Hs)
anchor_center = np.stack([centerYs,centerXs],axis=2).reshape(-1,2)
anchor_size = np.stack([anchorbox_Hs,anchorbox_Ws],axis=2).reshape(-1,2)
###############################################################################
# upper left ,low right
boxes = np.concatenate([anchor_center-0.5*anchor_size ,anchor_center+0.5*anchor_size],axis=1)
return boxes
anchors = anchor(w,h,rpn_stride,scales,ratios)
print(anchors.shape)
'''
# cv2.imread
img=cv2.imread('timg.jpg')
img=cv2.resize(img,(W,H))
cv2.imshow('img',img)
'''
plt.figure(figsize=(10,10))
#img=Image.open('timg.jpg')
img=cv2.imread('timg.jpg',cv2.IMREAD_COLOR)
img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img=cv2.resize(img,(W,H))
plt.imshow(img)
asx = plt.gca()
for i in range(anchors.shape[0]):
box = anchors[i]
rec = patches.Rectangle((box[0],box[1]),box[2]-box[0],box[3]-box[1],edgecolor='r',facecolor='none')
asx.add_patch(rec)
plt.show()
labeled by anchor boxes: