pytorch debug 常用工具

可视化

张量可视化

import torch
from torchvision.transforms.functional import to_pil_image
from PIL import Image
def tensor_to_pil(tensor):
	# 8bit的图像,可能失真
    # 确保tensor是在CPU上
    tensor = tensor.cpu()
    # 如果tensor有一个批次维度,去除它
    if tensor.dim() == 4 and tensor.shape[0] == 1:
        tensor = tensor.squeeze(0)
    # 转换为PIL图像
    pil_image = to_pil_image(tensor)
    pil_image.show()
    # 返回PIL图像
    return pil_image

tensor_to_pil( ).show()

转换成np,高精度可视化

import torch
import numpy as np
import matplotlib.pyplot as plt
def visualize_tensor(tensor):
    # 确保tensor在CPU上
    tensor = tensor.cpu()
    # 如果tensor有一个批次维度,去除它
    if tensor.dim() == 4 and tensor.shape[0] == 1:
        tensor = tensor.squeeze()
    # 将tensor转换为numpy数组,并将其数据类型转换为高精度浮点数
    np_array = tensor.detach().numpy().astype(np.float32)
    dims = np_array.ndim
    if dims == 3:
        # 使用 transpose() 函数将 channels 维度移到最后
        np_array = np.transpose(np_array, (1, 2, 0))
    # 使用matplotlib进行可视化
    plt.imshow(np_array, cmap='gray')  # 使用灰度图进行显示,你可以根据需要更改这个值
    plt.colorbar()  # 显示色轴
    plt.show()
    return np_array

深度图可视化


def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
    """Converts a depth map to a color image.

    Args:
        value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
        vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
        vmax (float, optional):  vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
        cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
        invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
        invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
        background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
        gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
        value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.

    Returns:
        numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
    """
    if isinstance(value, torch.Tensor):
        value = value.detach().cpu().numpy()

    value = value.squeeze()
    if invalid_mask is None:
        invalid_mask = value == invalid_val
    mask = np.logical_not(invalid_mask)

    # normalize
    vmin = np.percentile(value[mask],2) if vmin is None else vmin
    vmax = np.percentile(value[mask],85) if vmax is None else vmax
    if vmin != vmax:
        value = (value - vmin) / (vmax - vmin)  # vmin..vmax
    else:
        # Avoid 0-division
        value = value * 0.

    # squeeze last dim if it exists
    # grey out the invalid values

    value[invalid_mask] = np.nan
    cmapper = matplotlib.cm.get_cmap(cmap)
    if value_transform:
        value = value_transform(value)
        # value = value / value.max()
    value = cmapper(value, bytes=True)  # (nxmx4)

    # img = value[:, :, :]
    img = value[...]
    img[invalid_mask] = background_color

    #     return img.transpose((2, 0, 1))
    if gamma_corrected:
        # gamma correction
        img = img / 255
        img = np.power(img, 2.2)
        img = img * 255
        img = img.astype(np.uint8)
    return img


自动辨识图像格式可视化

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def convert_to_numpy(image_input):
    """
    自动检测输入图像类型,并将其转换为NumPy数组。
    """
    if isinstance(image_input, np.ndarray):
        # 输入已经是NumPy数组,直接返回
        return image_input
    elif 'Tensor' in str(type(image_input)):
        # 输入是Tensor类型
        # 检查是否需要转换(依赖于Tensor所属的库,如PyTorch, TensorFlow等)
        if hasattr(image_input, 'detach'):
            # 假设是PyTorch Tensor
            image_input = image_input.detach().cpu().numpy()
        else:
            # 假设是TensorFlow Tensor或其他框架的Tensor
            image_input = image_input.numpy()
        # 如果Tensor有通道维度在最前面(如CHW),则需要转换为HWC
        if image_input.ndim == 3 and image_input.shape[0] in (1, 3):
            image_input = image_input.transpose(1, 2, 0)
    elif isinstance(image_input, Image.Image):
        # 输入是Pillow图像,转换为NumPy数组
        image_input = np.array(image_input)
    else:
        raise TypeError("Unsupported image type")
    
    # 如果图像是单通道的,且在最后一个维度(例如HxWx1),去掉该维度
    if image_input.ndim == 3 and image_input.shape[-1] == 1:
        image_input = image_input.squeeze(-1)
    image_np = image_input 
    if image_np.ndim == 3 and image_np.shape[-1] == 3:
        plt.imshow(image_np)
    else:
        plt.imshow(image_np, cmap='viridis')
    plt.title(title)
    plt.axis('off')
    plt.show()


def visualize_image(image_np, title="Image"):
    """
    可视化NumPy格式的图像
    """
    if image_np.ndim == 3 and image_np.shape[-1] == 3:
        plt.imshow(image_np)
    else:
        plt.imshow(image_np, cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.show()

# 示例使用
# image_tensor, image_np, image_pil 分别代表Tensor, NumPy数组, Pillow图像的输入
# 将它们转换为NumPy数组
# image_np = convert_to_numpy(image_tensor)
# image_np = convert_to_numpy(image_np)
# image_np = convert_to_numpy(image_pil)

# # 可视化图像
# visualize_image(image_np)

可视化已经图像信息

def draw_np(pic_np):
    pic_np = np.squeeze(pic_np)
    plt.imshow(pic_np)
    # 隐藏坐标轴
    plt.axis('on')
    # 显示数据标尺
    plt.colorbar()
    # 显示图像
    plt.show()

def get_image_info(image):
    # 获取图像的模式、格式和尺寸
    mode = image.mode
    format_ = image.format
    size = image.size

    # 根据图像模式推断每个通道的位数
    if mode in ("1", "L", "P"):
        bits_per_channel = 8  # 通常是8位
    elif mode == "RGB":
        bits_per_channel = 8  # 通常是8位,3通道
    elif mode == "RGBA":
        bits_per_channel = 8  # 通常是8位,4通道
    elif mode == "I":
        bits_per_channel = 32 # 整数像素模式
    elif mode == "F":
        bits_per_channel = 32 # 浮点像素模式
    else:
        bits_per_channel = 'unknown'  # 未知或不常见的模式

    # 计算总位数
    total_bits = image.getbands().__len__() * bits_per_channel

    # 打印图像信息
    print(f"Image mode: {mode}")
    print(f"Image format: {format_}")
    print(f"Image size: {size}")
    print(f"Bits per channel: {bits_per_channel}")
    print(f"Total bits per pixel: {total_bits}")


#%%
import numpy as np

def get_array_info(np_array):
    """
    获取并打印NumPy数组的详细信息。

    参数:
    np_array: NumPy数组。
    """
    # 获取数组的形状
    shape = np_array.shape

    # 获取数组的总元素数量
    size = np_array.size

    # 获取数组的数据类型
    dtype = np_array.dtype

    # 获取数组单个元素的大小(以字节为单位)
    itemsize = np_array.itemsize

    # 获取数组的维度数量
    ndim = np_array.ndim

    # 获取数组的总字节数
    nbytes = np_array.nbytes

    # 打印数组信息
    print(f"Array Shape: {shape}")
    print(f"Array Size: {size}")
    print(f"Array Data Type: {dtype}")
    print(f"Item Size: {itemsize} bytes")
    print(f"Array Dimensions: {ndim}")
    print(f"Total Bytes: {nbytes} bytes")
def read_pic(path_pic):
    # 加载图像
    image = Image.open(path_pic)
    print(image.size)
    print(image.format)
    return image

def pic_to_np(pic):
    np_depth = np.array(pic)
    return np_depth

def draw_np(pic_np):
    pic_np = np.squeeze(pic_np)
    plt.imshow(pic_np)
    # 隐藏坐标轴
    plt.axis('on')
    # 显示数据标尺
    plt.colorbar()
    # 显示图像
    plt.show()
    
def pic_info(path):
    raw_image = read_pic(path)
    raw_np = pic_to_np(raw_image)
    get_image_info(raw_image)
    get_array_info(raw_np)
    raw_image.show()
    draw_np(raw_np)
  • 7
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值