Simple and Efficient: A Semisupervised Learning Framework for Remote Sensing Image Semantic Seg

文章介绍了利用半监督学习改进的遥感图像分割框架,通过重新定义自训练范式、强数据增强和LSST方法,特别是提出了自适应阈值策略,以提高模型对复杂场景的适应性。
摘要由CSDN通过智能技术生成

最近看了一篇论文,关于一个高效的遥感图像分割框架

论文地址:Simple and Efficient: A Semisupervised Learning Framework for Remote Sensing Image Semantic Segmentation | IEEE Journals & Magazine | IEEE Xplorel

论文代码:xiaoqiang-lu/LSST: The official PyTorch implementation of our paper (Simple and Efficient: A Semisupervised Learning Framework for Remote Sensing Image Semantic Segmentation) Accepted by TGRS2022 (github.com)

三个方法的介绍

1、Problem Defined and Classical Self-Training Paradigm

作者首先重新定义了传统自训练的范式,其中λ是一个需要小心选择的超参数

但是λ的选择违背了我们所提出模型的精简性,因此作者对有标记的数据集重新采样,直至有标记数据集的数量Nl跟无标记Nu的一样,此时损失可以定义为:

其中α是固定值,表示Nu与Nl的比率

为什么要使NI和Nu接近呢?

在半监督网络中,如果标记数据集远小于未标记数据集时,模型会更多的学习标记数据集的特点,从而忽略掉许多未标记数据集的特征,并且容易造成过拟合。其次,如果按传统的自训练范式的话,λ的值很难取,取不好可能会使模型更倾向标记数据的特征。因此作者对有标记的数据集重新采样,直至有标记数据集的数量Nl跟无标记Nu的一样,简单的讲a带入损失函数即可。

下面是作者网络框图的数学公式:

其中S(.)是学生模型,T(.)是教师模型

\hat{y}^{u}_{i}是未标记数据集经过教师网络的预测值,及下图在Teacher模块后面的图片

这里是作者提出来的自训练步骤:训练教师模型->生成hard伪标签->重新训练学生模型

体现在作者网络框图上就是:先训练教师网络,将训练好的教师网络生成伪标签并经过SDA来与学生网络做loss来训练学生网络,最后总的loss为Lu+LI

2、Strong Data Augmentations Applicable to RS Images

这就是前面提到的SDA(数据增强),这里作者专注于设计用于RS图像的数据增强方法,作者在文中也没有说为什么这么选

从左往右依次是:原始图像,CT,GT,Cutout,经过所以操作的图像

3、LSST: Linear Sampling Self-Training

作者提出自适应阈值是因为目前很多自训练都是对全体特征设置一个统一的阈值,忽略了网络对不同的特征有不同的效果,对于大背景来说可能阈值刚好,但是对小物体来说,本身就难识别,你再给他设个阈值,可能直接就没了,因此需要对特定类设置自适应阈值

这可以说是本文中主要的创新点,我就不从文中的信息入手了,太过晦涩,我直接从代码入手讲解

这里最重要的就是自适应阈值是如何产生的

 """
        Adaptive Pseudo-Labeling
    """
    print('\n\n\n================> Total stage 2/3: Adaptive Pseudo labeling all unlabeled images')

    MODE = 'label'
    dataset = SemiDataset(args.dataset, args.data_root, MODE, None, None, args.unlabeled_id_path)
    dataloader = DataLoader(dataset,batch_size=args.batch_size,shuffle=False,pin_memory=True, num_workers=4,drop_last=False)

    sparse_label(best_model, dataloader, args)

这是训练的第二阶段,此时就需要自适应阈值出手了,亮点在sparse_label这里

def sparse_label(model, dataloader, args):
    model.eval()
    tbar = tqdm(dataloader)

    with torch.no_grad():
        for img, _, id in tbar:
            img = img.cuda()
            pred = model(img)
            soft_max_output, hard_output = pred.max(dim=1)
            for j in range(soft_max_output.shape[0]):
                soft, hard = soft_max_output[j].cpu().numpy(), hard_output[j].cpu().numpy()
                need = []
                for c in range(NUM_CLASSES[args.dataset]):
                    soft_clone, hard_clone = deepcopy(soft), deepcopy(hard)
                    need.append(ratio_sample(hard_clone, soft_clone, args.ratio, c))
                need = np.min(np.array(need), axis=0)
                pred = Image.fromarray(need.astype(np.uint8), mode='P')
                pred.save('%s/%s' % (args.pseudo_mask_path, os.path.basename(id[j].split(' ')[1])))

可以看到,其中提取了预测图像的预测概率以及对应的hard标签,pre的维度应该是(batch_size,class_num,h,w),所以soft_max_output, hard_output维度为(batch_size,h,w),一直到

ratio_sample(hard_clone, soft_clone, args.ratio, c)

这里就到了我们的自适应阈值出场了:

def ratio_sample(hard_out, soft_max_out, ratio, s_class):
    single_h = hard_out
    single_s = soft_max_out
    h_index = (single_h != s_class)
    single_h[h_index] = 255
    single_s[h_index] = 0
    all = sorted(soft_max_out[(hard_out == s_class)], reverse=True)
    num = len(all)
    need_num = int(num * ratio + 0.5)
    if need_num != 0:
        adaptive_threshold = all[need_num - 1]
        mask = (single_s >= adaptive_threshold)
        index = (mask == False)
        single_h[index] = 255
    else:
        single_h[(single_h != 255)] = 255

    return single_h

这里将single_s中不是s_class的硬标签取255,软标签取0,及概率取0,输出为黑色,融合对其中认定为s_class的软标签进行排序,并取第need_num个的概率为阈值,并讲软标签中该类概率小于阈值时,对应的hard标签设为255,这就实现了自适应标签的功能

  • 23
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值