centerNet中计算分类target的方法就是使用高斯公式
对于每一个gt_box,我们首先得到中心点,然后就可以根据这个中心点和特征图的长宽计算置信度图了
最后算出来的置信度是以中心点为圆心,整个特征图外接圆的半径为半径的高斯分布置信度
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
import time
# 计算矩形外接圆的半径
# R = sqrt(h^2 + w^2)/2
def get_radius(height, width):
r = math.sqrt(height**2 + width**2) / 2
return int(r)
def gaussian2d(shape, sigma=1):
shape[0] = int(shape[0])
shape[1] = int(shape[1])
mask_gaussian = torch.zeros(shape[0], shape[1])
x, y = shape[0]//2, shape[1]//2
quanzhong = 1/(math.sqrt(2*3.1415)*sigma)
fenmu = 2*sigma**2
for i in range(shape[0]):
for j in range(shape[1]):
value = math.exp(-((i-x)**2+(j-y)**2)/fenmu)
mask_gaussian[i][j] = quanzhong*value
return mask_gaussian
# heatmap说特征图,center记录的是中心点的坐标
# radius是半径,k是高斯值被放大的倍数
def draw_gaussian_mask(heatmap, center, radius, k=1):
center_x = center[0]
center_y = center[1]
height = heatmap.shape[0]
width = heatmap.shape[1]
diameter = radius*2 + 1
gaussian = gaussian2d([diameter, diameter], sigma=diameter/6)
# 计算边界,防止越界
left = min(center_x, radius)
right = min(width - center_x, radius+1)
top = min(center_y, radius)
bottom = min(height-center_y, radius+1)
# 将相应区域取出来,这里修改mask_heatmap的时候heatmap也会相应改变
masked_heatmap = heatmap[center_y-top:center_y+bottom, center_x-left:center_x+right]
masked_gaussian = gaussian[(radius-top):(radius+bottom),(radius-left):(radius+right)]
# 更新网络
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap.numpy())
plt.imshow(heatmap)
plt.colorbar()
plt.show()
return heatmap
if __name__ == "__main__":
st = time.time()
heatmap = torch.zeros(224,224)
center = [125,125]
radius = get_radius(heatmap.shape[0], heatmap.shape[1])
heatmap = draw_gaussian_mask(heatmap, center, radius)
print("cost time:%.4fs"%(time.time()-st))