一、Timm创建模型
加载预训练模型
代码如下:
import timm
m = timm.create_model('vit_large_r50_s32_224.augreg_in21k', pretrained=True)
m.eval()
二、查看模型列表
1.引入库
代码如下:
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
输出结果
[‘bat_resnext26ts.ch_in1k’,
‘beit_base_patch16_224.in22k_ft_in22k’,
‘beit_base_patch16_224.in22k_ft_in22k_in1k’,
‘beit_base_patch16_384.in22k_ft_in22k_in1k’,
‘beit_large_patch16_224.in22k_ft_in22k’,
‘beit_large_patch16_224.in22k_ft_in22k_in1k’,
‘beit_large_patch16_384.in22k_ft_in22k_in1k’,
‘beit_large_patch16_512.in22k_ft_in22k_in1k’,
‘beitv2_base_patch16_224.in1k_ft_in1k’,
‘beitv2_base_patch16_224.in1k_ft_in22k’,
‘beitv2_base_patch16_224.in1k_ft_in22k_in1k’,
‘beitv2_large_patch16_224.in1k_ft_in1k’,
…]
## 根据模型名称进行过滤
import timm
from pprint import pprint
model_names = timm.list_models('*resne*t*')
pprint(model_names)
2.模型微调
代码如下(示例):
model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=2)
pprint(model)
模型输入通道的修改
3.特征提取
import timm
import torch
x = torch.randn(1, 3, 224, 224)
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
pprint(model)
features = model.forward_features(x)
print(features.shape)
4.指定模型特征输出
# 对输出特征进行限制
import torch
import timm
# # 指定输出索引,并修改stride
m = timm.create_model('resnet34', features_only=True,output_stride = 8, pretrained=True)
# print("模型修改前")
# pprint(m)
# print("模型修改后")
# output_stride降采样的倍数
# m = timm.create_model('resnet34', output_stride = 8,out_indices = [4,],features_only=True, pretrained=True)
# pprint(m)
print(f'Feature channels: {m.feature_info.channels()}')
print(f'Feature reduction: {m.feature_info.reduction()}')
o = m(torch.randn(2, 3, 256, 256))
for x in o:
print(x.shape)
print(o[0].shape)
总结
timm还支持图像变换,详细还是去官网看文档吧:Timm