ImageNet-21K Pretraining 项目教程
项目介绍
ImageNet-21K Pretraining 项目是由 Alibaba-MIIL 团队开发的一个开源项目,旨在为计算机视觉任务提供大规模的预训练模型。该项目基于 ImageNet-21K 数据集,该数据集包含超过 21,000 个类别和数百万张图像,为深度学习模型提供了丰富的预训练资源。通过使用 ImageNet-21K 数据集进行预训练,可以显著提高模型在各种下游任务上的性能。
项目快速启动
安装依赖
首先,确保你已经安装了 PyTorch 和其他必要的依赖库。你可以使用以下命令安装这些依赖:
pip install torch torchvision timm
下载项目
使用以下命令从 GitHub 下载项目:
git clone https://github.com/Alibaba-MIIL/ImageNet21K.git
cd ImageNet21K
加载预训练模型
你可以使用 timm
库来加载预训练模型。以下是一些示例代码:
import timm
# 加载 mobilenetv3_large_100_miil_in21k 模型
model = timm.create_model('mobilenetv3_large_100_miil_in21k', pretrained=True)
# 加载 tresnet_m_miil_in21k 模型
model = timm.create_model('tresnet_m_miil_in21k', pretrained=True)
# 加载 vit_base_patch16_224_miil_in21k 模型
model = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True)
# 加载 mixer_b16_224_miil_in21k 模型
model = timm.create_model('mixer_b16_224_miil_in21k', pretrained=True)
应用案例和最佳实践
图像分类
使用预训练模型进行图像分类是常见的应用场景。以下是一个简单的示例代码,展示如何使用预训练模型对图像进行分类:
import torch
from PIL import Image
import requests
from torchvision import transforms
# 加载预训练模型
model = timm.create_model('mobilenetv3_large_100_miil_in21k', pretrained=True)
model.eval()
# 加载并预处理图像
url = 'https://example.com/image.jpg'
image = Image.open(requests.get(url, stream=True).raw)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = transform(image).unsqueeze(0)
# 进行推理
with torch.no_grad():
output = model(input_tensor)
# 获取预测结果
predictions = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(predictions, 5)
for prob, catid in zip(top5_prob, top5_catid):
print(f'类别: {catid}, 概率: {prob.item()}')
迁移学习
迁移学习是另一种常见的应用场景,特别是在数据集较小的情况下。以下是一个示例代码,展示如何使用预训练模型进行迁移学习:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 加载预训练模型并修改最后一层
model = timm.create_model('mobilenetv3_large_100_miil_in21k', pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear