前言
efficientNet确实很牛逼,而pytorch也已经在第一时间上线了调用efficientNet的方法。但是其调用的方法对于非科学上网的开发者来说很不友好(因为调用该模型需要在pytorch的终端当中进行模型的下载,而访问pytorch的终端对于国内用户来说太慢了,和访问stackoverflow速度差不多。)而一个模型例如b4,b5则要几十上百mb的大小。
efficientnet_pytorch模块安装
pip install efficientnet_pytorch
下面是正常情况下对efficientNet的调用
//以加载efficientNet-b5预训练模型为例子
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientNet-b5')
//其通常需要访问远程终端下载,速度非常慢
当然,不通过终端调efficientNet预训练模型的前提是本地已经有efficientNet的预训练模型,不用担心,我已经把b3-b6保存在了百度云,百度云的链接在此:百度网盘 ,提取码为:3bpc
//现在则是在本地调用efficientNet的预训练模型了,还是以efficientNet-b5为例子
//首先我们需要pytorch提供好的网络结构
model = EfficientNet.from_name('efficientnet-b5')
//而后需要通过pytorch来加载本地的pth权重文件
state_dict = torch.load('xxx/xxx/efficientnet-b5.pth')
//把effcientnet-b5的权重load到efficientnet-b5的网络结构中
model.load_state_dict(state_dict)
一般情况下,efficientNet因其分类数为1000,若使用还需要修改分类数为我们需要的,如下
import torch
from efficientnet_pytorch import EfficientNet
from torch import nn
model = EfficientNet.from_name('efficientnet-b4')
state_dict = torch.load('efficientnet-b4.pth')
model.load_state_dict(state_dict)
//前三步照常
in_fea = model._fc.in_features //读取全连接层的in_features
//要改的是最终输出的特征维度out_features(假设分类数为40)
model._fc = nn.Linear(in_features=in_fea, out_features= 40,bias=True)
//Finish!