timm笔记

快速开始

安装 timm
pip install timm

timm.create_model

(model_name: str,
 pretrained:  bool = False,
 pretrained_cfg:Union = None,
 pretrained_cfg_overlay: Optional = None,
 checkpoint_path: str = '',
 scriptable: Optional = None,
 exportable: Optional = None,
 no_jit: Optional = None, **kwargs)

timm.create_model 详细解读

create_model 函数用于创建一个模型。它的参数如下:

  • model_name: 模型名称 (字符串)。
  • pretrained: 是否加载预训练权重 (布尔值,默认值为 False)。
  • pretrained_cfg: 预训练配置 (可选)。
  • pretrained_cfg_overlay: 预训练配置覆盖 (可选)。
  • checkpoint_path: 检查点路径 (字符串,默认值为空)。
  • scriptable: 是否可脚本化 (可选)。
  • exportable: 是否可导出 (可选)。
  • no_jit: 是否禁用 JIT 编译 (可选)。
  • **kwargs: 其他关键字参数。

关键字参数

  • drop_rate: 分类器训练时的 dropout 率 (浮点数)。
  • drop_path_rate: 训练时随机深度 drop 路径率 (浮点数)。
  • global_pool: 分类器的全局池化类型 (字符串)。

示例

from timm import create_model

# 创建一个没有预训练权重的 MobileNetV3-Large 模型。
model = create_model('mobilenetv3_large_100')

# 创建一个带有预训练权重的 MobileNetV3-Large 模型。
model = create_model('mobilenetv3_large_100', pretrained=True)
model.num_classes  # 1000

# 创建一个带有预训练权重和新分类头的 MobileNetV3-Large 模型 (10 类)。
model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
model.num_classes  # 10

这个函数会通过入口函数将相关参数传递给 timm.models.build_model_with_cfg,然后调用模型类的 __init__ 方法。如果 kwargs 的值为 None,则在传递前会被剔除。

加载预训练模型
import timm
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
model.eval()

注意:返回的 PyTorch 模型默认设置为训练模式,因此如果你计划使用它进行推理,则必须在其上调用 .eval()。

列出预训练模型
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
微调预训练模型
model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
特征提取
x = torch.randn(1, 3, 224, 224)
features = model.forward_features(x)
print(features.shape)
图像增强
transform = timm.data.create_transform((3, 224, 224))
预处理数据
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)
使用预训练模型进行推理
image = Image.open(requests.get(url, stream=True).raw)
image_tensor = transform(image).unsqueeze(0)
output = model(image_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
values, indices = torch.topk(probabilities, 5)

特征提取

特征提取

倒数第二层特征 (分类器前特征)

你可以通过多种方式获取模型的倒数第二层特征,无需修改模型:

  • 未池化特征

    • 使用 model.forward_features(input) 来获取未池化特征。
    • 创建模型时不包含分类器和池化层。
    • 使用 reset_classifier(0, '') 移除分类器和池化层。
  • 池化特征

    • 使用 model.forward_features() 并手动池化结果。
    • 创建模型时只移除分类器。
多尺度特征图 (特征金字塔)

可以创建一个只输出特征图的模型,使用 features_only=True 参数,并可通过 out_indices 指定输出哪些层的特征。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值