python求heatmaps图像上的极大值点

需求:给定10个关键点的heatmaps(大小为10\times h \times w),对其沿着第0个通道求和后得到总的heatmap(大小为h \times w),求heatmap上的极大值点(即关键点)。

代码如下:

import os
import SimpleITK as sitk
import numpy as np
from datasets.utils import generate_heatmaps
import matplotlib.pyplot as plt
from skimage.feature import peak_local_max

def generate_heatmaps(image, spacing, gt_coords, sigma=3.0):
    '''
    generate the heat maps according to the physical distance
    :param image: a numpy array with shape of (h, w)
    :param spacing: a numpy array, i.e., np.array([w_size, h_size])
    :param gt_coords: a numpy array with shape of (point_num, 2)
    :param sigma:
    :return: a numpy array with shape of (point_num, h, w)
    '''
    coord = np.where(image < np.inf)
    # 注意需要反转横纵坐标
    coord = np.stack(coord[::-1], axis=1).reshape(image.shape[0], image.shape[1], 2)

    dist = []
    for point in gt_coords:
        d = (((coord - point) * spacing) ** 2).sum(axis=-1)
        dist.append(np.exp(-d / (2.0 * (sigma ** 2))))
        # dist.append((((coord - point) * spacing) ** 2).sum(dim=-1).sqrt())
    dist = np.stack(dist, axis=0)

    return dist.astype(np.float32)

if __name__ == '__main__':
    data_dir = '/home/pangshumao/data/Spine_Localization_PIL/in'
    kp_dir = '/home/pangshumao/data/Spine_Localization_PIL/in/keypoints'

    kp_data = np.load(os.path.join(kp_dir, 'Case300.npz'))
    coords = kp_data['coords'] # a numpy array with shape of (10, 2), each row is the coordinate of a keypoint like (x, y)
    pixelspacing = kp_data['pixelspacing'] # a numpy array with shape of (2, )

    names = ['L1', 'L2', 'L3', 'L4', 'L5',
             'L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']

    reader = sitk.ImageFileReader()
    reader.SetFileName(os.path.join(data_dir, 'weak_supervised_MR', 'Case300.dcm'))
    image = reader.Execute()
    image_array = sitk.GetArrayFromImage(image) # 1, h, w

    mr_image = image_array[0]
    heatmaps = generate_heatmaps(mr_image, pixelspacing, coords, sigma=3.0) # a numpy array with shape of (10, h, w)

    heatmap = np.sum(heatmaps, axis=0)

    h_c = peak_local_max(heatmap, min_distance=10) # a numpy array with shape of (n, 2), each row is the coordinate of a keypoint like (y, x)

    # display results
    fig, axes = plt.subplots(1, 3, figsize=(14, 6), sharex=True, sharey=True)
    ax = axes.ravel()
    ax[0].imshow(mr_image, cmap=plt.cm.gray)

    for i in range(coords.shape[0]):
        ax[0].scatter(x=coords[i, 0], y=coords[i, 1], c='r')
        ax[0].text(x=coords[i, 0] + 5, y=coords[i, 1] + 5, s=names[i], c='w')
    ax[0].axis('off')
    ax[0].set_title('Original keypoints')

    ax[1].imshow(heatmap)
    ax[1].axis('off')
    ax[1].set_title('heatmap')

    ax[2].imshow(heatmap, cmap=plt.cm.gray)
    for i in range(h_c.shape[0]):
        ax[2].scatter(x=h_c[i, 1], y=h_c[i, 0], c='r')
        ax[2].text(x=h_c[i, 1] + 5, y=h_c[i, 0] + 5, s=str(i), c='w')
    ax[2].axis('off')
    ax[2].set_title('Local maximum')

    fig.tight_layout()

    plt.show()

结果如图所示,其中右图的标号[0,9]表示求得的极值点出现的顺序。

参考:https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_peak_local_max.html

 


 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值