预训练模型项目教程
项目介绍
pretrained-models.pytorch
是一个开源项目,旨在提供预训练的卷积神经网络(ConvNets)模型,支持 PyTorch 框架。该项目的目标是帮助用户复现研究论文的结果,特别是在迁移学习设置中,并提供一个统一的接口/API 来访问这些预训练模型。
项目快速启动
安装
首先,确保你已经安装了 PyTorch。然后,你可以通过以下命令安装 pretrained-models.pytorch
:
pip install git+https://github.com/Cadene/pretrained-models.pytorch.git
使用示例
以下是一个简单的示例,展示如何加载并使用一个预训练模型进行图像分类:
import pretrainedmodels
import torch
from torchvision import transforms
from PIL import Image
# 加载预训练模型
model_name = 'resnet18'
model = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载图像
img = Image.open('path_to_your_image.jpg')
img_tensor = transform(img).unsqueeze(0) # 增加批次维度
# 模型预测
with torch.no_grad():
output = model(img_tensor)
_, preds = torch.max(output, 1)
print(f'预测类别: {preds.item()}')
应用案例和最佳实践
迁移学习
迁移学习是使用预训练模型的一种常见方式。你可以加载一个预训练模型,并替换其最后一层以适应你的特定任务。以下是一个示例:
import torch.nn as nn
# 加载预训练模型
model = pretrainedmodels.__dict__['resnet18'](num_classes=1000, pretrained='imagenet')
# 替换最后一层
num_ftrs = model.last_linear.in_features
model.last_linear = nn.Linear(num_ftrs, num_classes=2) # 假设你有2个类别
# 微调模型
# ...
模型评估
在微调模型后,你可以使用以下代码进行模型评估:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 数据加载和预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = datasets.ImageFolder('path_to_your_dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'准确率: {100 * correct / total}%')
典型生态项目
TorchVision
TorchVision
是 PyTorch 的一个官方库,提供了许多用于计算机视觉任务的工具和预训练模型。你可以通过以下命令安装:
pip install torchvision
TorchText
TorchText
是 PyTorch 的一个官方库,专注于自然语言处理任务。你可以通过以下命令安装:
pip install torch