why? 主要原因是大多数可视化工具期望输入的格式数据为”NumPy“数组,这些工具在处理和显示图像时,通常不支持pytorch的’Tensor‘结构
-
Matplotlib 等库的兼容性:
Matplotlib
是一个常用的 Python 可视化库,它要求输入数据为’Numpy‘数组,传入Tensor会报错Tensor
是 PyTorch 特有的数据结构,而NumPy
数组是 Python 科学计算的标准数据格式,许多库都是围绕NumPy
设计和实现的。
-
数据格式的转换:
Tensor
在PyTorch中存储形式是(channels, height,width),而Numpy以(height, width, channels)表示,因此我们需要用’.transpose‘来调整为度顺序。‘transpose(1, 2,0)
-
数据类型的处理:
- 图像显示通常需要像素值(NumPy类型)在’[0, 255]‘或’[0, 1]之间,数据类型为float或int。而‘Tensor’数据类型和范围可能不同,因此在转换的同时需要做一些处理
-
便于后续操作:
NumPy
提供了丰富的数组操作功能,如裁剪、变换、过滤等,这些操作在图像处理和可视化中非常常见。将Tensor
转换为NumPy
数组后,可以更方便地进行这些操作。
-
常规处理操作
-
预处理:
- 将张量移动到CPU上,因为NumPy支持CPU
- 如果需要后使张量仍在GPU上操作,可以对Tensor创建副本,使得副本移动到CPU
- 将张量从计算图中分离(detach),相当于是变成了一个标量,避免影响反向传播计算梯度,因为Numpy不支持梯度计算
-
转换
- 需要在转换过程中去除数组中的单维度 例如,如果数组的形状是
(1, 3, 224, 224)
(批量大小为1),squeeze()
将其变为(3, 224, 224)
。这一步适用于去除可能的单通道维度。
- 需要在转换过程中去除数组中的单维度 例如,如果数组的形状是
-
后处理
- 将数组轴序调整为(h, w, c).
- 将数组像素值限制在[0, 1]之间,必须是[0,1],因为在数据增强、归一化在传递数据时统一
-
-
示例
-
def im_convert(tensor): image = tensor.to('cpu').clone().deatch() # 转移到cpu创建副本脱离 image = image.numpy().squeeze() # 转换为张量并压缩为3维数组 image = image.tanspose(1, 2, 0) # 通道调序 image = iamge.clip(0, 1) # 限制范围 # 可视化过程(可忽略) fig = plt.figure(figsize=(20, 12)) #创建图形窗口,图形尺寸时20x12 dataset = function/class('yourdata-path') # 根据你的数据路径获取数据集 dataloader = torch.utils.data.DataLoader # 获取数据加载器 (dataset = dataset, batch_size = 8, shuffle = False, num_workers = 0) dataiter = iter(dataloader) # 获取数据迭代器,是由数据加载器转换而来 inputs, classes = next(dataiter) # 从iter中获取一个批次数据 columns = 4 rows = 2 # 定义4*2网格 for idx in range (columns*rows): ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[]) #创建一个子图,idx + 1 是子图的索引(从 1 开始),隐藏x,y刻度 ax.set_title(num_to_class[int(classes[idx])]) # 设置标题,num_to_class是我之前网络中定义的一个映射,将类别索引转化为类别名称,classes是由迭代器产生的 plt.imshow(im_convert(inputs[idx])) # 将Tensor转变为NumPy,inputs是包含多张图像的张量,idx是循环参数 plt.imshow()