get_anchors,先验框设计
import numpy as np
import matplotlib.pyplot as plt
import config
"""anchors即先验框"""
def get_anchors(shape,width,height):
anchors = generate_anchors()
network_anchors = shift(shape,anchors)
network_anchors[:,0] = network_anchors[:,0]/width
network_anchors[:,1] = network_anchors[:,1]/height
network_anchors[:,2] = network_anchors[:,2]/width
network_anchors[:,3] = network_anchors[:,3]/height
network_anchors = np.clip(network_anchors,0,1)
return network_anchors
def generate_anchors(sizes=None,ratios=None):
if sizes is None:
sizes = config.anchor_box_scales
if ratios is None:
ratios = config.anchor_box_ratios
num_anchors = len(sizes) * len(ratios)
anchors = np.zeros((num_anchors,4))
anchors[:,2:] = np.tile(sizes,(2,len(ratios))).T
for i in range(len(ratios)):
anchors[3*i:3*i+3,2] = anchors[3*i:3*i+3,2]*ratios[i][0]
anchors[3*i:3*i+3,3] = anchors[3*i:3*i+3,3]*ratios[i][1]
anchors[:,0:2] = anchors[:,0:2] - anchors[:,2:] * 0.5
anchors[:,2:] = anchors[:,2:] * 0.5
print('生成的9个不同尺度不同长宽比的anchors:\n',anchors)
return anchors
def shift(shape,anchors,stride=config.rpn_stride):
shift_x = (np.arange(0,shape[0],dtype=float) + 0.5) * stride
shift_y = (np.arange(0,shape[1],dtype=float) + 0.5) * stride
shift_x, shift_y = np.meshgrid(shift_x,shift_y)
shift_x = np.reshape(shift_x,[-1])
shift_y = np.reshape(shift_y,[-1])
shifts = np.stack([shift_x,
shift_y,
shift_x,
shift_y],axis=0)
shifts = np.transpose(shifts)
num_of_anchors = np.shape(anchors)[0]
k = np.shape(shifts)[0]
shifted_anchors = np.reshape(anchors,[1,num_of_anchors,4]) + np.array(np.reshape(shifts,[k,1,4]))
shifted_anchors = np.reshape(shifted_anchors,[k * num_of_anchors,4])
print(shifted_anchors,np.shape(shifted_anchors))
'''fig = plt.figure()
ax = fig.add_subplot(111)
plt.ylim(-300,900)
plt.xlim(-300,900)
plt.scatter(shift_x,shift_y)
box_widths = shifted_anchors[:,2] - shifted_anchors[:,0]
box_heights = shifted_anchors[:,3] - shifted_anchors[:,1]
initial = int(shape[0]*shape[1]/2*9+shape[0]/2*9)
for i in [initial+0,initial+1,initial+2,initial+3,initial+4,initial+5,initial+6,initial+7,initial+8]:
rect = plt.Rectangle([shifted_anchors[i,0],shifted_anchors[i,1]],box_widths[i],box_heights[i],color='r',fill=False)
ax.add_patch(rect)
plt.show()'''
return shifted_anchors
if __name__ == '__main__':
network_anchors = get_anchors((38,38),600,600)
print(np.shape(network_anchors))
print(network_anchors)