AI Challenger 2018 农作物病害细粒度分类-----Pytorch 深度学习实战

AI Challenger 2018 农作物病害细粒度分类

1 前言
2 代码组织结构
3 完整流程解析
     3.1 EDA
     3.2 参数定义
     3.3 数据加载过程
     3.4 数据处理 DataAugmentation 和TTA
     3.5 模型定义
     3.6 训练过程定义
     3.7 模型融合
     3.8 测试与结果提交
     3.9 log记录,可视化训练过程及其他trick
4 收获与心得

前言

    本文以AI Challenger 2018 农作物病害细粒度分类为例,比赛详细信息和数据见文末,基于Pytorch 0.4.0 构建项目 其中模型训练部分是在jupyter中完成 因此没有将整个训练过程封装为可执行的py文件,做这个比赛的初衷是熟悉一下pytorch,还有就是了解一下打比赛的整个流程,在过程中排名一度还不错 让自己产生了可能能拿奖的错觉 果然还是年轻啊 第一次打比赛还想拿奖 最终acc 在0.883 如果有好心的大佬能够告诉我这个怎么调参调到0.89以上 十分感谢

    不过通过这次比赛 让自己学习到了很多编程的技巧 熟悉了流程 收获还是很大的 有想要一起打比赛的小伙伴可以组队呀。下面就将这次比赛的整个流程的收获做一下总结,方便日后参考,同时也能够作为一份真正的实战指导,虽然做的菜 但总归有可以借鉴的地方.

代码github地址

代码组织结构

    在使用pytroch过程中可以将整个流程分为如下部分:数据分析过程(EDA), 参数定义 ,数据加载过程,数据处理 Data Augmentation和TTA(Test Time Augmentation),模型定义,训练过程定义,验证过程定义,测试过程定义,log定义与训练过程可视化 ,模型融合。 大致可以分为上述部分,每一部分在下文中做具体展开。

整体代码结构如下:

• code
  ▫CropModel.py
  ▫CropDataSet.py
  ▫utils.py 
   ......
• config
   ▫config.py
• data
  ▫trainData
  ▫validationData
  ▫testData
• model
  ▫ResNet50
    ▫2018-11-03_acc.pth
• feature
  ▫ ResNet50
      ▫ val_all_prediction.pth
      ▫ val_crop_prediction.pth
      ▫ test_all_prediction.pth
      ......
• log
    ▫ 2018-11-01
           ▫ ResNet50
                  ▫ tensorBoardX
                  ▫ logtxt
           ▫ ResNet 101
     ▫ 2018-11-02
    
• submision:
        ▫ 2018-11-02


在这次比赛中我发现良好的代码组织以及模型组织是必不可少的,只有这样才能更好的实现源源不断的idea的修改,使得代码不至于不可控,上述代码组织结构是这次比赛摸索出来的,肯定还有不好的地方 需要之后实践中不断修改。
code : 存放项目代码其中CropModel .py 将项目使用到的所有模型进行封装 ,CropDataSet .py 存放数据加载类以及不同的transform的方法 ,utils存放各种工具方法
data:在data中下分三个文件夹 trainData ,testData,validationData 每个文件夹下面存放着对应的annotation.json以及img文件夹保存图像
model:model用来保存不同训练模型结果 以模型名称命名,在每个文件夹下以 日期+acc.pth 保留当日最好模型,日期+_loss.pth 保留当日最好loss模型。 在这里其实可以改进model的保存方式 可以每一轮(或者固定周期)都将模型保存下来 然后把最好的模型另外创建一个和model完全一致结构的checkpoint文件夹 专门保存最优模型
feature:feature文件夹存储TTA之后生成的结果(之所以称为feature 是在stacking的时候 第二层的算法是将第一次算法结果作为特征的 所以这里就使用feature来命名这些TTA的结果)
log:存储tensorboardX生成的训练过程图 以及自定义的训练过程中的log输出。在这次比赛中没有存储log输出而是使用jupyter 直接打印出来 这样做是有风险的 不利于log的回溯 同时如果jupyter断开与服务器的连接 那么log信息就会丢失
submission:存储提交结果

完整流程解析
EDA

对于该问题EDA相对而言较为简单 可以分为如下几个步骤
1.将annotation转化为pandas格式
2.查询trainData validateData testData中是否有缺失值存在
3.生成各类样本数量分布图 并按样本数量大小排序
4. 展示若干样本图像

首先通过使用matplotlib 和pandas 对数据进行简单的统计和可视化
matplotlib 可能会出现中文注解乱码的问题 可以通过下述代码解决

import matplotlib
matplotlib.rcParams[u'font.sans-serif'] = ['simhei']
matplotlib.rcParams['axes.unicode_minus'] = False

将json文件转化为pandas

with open("../data/AgriculturalDisease_trainingset/AgriculturalDisease_train_annotations.json") as datafile1:
    trainDataFram=pd.read_json(datafile1,orient='records')
with open("../data/AgriculturalDisease_validationset/AgriculturalDisease_validation_annotations.json") as datafile2: #first check if it's a valid json file or not
    validateDataFram =pd.read_json(datafile2,orient='records')    

查看数据中Null的分布情况:

total=trainDataFram.isnull().sum().sort_values(ascending=False)
percent=(trainDataFram.isnull().sum())/(trainDataFram.isnull().count()).sort_values(ascending = False)
missing_validation_data = pd.concat([total, percent], axis=1, keys=['Total', 'Percent'],sort=False)
missing_validation_data.head()

在这里插入图片描述

查看数据分布情况

dataDistribute=trainDataFram.groupby(by=['disease_class']).size()
plt.figure(figsize=(50,20),dpi=100)
plt.xticks(range(len(dataDistribute)),dataDistribution.index.tolist(),fontsize=40)
plt.yticks(fontsize=40)
bar=plt.bar(dataDistribution.index.tolist(), dataDistribute.tolist(),width=0.7)
 
for b in bar:
    h=b.get_height()
    plt.text(b.get_x()+b.get_width()/2,h,int(h),ha='center',fontsize=30)
plt.show()

在这里插入图片描述

validate data
在这里插入图片描述

由此可见在训练过程中可以将44,45 label删除 提升正确率

根据数据量的大小排序

trainDataFram['disease_class'].value_counts().plot(kind='bar',figsize=(60,30),fontsize =60,title="Number of Training Examples Versus Class").title.set_size(80)

在这里插入图片描述

按大小排列同时在柱状图上增加数据量大小

dataDistribute=trainDataFram['disease_class'].value_counts()
plt.figure(figsize=(50,20),dpi=100)
plt.xticks(range(len(dataDistribute)),dataDistribute.index.tolist(),fontsize=40) #第一个参数是在哪些位置需要放置坐标值  第二个参数是放置的坐标值大小
plt.yticks(fontsize=40)
bar=plt.bar(range(len(dataDistribute)),dataDistribute.tolist(),width=0.6)
for b in bar:
    h=b.get_height()
    plt.text(b.get_x()+b
  • 37
    点赞
  • 231
    收藏
    觉得还不错? 一键收藏
  • 42
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 42
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值