需求:给定10个关键点的heatmaps(大小为),对其沿着第0个通道求和后得到总的heatmap(大小为),求heatmap上的极大值点(即关键点)。
代码如下:
import os
import SimpleITK as sitk
import numpy as np
from datasets.utils import generate_heatmaps
import matplotlib.pyplot as plt
from skimage.feature import peak_local_max
def generate_heatmaps(image, spacing, gt_coords, sigma=3.0):
'''
generate the heat maps according to the physical distance
:param image: a numpy array with shape of (h, w)
:param spacing: a numpy array, i.e., np.array([w_size, h_size])
:param gt_coords: a numpy array with shape of (point_num, 2)
:param sigma:
:return: a numpy array with shape of (point_num, h, w)
'''
coord = np.where(image < np.inf)
# 注意需要反转横纵坐标
coord = np.stack(coord[::-1], axis=1).reshape(image.shape[0], image.shape[1], 2)
dist = []
for point in gt_coords:
d = (((coord - point) * spacing) ** 2).sum(axis=-1)
dist.append(np.exp(-d / (2.0 * (sigma ** 2))))
# dist.append((((coord - point) * spacing) ** 2).sum(dim=-1).sqrt())
dist = np.stack(dist, axis=0)
return dist.astype(np.float32)
if __name__ == '__main__':
data_dir = '/home/pangshumao/data/Spine_Localization_PIL/in'
kp_dir = '/home/pangshumao/data/Spine_Localization_PIL/in/keypoints'
kp_data = np.load(os.path.join(kp_dir, 'Case300.npz'))
coords = kp_data['coords'] # a numpy array with shape of (10, 2), each row is the coordinate of a keypoint like (x, y)
pixelspacing = kp_data['pixelspacing'] # a numpy array with shape of (2, )
names = ['L1', 'L2', 'L3', 'L4', 'L5',
'L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
reader = sitk.ImageFileReader()
reader.SetFileName(os.path.join(data_dir, 'weak_supervised_MR', 'Case300.dcm'))
image = reader.Execute()
image_array = sitk.GetArrayFromImage(image) # 1, h, w
mr_image = image_array[0]
heatmaps = generate_heatmaps(mr_image, pixelspacing, coords, sigma=3.0) # a numpy array with shape of (10, h, w)
heatmap = np.sum(heatmaps, axis=0)
h_c = peak_local_max(heatmap, min_distance=10) # a numpy array with shape of (n, 2), each row is the coordinate of a keypoint like (y, x)
# display results
fig, axes = plt.subplots(1, 3, figsize=(14, 6), sharex=True, sharey=True)
ax = axes.ravel()
ax[0].imshow(mr_image, cmap=plt.cm.gray)
for i in range(coords.shape[0]):
ax[0].scatter(x=coords[i, 0], y=coords[i, 1], c='r')
ax[0].text(x=coords[i, 0] + 5, y=coords[i, 1] + 5, s=names[i], c='w')
ax[0].axis('off')
ax[0].set_title('Original keypoints')
ax[1].imshow(heatmap)
ax[1].axis('off')
ax[1].set_title('heatmap')
ax[2].imshow(heatmap, cmap=plt.cm.gray)
for i in range(h_c.shape[0]):
ax[2].scatter(x=h_c[i, 1], y=h_c[i, 0], c='r')
ax[2].text(x=h_c[i, 1] + 5, y=h_c[i, 0] + 5, s=str(i), c='w')
ax[2].axis('off')
ax[2].set_title('Local maximum')
fig.tight_layout()
plt.show()
结果如图所示,其中右图的标号[0,9]表示求得的极值点出现的顺序。
参考:https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_peak_local_max.html