无人机图像语义分割--天池2019年县域农业大脑AI挑战赛冠军解决方案
比赛页面
链接: 比赛页面.
github
链接: github.
任务简介
通过无人机航拍的地面影像,探索像素级农作物分类的精准算法,识别薏仁米、玉米、烤烟、人造建筑以及背景(其它)等。
解决方案
我们队伍的最终方案融合了三个模型,本文仅仅介绍自己做的xception模型的相关介绍(所涉及的数据以及指标均为复赛阶段的数据和指标)。
数据处理
数据存在的问题
- 由于比赛提供的数据是几张非常大的图片(分辨率40000×40000级别),因此训练网络时需要对数据进行切分,预测结束后则要进行结果的拼接。
- 数据不平衡,背景占据60%以上,建筑物类别只有5%左右。
数据不平衡
针对数据不平衡问题,采用固定窗口大小的滑窗策略,动态检测窗口所处位置的类别,保留规则如下:首先利用mask(训练数据第4通道,1表示该位置的像素有实际意义,0表示该位置的像素为纯黑,没有实际意义,与背景类不同意义)删除纯黑像素占比大于当前窗口像素的 7 8 \frac{7}{8} 87的图片,其次若当前窗口背景类像素占比小于 1 3 \frac{1}{3} 31则减小下一步的滑动步长,即利用过采样稀少样本达到缓解数据不平衡的目的。核心代码如下:
while h_index < label_h - UNIT:
w_index = 0
while w_index < label_w - UNIT:
img_unit = img[h_index:h_index + UNIT, w_index:w_index + UNIT, :]
# 删除黑色大于7/8的图片
if np.sum(np.where(np.sum(img_unit, axis=2) != 0, 1, 0)) > UNIT*UNIT//8:
k = k + 1
print('\rcrop {} unit image'.format(k), end='', flush=True)
label_unit = label[h_index:h_index + UNIT, w_index:w_index + UNIT]
path_unit_img = os.path.join(output_dir, 'image', '{}_{}.png'.format(IMAGE, k))
path_unit_label = os.path.join(output_dir, 'label', '{}_{}.png'.format(IMAGE, k))
path_unit_vis = os.path.join(output_dir, 'vis', '{}_{}.png'.format(IMAGE, k))
io.imsave(path_unit_img, img_unit)
io.imsave(path_unit_label, label_unit)
io.imsave(path_unit_vis, vis(label=label_unit, img=img_unit))
# 如果0类个数小于1/3,则减小步长
if np.sum(np.where(label_unit == 0, 1, 0)) < UNIT*UNIT//3:
w_index = w_index + UNIT//2
else:
w_index = w_index + UNIT // 10 * 9
else:
w_index = w_index + UNIT // 10 * 9
h_index = h_index + UNIT // 10 * 9
拼接问题
采用膨胀预测缓解拼接边缘问题
模型选择
比赛尝试了各种模型,包括但不限于PSPNet,各种Backbone的UNet,最后选择了deeplabv3+。(404小公司还是牛逼啊)
模型选择方面的一点思考或者说疑问:在尝试UNet,FPN这种结合底层特征(直接从Backbone前端抽取)的网络模型时,发现预测结果非常分散,出现很多小点或者小洞,导致预测结果的指标都不是很好,而PSPNet和DeepLab等网络却没有这种问题。
个人认为这个问题和这次的比赛任务农作物的分割有关,农田这个集合本身的联系没有其他物体来的紧密,比如本次训练数据中有许多农田中农作物种植的比较稀疏,过多的结合太底层的特征(类似UNet,FPN直接从backbone中提取早期的特征的网络),则农田中植物之间露出的土地和其他非农田中的土地并没有区别,会导致网络将农田中较为稀疏的裸露的土地这一部分分成背景,而PSPnet和DeepLab网络则因为没有结合或者结合的比重较小而有较大的优势,预测的结果较为平滑。
其他trick
学习率预热和多项式衰减
参考《A disciplined approach to neural network hyper-parameters: Part 1 – learning rate, batch size, momentum, and weight decay》,onecycle学习率设置的思想:
前期利用较大的学习率冲破局部最优解,到达一个相对平缓的地带,后期利用较小的学习率搜索到一个精度较高的解。
使用warmup和poly_decay达到类似的学习率曲线,取得了不错的提升。
但是在本次试验中发现后期的学习率衰减不能像onecycle策略的那样衰减到很小,衰减到很小的学习率导致了模型的过拟合,因此调大了最后poly_decay策略衰减的学习率,使模型保持良好的泛华能力。
Radam优化器
正好在比赛期间看到,拿过来一用发现比adam稳定,参考论文《On the Variance of the Adaptive Learning Rate and Beyond》
标签平滑
在应用这个trick之前我们甚至使用了难样本挖掘,本来我们以为那些样本是难样本,使用之后效果反而大打折扣,后来才发现有些样本本来就是错误样本。发现许多地方存在标注错误的问题后(和实际问题和标注质量有关,农田边界本身就是模糊不清的),标注错误的样本会导致训练模型的时候难以收敛或者过拟合错误样本,因此参考论文《Rethinking the Inception Architecture for Computer Vision》,采用Label Smoothing进行标签软化。Szegedy在论文中说:
Intuitively, this happens because the model becomes too confident about its predictions
将标签从本来的非对即错状态变为所属类别的概率,这样更加符合客观规律,可以减少错误样本对模型的干扰。使得模型训练更加稳定,模型更加鲁棒。
伪标签
比赛大杀器,和融合一样没有明令禁止的话就无脑上。
基本流程就是将测试数据进行预测,剔除概率较低(低于Probability threshold)的部分,将概率较高的部分当成训练数据加入训练集重新训练。
Pseudolabeling(mIOU) | Probability threshold | Test A score(mIOU) |
---|---|---|
- | - | 0.794 |
0.794 | 86% | 0.804 |
0.804 | 86% | 0.810 |
后处理
根据观察训练数据可以得出农田呈现大片大片的状态不会有很碎的小碎块,因此利用检测连通区域的方法去除掉小碎块提高指标。
核心代码:
if area_threshold !=0:
result = to_categorical(result, num_classes=5, dtype='uint8')
for i in tqdm(range(5)):
result[:, :, i] = morphology.remove_small_objects(result[:, :, i] == 1, min_size=area_threshold, connectivity=1, in_place=True) + 0
result[:, :, i] = morphology.remove_small_holes(result[:, :, i] == 1, area_threshold=area_threshold, connectivity=1, in_place=True) + 0
# result[:, :, i] = cv2.morphologyEx(result[:, :, i], cv2.MORPH_OPEN, morphology.square(200))
# result[:, :, i] = cv2.morphologyEx(result[:, :, i], cv2.MORPH_CLOSE, morphology.square(200))
result = np.argmax(result, axis=2).astype(np.uint8)
首先将标签转化为onehot类型,然后对每个类别进行填充空洞和去除碎片的处理,由于开闭操作太费时间并且提升不明显,最后并没有使用。
最终决赛队伍指标
队伍名称 | A榜 | B榜 |
---|---|---|
冲鸭!大黄 | 0.810 | 0.818 |
AKLDF | 0.788 | 0.817 |
算法cj | - | 0.806 |
A-Force | - | 0.804 |
我们打野贼六 | 0.782 | 0.803 |
致谢
该项目网络基于tensorflow官方的deeplabv3+项目
感谢另外两位队友的大力支持
体会
- 融合是真的厉害,特别是模型之间的相关性越弱,融合之后提升越大。
- 花里胡哨一大堆trick,提分最大的还是和数据直接相关的伪标签,数据驱动决定了大的提升应该基本都和数据有紧密关联。
- 当backbone的拟合能力足够的时候,换backbone好像并没有很大的变化。
数据集
百度网盘 提取码:91c3
不能用于发论文或者商业行为