【无标题】

可视化tensor类型或者numpy类型的图片

说明:在很多时候我们需要显示tensor的图片类型,但是在实际的程序中,数据的类型有很多种,其中包括含有batch_size的四维数据,还有没有batch_size的三维数据,同时图片的通道C也可能在不同维度,所以为了方便可视化,我写了一个函数方便,减少大家可视化的麻烦

import matplotlib.pylab as plt
def visualize_image(image_data):
    """
    可视化图像数据。
    参数:
    - image_data: 输入的图像数据,可以是Tensor类型(HxWx3, 3xHxW, BxHxWx3, Bx3xHxW)或者ndarray类型(HxWx3, 3xHxW, BxHxWx3, Bx3xHxW)

    注意:
    如果输入是Tensor,确保使用 `.detach().cpu().numpy()` 将其转换为NumPy数组。
    """
    if isinstance(image_data, torch.Tensor):
        image_data = image_data.detach().cpu().numpy()

    if image_data.ndim == 3:
        if image_data.shape[2] == 3:
            # HxWx3 or 3xHxW
            plt.imshow(image_data)
            plt.axis('off')
            plt.show()
        elif image_data.shape[0] == 3:
            # 3xHxW
            plt.imshow(image_data.transpose(1, 2, 0))
            plt.axis('off')
            plt.show()
    elif image_data.ndim == 4:
        if image_data.shape[3] == 3:
            # BxHxWx3 or Bx3xHxW
            B = image_data.shape[0]
            images_per_row = 5
            rows = (B + images_per_row - 1) // images_per_row
            figsize = (images_per_row * 4, rows * 4)
            _, axs = plt.subplots(rows, images_per_row, figsize=figsize)

            for i in range(B):
                ax = axs[i // images_per_row, i % images_per_row]
                ax.imshow(image_data[i])
                ax.axis('off')
            plt.show()
            # 隐藏多余的子图
            for j in range(B, rows * images_per_row):
                axs.flatten()[j].axis('off')
        elif image_data.shape[1] == 3:
            image_data = image_data.transpose(0, 2, 3, 1)
            B = image_data.shape[0]
            images_per_row = 5
            rows = (B + images_per_row - 1) // images_per_row
            figsize = (images_per_row * 4, rows * 4)
            _, axs = plt.subplots(rows, images_per_row, figsize=figsize)

            for i in range(B):
                ax = axs[i // images_per_row, i % images_per_row]
                ax.imshow(image_data[i])
                ax.axis('off')

            # 隐藏多余的子图
            for j in range(B, rows * images_per_row):
                axs.flatten()[j].axis('off')
            plt.show()
  • 9
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值