Pytorch视觉模型库--timm

更新于2021.7.1



一、timm是什么?

  1. github:https://github.com/rwightman/pytorch-image-models
  2. Lean tutorials: https://rwightman.github.io/pytorch-image-models/
  3. Comprehensive tutorials:https://fastai.github.io/timmdocs/#
  4. https://paperswithcode.com/lib/timm

Pytorch Image Models (timm) 整合了常用的models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts,它的目的是将各种SOTA模型整合在一起,并具有再现ImageNet训练结果的能力。

作者:Ross Wightman,来自加拿大温哥华。

二、安装

pip install timm

建议在python>=3.6, pytorch>=1.4环境下使用

三、使用

import timm

1. 查询模型:timm.list_models()

1.1. 查看存在的所有模型

model_list = timm.list_models()
print(len(model_list), model_list[:3])
# Results 541 ['adv_inception_v3', 'botnet26t_256', 'botnet50ts_256']

1.2. 查看具有预训练参数的模型

model_pretrain_list = timm.list_models(pretrained=True)
print(len(model_pretrain_list), model_pretrain_list[:3])
# Results:396 ['adv_inception_v3', 'cait_m36_384', 'cait_m48_448']

1.3. 检索特定模型,采用模糊查询,如resnet系列

model_resnet = timm.list_models('*resnet*')
print(len(model_resnet), model_resnet[:3])
# Results: 117 ['cspresnet50', 'cspresnet50d', 'cspresnet50w']

可进一步查看想用的模型是否提供了预训练参数

print('resnet50: ', 'resnet50' in model_pretrain_list,
      'resnet101: ', 'resnet101' in model_pretrain_list)
# Results:     resnet50:  True      resnet101:  False

2. 创建模型:timm.create_model()

2.1. 创建预定义的完整的分类模型,可通过pretrained选项选择是否加载预训练参数

import torch
x = torch.randn([1, 3, 224, 224])
model_resnet50 = timm.create_model('resnet50', pretrained=True)
out = model_resnet50(x)
print(out.shape)
# Results: torch.Size([1, 1000])

2.2. 改变输出类别数目,微调模型:num_classes

model_resnet50_finetune = timm.create_model('resnet50', pretrained=True, num_classes=10)
out = model_resnet50_finetune (x)
print(out.shape)
# Results: torch.Size([1, 10])

2.3. 改变输入通道数:in_chans

# 通道数改变后,对应的权重参数会进行相应的处理,此处不作详细说明,
#可参照:https://fastai.github.io/timmdocs/models或直接查看源代码
x = torch.randn([1, 1, 224, 224])
feature_extractor = timm.create_model('resnet50', in_chans=1, features_only=True, out_indices=[1, 3, 4])

2.4. 获取分类层前(倒数第二层)的特征

参考链接:https://blog.csdn.net/qq_41917697/article/details/115026308

①直接调用forward_features()函数

x = torch.randn([1, 3, 224, 224])
Backbone1 = timm.create_model('vit_base_patch16_224')
Backbone2 = timm.create_model('resnet50')
feature1 = Backbone1.forward_features(x)
feature2 = Backbone2.forward_features(x)
print('vit_feature:', feature1.shape, 'resnet_feature:', feature2.shape)
# Results: vit_feature: torch.Size([1, 768])    resnet_feature: torch.Size([1, 2048, 7, 7])

②直接创建没有池化和分类层的模型,对于基于CNN的模型可以这样做

x = torch.randn([1, 3, 224, 224])
Backbone1 = timm.create_model('resnet50', num_classes=0, global_pool='')
Backbone2 = timm.create_model('resnet50', num_classes=0)
feature1 = Backbone1(x)
feature2 = Backbone2(x)
print('before pooling:', feature1.shape, 'after pooling:', feature2.shape)
# Results: before pooling: torch.Size([1, 2048, 7, 7])    after pooling: torch.Size([1, 2048])

③通过移除层来获得

x = torch.randn([1, 3, 224, 224])
Backbone1 = timm.create_model('resnet50')
Backbone2 = timm.create_model('resnet50')
Backbone1.reset_classifier(0, '')
Backbone2.reset_classifier(0)
feature1 = Backbone1(x)
feature2 = Backbone2(x)
print('before pooling:', feature1.shape, 'after pooling:', feature2.shape)
# Results: before pooling: torch.Size([1, 2048, 7, 7]) after pooling: torch.Size([1, 2048])

2.5. 获取中间层特征:features_only

x = torch.randn([1, 3, 224, 224])
feature_extractor = timm.create_model('resnet50', features_only=True)  # 并非所有model都有此选项
feature_list = feature_extractor(x)
for a in feature_list:
    print(a.shape)
# Results:
# torch.Size([1, 64, 112, 112])
# torch.Size([1, 256, 56, 56])
# torch.Size([1, 512, 28, 28])
# torch.Size([1, 1024, 14, 14])
# torch.Size([1, 2048, 7, 7])

可通过out_indices参数指定从哪个level获取feature

feature_extractor = timm.create_model('resnet50', features_only=True, out_indices=[1, 3, 4])
feature_list = feature_extractor(x)
for a in feature_list:
    print(a.shape)
# Results:
# torch.Size([1, 256, 56, 56])
# torch.Size([1, 1024, 14, 14])
# torch.Size([1, 2048, 7, 7])
  • 39
    点赞
  • 127
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值