在深度学习中,数据的预处理和增强是至关重要的步骤。而在 PyTorch 中,transforms.Compose()
函数提供了便捷、模块化的数据变换方式,极大地简化了预处理流程。本文将详细介绍 transforms.Compose()
,并通过实例演示如何在图像数据处理中使用它。
🧑 博主简介:现任阿里巴巴嵌入式技术专家,15年工作经验,深耕嵌入式+人工智能领域,精通嵌入式领域开发、技术管理、简历招聘面试。CSDN优质创作者,提供产品测评、学习辅导、简历面试辅导、毕设辅导、项目开发、C/C++/Java/Python/Linux/AI等方面的服务,如有需要请站内私信或者联系任意文章底部的的VX名片(ID:
gylzbk
)
💬 博主粉丝群介绍:① 群内初中生、高中生、本科生、研究生、博士生遍布,可互相学习,交流困惑。② 热榜top10的常客也在群里,也有数不清的万粉大佬,可以交流写作技巧,上榜经验,涨粉秘籍。③ 群内也有职场精英,大厂大佬,可交流技术、面试、找工作的经验。④ 进群免费赠送写作秘籍一份,助你由写作小白晋升为创作大佬。⑤ 进群赠送CSDN评论防封脚本,送真活跃粉丝,助你提升文章热度。有兴趣的加文末联系方式,备注自己的CSDN昵称,拉你进群,互相学习共同进步。
【PyTorch】掌握transforms.Compose:PyTorch数据预处理的强大工具
- 1. 📖 什么是 `transforms.Compose()`?
- 2. 💻 安装与基础使用
- 3. 🛠️ 常用的变换操作
- 3.1 📏 尺寸调整和裁剪
- 3.2 🎨 数据增强
- 3.3 🔄 归一化和张量转换
- 3.4 🌈 色彩变换
- 4. 📋 示例:图像数据预处理流水线
- 5. ✨ 自定义变换操作
- 6. 🖼️ 应用实例:图像分类
- 6.1 📂 数据准备
- 6.2 🏋️♂️ 训练模型
- 6.3 📊 测试模型
- 结论
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Sq2FG68L-1721647907473)(https://i-blog.csdnimg.cn/direct/ba1c6150dbcc4f4ab22048da5d2b9cd0.png)]
1. 📖 什么是 transforms.Compose()
?
transforms.Compose()
是 PyTorch 提供的一个简单实用的工具。它允许将多个图像变换操作组成一个序列,从而简化图像预处理流水线。transforms.Compose()
接受一个变换列表,并返回一个新的、组合后的变换。 这特别适合在处理图像时,需要链式应用多个变换操作的场景。
2. 💻 安装与基础使用
首先,确保你已经安装了 PyTorch 和 torchvision:
pip install torch torchvision
然后,你可以通过以下代码块来理解 transforms.Compose()
的基本用法:
import torch
from torchvision import transforms
from PIL import Image
# 定义一个由多个变换操作组成的序列
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
# 加载一个图片文件
img = Image.open("path/to/your/image.jpg")
# 应用变换
transformed_img = transform(img)
print(type(transformed_img)) # <class 'torch.Tensor'>
在这个例子中,我们定义了一系列的变换操作,包括调整图像大小、随机水平翻转和将图像转换为张量。通过应用 transform
,我们一次性地对图像进行了这些变换。
3. 🛠️ 常用的变换操作
在 transforms.Compose()
中,你可以使用多种变换,包括但不限于:
3.1 📏 尺寸调整和裁剪
transforms.Resize(size)
: 调整图像大小。transforms.CenterCrop(size)
: 从中心裁剪图像。transforms.RandomResizedCrop(size)
: 随机调整图像大小并裁剪。
3.2 🎨 数据增强
transforms.RandomHorizontalFlip(p=0.5)
: 随机水平翻转图像。transforms.RandomVerticalFlip(p=0.5)
: 随机垂直翻转图像。transforms.RandomRotation(degrees)
: 随机旋转图像一定角度。transforms.RandomAffine(degrees, translate)
: 随机仿射变换。
3.3 🔄 归一化和张量转换
transforms.ToTensor()
: 将 PIL 图像或 numpy.ndarray 转换为 Tensor,并归一化到 [0, 1] 之间。transforms.Normalize(mean, std)
: 对 Tensor 进行标准化处理,即(x - mean) / std
。
3.4 🌈 色彩变换
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
: 随机改变图像的亮度、对比度、饱和度和色相。transforms.Grayscale(num_output_channels=1)
: 将图像转换为灰度图。
4. 📋 示例:图像数据预处理流水线
下面的例子展示了一个组合多种变换操作的实际应用场景:
from torchvision import transforms
from PIL import Image
transform_pipeline = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomCrop(114),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = Image.open("path/to/your/image.jpg")
transformed_img = transform_pipeline(img)
print(type(transformed_img)) # <class 'torch.Tensor'>
在这个例子中,我们创建了一个由不同图像变换操作组成的更复杂的变换流水线。该流水线包括调整大小、随机裁剪、随机水平翻转、随机旋转、色彩变化、转换为 Tensor 和归一化。应用这些变换后,图像将变得更加适合用于训练深度学习模型。
5. ✨ 自定义变换操作
除了使用内置的变换操作之外,transforms.Compose()
还支持自定义变换。你可以定义自己的变换类,并将其作为 transforms.Compose()
的一部分:
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
class AddGaussianNoise(object):
def __init__(self, mean=0.0, std=1.0):
self.mean = mean
self.std = std
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
transform_pipeline = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
AddGaussianNoise(0., 1.)
])
img = Image.open("path/to/your/image.jpg")
transformed_img = transform_pipeline(img)
print(type(transformed_img)) # <class 'torch.Tensor'>
在这个例子中,我们定义了一个自定义的变换 AddGaussianNoise
,它向图像添加高斯噪声。然后我们将这个自定义变换添加到 transforms.Compose()
的序列中。
6. 🖼️ 应用实例:图像分类
为了更好地理解 transforms.Compose()
的实际应用,我们将使用一个图像分类的实例,演示如何使用数据预处理流水线来处理训练数据和测试数据。
6.1 📂 数据准备
假设我们使用 CIFAR-10 数据集进行图像分类任务。可以通过以下代码下载并加载数据:
import torchvision
import torchvision.transforms as transforms
# 定义变换操作序列
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# 下载并加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)
# 下载并加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
6.2 🏋️♂️ 训练模型
定义模型结构,损失函数和优化器,然后开始训练:
import torch.nn as nn
import torch.optim as optim
# 定义简单的卷积神经网络
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = SimpleCNN()
criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练模型
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criteria(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99: # 每 100 个 mini-batch 打印一次
print(f'[Epoch {epoch + 1}, Iter {i + 1}] loss: {running_loss / 100:.3f}')
running_loss = 0.0
print("Finished Training")
6.3 📊 测试模型
在测试数据集上评估模型性能:
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy on the 10000 test images: {100 * correct / total} %')
通过上述步骤,你可以看到 transforms.Compose()
在数据预处理流水线中的实际应用,简化了数据变换操作,提高了图像分类任务的效果和效率。
结论
transforms.Compose()
是 PyTorch 数据预处理和增强流程中的一个核心工具。通过组合多种变换操作,你可以方便地对图像数据进行各种各样的预处理,从而提高模型的泛化能力和训练效果。无论是使用内置的变换操作,还是自定义变换,transforms.Compose()
都能帮助你构建高效的图像处理流水线。希望本文能帮助你更好地掌握并应用 transforms.Compose()
,为你的深度学习项目保驾护航。