Timm快速使用
关注B站可以观看更多实战教学视频:hallo128的个人空间
参考:Timm使用说明【推荐】
核心思路
代码使用了PyTorch和timm库加载了一个预训练的ResNet-18模型,并对一张图片进行了推理,输出了概率最高的Top-5标签及其概率值。这是一个常见的图像分类任务的示例,通过预训练模型可以快速实现图像识别的功能。在这个示例中,模型是基于ImageNet数据集训练过的,而IMAGENET_1k_LABELS则包含了ImageNet数据集中的标签信息。
代码
提前下载类别对应关系:wget https://mirror.coggle.club/imagenet_classes.txt
- 1.加载图片
- 2.加载模型
- 3.使用模型进行预测
- 4.得到类别对应关系
import timm
from PIL import Image
# 1.加载图片
image = Image.open('tyler-swift.jpg') #---更改图片路径
# 2.加载模型
model = timm.create_model('resnet18', pretrained=True)
# 得到模型对图片的预处理方式
transform = timm.data.create_transform(
**timm.data.resolve_data_config(model.pretrained_cfg)
)
image_tensor = transform(image) # 图片预处理
# 3.使用模型进行预测
output = model(image_tensor.unsqueeze(0))
probabilities = torch.nn.functional.softmax(output[0], dim=0)
values, indices = torch.topk(probabilities, 5)
# 4.得到类别对应关系
IMAGENET_1k_LABELS = open('imagenet_classes.txt').readlines()
[{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]
Timm使用说明【推荐】