可视化tensor类型或者numpy类型的图片
说明:在很多时候我们需要显示tensor的图片类型,但是在实际的程序中,数据的类型有很多种,其中包括含有batch_size的四维数据,还有没有batch_size的三维数据,同时图片的通道C也可能在不同维度,所以为了方便可视化,我写了一个函数方便,减少大家可视化的麻烦
import matplotlib.pylab as plt
def visualize_image(image_data):
"""
可视化图像数据。
参数:
- image_data: 输入的图像数据,可以是Tensor类型(HxWx3, 3xHxW, BxHxWx3, Bx3xHxW)或者ndarray类型(HxWx3, 3xHxW, BxHxWx3, Bx3xHxW)
注意:
如果输入是Tensor,确保使用 `.detach().cpu().numpy()` 将其转换为NumPy数组。
"""
if isinstance(image_data, torch.Tensor):
image_data = image_data.detach().cpu().numpy()
if image_data.ndim == 3:
if image_data.shape[2] == 3:
# HxWx3 or 3xHxW
plt.imshow(image_data)
plt.axis('off')
plt.show()
elif image_data.shape[0] == 3:
# 3xHxW
plt.imshow(image_data.transpose(1, 2, 0))
plt.axis('off')
plt.show()
elif image_data.ndim == 4:
if image_data.shape[3] == 3:
# BxHxWx3 or Bx3xHxW
B = image_data.shape[0]
images_per_row = 5
rows = (B + images_per_row - 1) // images_per_row
figsize = (images_per_row * 4, rows * 4)
_, axs = plt.subplots(rows, images_per_row, figsize=figsize)
for i in range(B):
ax = axs[i // images_per_row, i % images_per_row]
ax.imshow(image_data[i])
ax.axis('off')
plt.show()
# 隐藏多余的子图
for j in range(B, rows * images_per_row):
axs.flatten()[j].axis('off')
elif image_data.shape[1] == 3:
image_data = image_data.transpose(0, 2, 3, 1)
B = image_data.shape[0]
images_per_row = 5
rows = (B + images_per_row - 1) // images_per_row
figsize = (images_per_row * 4, rows * 4)
_, axs = plt.subplots(rows, images_per_row, figsize=figsize)
for i in range(B):
ax = axs[i // images_per_row, i % images_per_row]
ax.imshow(image_data[i])
ax.axis('off')
# 隐藏多余的子图
for j in range(B, rows * images_per_row):
axs.flatten()[j].axis('off')
plt.show()