🧠 你需要知道的核心点
🟨 原始图片(PIL Image)长这样:
类型:PIL.Image
数据:每个像素是 (R, G, B),值范围是 0~255(整数)
形状:宽 x 高(W x H)
这是你平时在电脑上看的图片。
🟩 模型要求的格式:
类型:torch.Tensor
数据:每个像素通道是 float 类型,值范围是 0.0 ~ 1.0
形状:(通道数 C, 高 H, 宽 W),比如 (3, 150, 150)
这是深度学习模型能“吃”的数据格式。
✅ ToTensor()
的作用就是:
👉 把图像从 PIL 格式 变成 Tensor 格式
转换前 (PIL) | 转换后 (Tensor) |
---|---|
数据是整数(0~255) | 数据是浮点数(0.0~1.0) |
形状是 (H, W, C) | 形状是 (C, H, W) |
RGB 图片(3通道) | Tensor(3通道) |
不能直接送入模型 | 可以直接送入模型 |
👇 看例子就明白了
假设我们有一张 2x2 的红色图片(简单到极致):
from PIL import Image
import numpy as np
from torchvision import transforms
import torch
# 创建一张红色图像,RGB(255, 0, 0)
red_image = Image.fromarray(np.array([
[[255, 0, 0], [255, 0, 0]],
[[255, 0, 0], [255, 0, 0]]
], dtype=np.uint8))
print("🔍 原始图像数据:")
print(np.array(red_image))
输出结果:
[[[255 0 0] [255 0 0]]
[[255 0 0] [255 0 0]]]
然后我们用 ToTensor()
转一下:
tensor_image = transforms.ToTensor()(red_image)
print("🔁 ToTensor 后:")
print(tensor_image)
输出结果:
tensor([[[1.0000, 1.0000],
[1.0000, 1.0000]], <-- 红色通道
[[0.0000, 0.0000],
[0.0000, 0.0000]], <-- 绿色通道
[[0.0000, 0.0000],
[0.0000, 0.0000]]]) <-- 蓝色通道
💡所以你只要记住:
ToTensor()
会:
- 把像素值从 0~255 ➜ 变成 0~1 的小数(除以 255)
- 把形状从 (H, W, C) ➜ 变成 (C, H, W)
- 把图片从 PIL 类型 ➜ 变成 Tensor 类型(模型可以理解)
❓那如果我不加 ToTensor()
会怎样?
✅ 不加的话,图片还是 PIL 格式,不能直接传给模型,会报错:
TypeError: pic should be Tensor or ndarray. Got <class 'PIL.Image.Image'>