在使用 PyTorch 进行图像处理和深度学习任务时,transforms.ToTensor()
是一个常用的工具。它可以将图像数据转换为 PyTorch 的张量格式,为后续的神经网络处理做好准备。尽管这个方法很常用,但很多人可能并不完全理解它具体做了哪些操作。本文将深入解析 transforms.ToTensor()
的具体作用和工作原理。
transforms.ToTensor()
的核心功能
transforms.ToTensor()
主要有三个核心功能:
- 图像格式转换:将图像从 H x W x C 格式(高度 x 宽度 x 通道)转换为 C x H x W 格式。
- 数据类型转换:将图像数据的像素值从
uint8
类型转换为float32
类型。 - 归一化:将像素值从
[0, 255]
的范围缩放到[0.0, 1.0]
的范围。
详细操作步骤
1. 图像格式转换
图像数据通常以高度(H)、宽度(W)和通道(C)的顺序存储。在大多数图像处理库(如 PIL 和 OpenCV)中,图像的默认格式是 H x W x C。然而,PyTorch 期望输入的张量格式是 C x H x W,即通道维度在最前面。
- 输入格式(常见的图像库):H x W x C
- 输出格式(PyTorch 张量):C x H x