基于pytorch实现的图像分类源码
这个代码是干嘛的?
这个代码是基于pytorch框架实现的深度学习图像分类,主要针对各大有图像分类需求的使用者。
当然这个代码不适合大佬使用,对于大佬我建议是直接使用mmcv或者timm。
timm是我认为目前比较顶流的图像分类框架,也有很多图像分割、目标检测的源码使用timm作为backbone。
mmcv就更不用说了,就是大佬中的大佬。
当然除了大佬,我不建议各位使用timm和mmcv,不是因为他不好用,而是因为他使用难度较高,对于代码能力一般的同学,跑通估计就已经比较吃力,就不需要说根据自己的需求进行修改代码了。
因此我花了很多时间去研究bubbliiing(我是他粉丝)、yolov5、timm…比较优秀的开源框架后,编写并整合各大优秀源码到一个代码里面。
当然也希望通过这个项目,能够提升自己的编程水平和对深度学习的进一步理解,因为代码是我自己一个人进行编写和整合,虽然经过一些测试,但是没办法进行各方位测试,如果使用者遇到bug、出现报错、有不对的地方,可以通过留言、私信、邮箱(1069614715@qq.com)进行联系作者,咱们可以一起讨论讨论。
也希望通过这个平台,能交到更多这行的朋友,感谢各位!
源码地址:https://github.com/z1069614715/pytorch-classifier
源码使用案例:使用pytorch实现花朵分类
源码中的损失函数代码案例:pytorch代码-图像分类损失函数
源码地址中有更详细的解释,后续也会在哔哩哔哩中上传如何使用的视频。
如果这个代码帮助了你,请在博客点个赞,请在github点个star,谢谢!
为什么推荐你使用这个代码?
- 丰富的可视化功能
- 训练图像可视化.
- 损失函数,精度,学习率迭代图像可视化.
- 热力图可视化.
- TSNE可视化.
- 数据集识别情况可视化.(metrice.py文件中–visual参数,开启可以自动把识别正确和错误的文件路径,类别,概率保存到csv中,方便后续分析)
- 类别精度可视化.(可视化训练集,验证集,测试集中的总精度,混淆矩阵,每个类别的precision,recall,accuracy,f0.5,f1,f2,auc,aupr)
- 总体精度可视化.(kappa,precision,recll,f1,accuracy,mpa)
- 丰富的模型库
- 由作者整合的丰富模型库,主流的模型基本全部支持,支持的模型个数高达50+,其全部支持ImageNet的预训练权重,详细请看Model Zoo.(变形金刚系列后续更新)
- 目前支持的模型都是通过作者从github和torchvision整合,因此支持修改、改进模型进行实验,并不是直接调用库创建模型.
- 丰富的训练策略
- 支持断点续训,只需要设定一个参数(–resume).
- 支持多种常见的损失函数.(目前支持PolyLoss,CrossEntropyLoss,FocalLoss)
- 支持一个参数即可设置类别平衡.
- 支持混合精度训练.(使你的机器能支持更大的batchsize)
- 支持知识蒸馏.
- 丰富的数据增强策略
- 支持RandAugment, AutoAugment, TrivialAugmentWide, AugMix, Mixup, CutMix, CutOut, TTA等强大的数据增强.
- 支持添加torchvision中的数据增强.
- 支持添加自定义数据增强.详细看Some explanation第十四点
- 丰富的学习率调整策略
本程序支持学习率预热,支持预热后的自定义学习率策略.详细看Some explanation第五点 - 支持导出各种常用推理框架模型
目前支持导出torchscript,onnx,tensorrt推理模型. - 简单的安装过程
- 安装好pytorch, torchvision(pytorch==1.12.0+torchvision==0.13.0+)
可以在pytorch官网找到对应的命令进行安装. - pip install -r requirements.txt
- 安装好pytorch, torchvision(pytorch==1.12.0+torchvision==0.13.0+)
- 人性化的设定
- 大部分可视化数据(混淆矩阵,tsne,每个类别的指标)都会以csv或者log的格式保存到本地,方便后期美工图像.
- 程序大部分输出信息使用PrettyTable进行美化输出,大大增加可观性.
- 后续更新
后续将会更新一些使用的图像分类的tricks到这个代码里面,例如SWA,R-Drop等等。
更新日志
-
2022.11.12
具体更新细节和实验可以看pytorch-classifier-v1.1更新日志 -
2022-11-27
具体更新细节和推理速度对比实验可以看pytorch-classifier-v1.2更新日志
Model Zoo
目前支持的模型,以下模型全部都支持基于ImageNet的预训练权重。
model | model_name |
---|---|
resnet | resnet18,resnet34,resnet50,resnet101,wide_resnet50,wide_resnet101,resnext50,resnext101 resnest50,resnest101,resnest200,resnest269 |
shufflenet | shufflenet_v2_x0_5,shufflenet_v2_x1_0 |
mobilenet | mobilenetv2,mobilenetv3_small,mobilenetv3_large |
densenet | densenet121,densenet161,densenet169,densenet201 |
vgg | vgg11,vgg11_bn,vgg13,vgg13_bn,vgg16,vgg16_bn,vgg19,vgg19_bn |
efficientnet | efficientnet_b0,efficientnet_b1,efficientnet_b2,efficientnet_b3,efficientnet_b4,efficientnet_b5,efficientnet_b6,efficientnet_b7 efficientnet_v2_s,efficientnet_v2_m,efficientnet_v2_l |
nasnet | mnasnet0_5,mnasnet1_0 |
vovnet | vovnet39,vovnet59 |
convnext | convnext_tiny,convnext_small,convnext_base,convnext_large,convnext_xlarge |
ghostnet | ghostnet |
repvgg | RepVGG-A0,RepVGG-A1,RepVGG-A2,RepVGG-B0,RepVGG-B1,RepVGG-B1g2,RepVGG-B1g4 RepVGG-B2,RepVGG-B2g4,RepVGG-B3,RepVGG-B3g4,RepVGG-D2se |
sequencer | sequencer2d_s,sequencer2d_m,sequencer2d_l |
darknet | darknet53,darknetaa53 |
cspnet | cspresnet50,cspresnext50,cspdarknet53,cs3darknet_m,cs3darknet_l,cs3darknet_x,cs3darknet_focus_m,cs3darknet_focus_l cs3sedarknet_l,cs3sedarknet_x,cs3edgenet_x,cs3se_edgenet_x |
dpn | dpn68,dpn68b,dpn92,dpn98,dpn107,dpn131 |