torchvision 教程

PyTorch torchvision 教程

torchvision 是 PyTorch 的一个子库,专为计算机视觉任务设计,提供了常用的数据集、预训练模型、以及图像转换和处理的工具。本文将介绍如何使用 torchvision 中的功能来加载数据集、预处理数据、使用预训练模型以及进行图像增强。

1. 安装 torchvision

首先,你需要安装 torchvision 库。可以使用 pip 安装:

pip install torchvision

2. torchvision 的主要组件

torchvision 的主要组件有:

  • torchvision.datasets:提供常用的数据集,例如 MNIST、CIFAR-10、ImageNet 等。
  • torchvision.transforms:用于图像的预处理和数据增强。
  • torchvision.models:提供预训练的深度学习模型。
  • torchvision.io:用于读取和写入图像、视频等数据。

3. 使用 torchvision.datasets 加载数据集

torchvision 提供了许多流行的数据集,可以直接从 torchvision.datasets 中加载。你可以加载数据集,并使用 DataLoader 迭代数据。

3.1 加载 MNIST 数据集

MNIST 是一个包含手写数字的经典数据集,常用于图像分类任务。

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义数据转换 (如归一化)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 下载并加载 MNIST 数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 使用 DataLoader 加载数据集
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 打印样本
for images, labels in train_loader:
    print(f"Image batch shape: {images.size()}")
    print(f"Labels batch shape: {labels.size()}")
    break
3.2 加载 CIFAR-10 数据集

CIFAR-10 是另一个常用的数据集,包含 10 类自然图片。

# 加载 CIFAR-10 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 使用 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

4. torchvision.transforms 图像预处理与增强

torchvision.transforms 提供了许多常用的图像预处理和增强方法,例如缩放、裁剪、旋转、翻转等。

4.1 基本预处理操作
transform = transforms.Compose([
    transforms.Resize((32, 32)),            # 调整图像大小
    transforms.RandomHorizontalFlip(),      # 随机水平翻转
    transforms.ToTensor(),                  # 转换为 PyTorch 张量
    transforms.Normalize((0.5,), (0.5,))    # 标准化
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
4.2 常见的 transforms 操作
  • transforms.Resize(size):调整图像大小为 size
  • transforms.CenterCrop(size):从图像中心裁剪大小为 size 的部分。
  • transforms.RandomCrop(size):随机裁剪图像。
  • transforms.RandomHorizontalFlip():随机水平翻转图像。
  • transforms.ColorJitter():随机更改图像的亮度、对比度和饱和度。
  • transforms.ToTensor():将 PIL 图像或 NumPy 数组转换为 PyTorch 张量。
  • transforms.Normalize(mean, std):标准化图像数据。

5. 使用 torchvision.models 的预训练模型

torchvision.models 提供了多种预训练模型,例如 ResNet、VGG、AlexNet 等,这些模型在 ImageNet 数据集上进行了预训练。

5.1 加载预训练模型

你可以加载一个预训练的 ResNet 模型并在新任务上进行微调。

import torchvision.models as models

# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)

# 查看模型架构
print(model)
5.2 微调预训练模型

如果你想要微调预训练模型(例如用于 CIFAR-10 数据集),你可以冻结预训练模型的部分参数,并修改最后一层以适应新的任务。

# 冻结所有层的参数
for param in model.parameters():
    param.requires_grad = False

# 修改最后一层以适应 CIFAR-10 (10 类分类任务)
model.fc = torch.nn.Linear(512, 10)

# 将模型移动到 GPU(如果有)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
5.3 训练模型
# 训练模型
for epoch in range(2):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # 清零梯度
        optimizer.zero_grad()

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 反向传播与优化
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
5.4 测试模型
# 测试模型性能
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

6. torchvision.io 读取和保存图像

torchvision.io 提供了方便的图像读取和保存功能。

6.1 读取图像
import torchvision.io as io

# 读取图像
img = io.read_image('image.jpg')  # 读取为张量

# 显示张量信息
print(img.size())
6.2 保存图像
# 保存张量为图像文件
io.write_jpeg(img, 'output_image.jpg')

7. 完整示例

以下是一个使用 torchvision 加载数据集、进行数据增强、使用预训练模型微调并进行训练的完整示例:

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

# 数据增强与预处理
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载 CIFAR-10 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 加载预训练的 ResNet18 模型并修改最后一层
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
model.fc = nn.Linear(512, 10)

# 设备设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(2):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs

 = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

8. 总结

torchvision 是 PyTorch 中处理计算机视觉任务的重要工具,它为常用的数据集、模型、数据处理和增强提供了便利的接口。通过本教程,你可以学习如何使用 torchvision 加载数据集、应用图像预处理、使用预训练模型进行微调,并训练模型来解决实际的计算机视觉任务。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

吉小雨

你的激励是我创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值