参考乱觉先森的这篇文章:EfficientNet 训练测试自己的分类数据集
1、下载代码
Levigty/EfficientNet-Pytorch 可快速使用。(这里顺便提一下EfficientNet的pytorch版官方代码)
2、准备数据集
格式如下:
3、下载预训练模型
可复制链接至迅雷下载更快https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
下载完成之后放在eff_weights文件夹下,目录结构如下:
4、训练完整代码efficientnet_sample.py.更改一些训练参数即可,这里我遇到一个错误ForkingPickler(file, protocol).dump(obj) BrokenPipeError: [Errno 32] Broken pipe,将训练代码放进if __name__ == '__main__':就可以了!!!
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, models, transforms
import time
import os
from efficientnet.model import EfficientNet
# some parameters
use_gpu = torch.cuda.is_available()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
data_dir = 'OxFlower17'
batch_size = 2
lr = 0.01
momentu