pytorch_classification
利用pytorch实现图像分类,其中包含的densenet,resnext,mobilenet,efficientnet, resnet等图像分类网络,可以根据需要再行利用torchvision扩展其他的分类算法
实现功能
基础功能利用pytorch实现图像分类
包含带有warmup的cosine学习率调整
warmup的step学习率优调整
多模型融合预测,加权与投票融合
利用flask实现模型云端api部署
使用tta测试时增强进行预测
添加label smooth的pytorch实现(标签平滑)
添加使用cnn提取特征,并使用SVM,RF,MLP,KNN等分类器进行分类。
更新添加了模型蒸馏的的训练方法
运行环境
python3.7
pytorch 1.1
torchvision 0.3.0
代码仓库的使用
数据集形式
原始数据集存储形式为,同个类别的图像存储在同一个文件夹下,所有类别的图像存储在一个主文件夹data下。
|-- data
|-- train
|--label1
|--*.jpg
|--label2
|--*.jpg
|--label
|--*.jpg
...
|-- val
|--*.jpg
利用preprocess.py将数据集格式进行转换(个人习惯这种数据集的方式)
python ./data/preproces