最近看了一篇论文,关于一个高效的遥感图像分割框架
三个方法的介绍
1、Problem Defined and Classical Self-Training Paradigm
作者首先重新定义了传统自训练的范式,其中λ是一个需要小心选择的超参数
但是λ的选择违背了我们所提出模型的精简性,因此作者对有标记的数据集重新采样,直至有标记数据集的数量Nl跟无标记Nu的一样,此时损失可以定义为:
其中α是固定值,表示Nu与Nl的比率
为什么要使NI和Nu接近呢?
在半监督网络中,如果标记数据集远小于未标记数据集时,模型会更多的学习标记数据集的特点,从而忽略掉许多未标记数据集的特征,并且容易造成过拟合。其次,如果按传统的自训练范式的话,λ的值很难取,取不好可能会使模型更倾向标记数据的特征。因此作者对有标记的数据集重新采样,直至有标记数据集的数量Nl跟无标记Nu的一样,简单的讲a带入损失函数即可。
下面是作者网络框图的数学公式:
其中S(.)是学生模型,T(.)是教师模型
是未标记数据集经过教师网络的预测值,及下图在Teacher模块后面的图片
这里是作者提出来的自训练步骤:训练教师模型->生成hard伪标签->重新训练学生模型
体现在作者网络框图上就是:先训练教师网络,将训练好的教师网络生成伪标签并经过SDA来与学生网络做loss来训练学生网络,最后总的loss为Lu+LI
2、Strong Data Augmentations Applicable to RS Images
这就是前面提到的SDA(数据增强),这里作者专注于设计用于RS图像的数据增强方法,作者在文中也没有说为什么这么选
![](https://i-blog.csdnimg.cn/blog_migrate/c5332921ca5901944155cecc6034a05c.png)
3、LSST: Linear Sampling Self-Training
作者提出自适应阈值是因为目前很多自训练都是对全体特征设置一个统一的阈值,忽略了网络对不同的特征有不同的效果,对于大背景来说可能阈值刚好,但是对小物体来说,本身就难识别,你再给他设个阈值,可能直接就没了,因此需要对特定类设置自适应阈值
这可以说是本文中主要的创新点,我就不从文中的信息入手了,太过晦涩,我直接从代码入手讲解
这里最重要的就是自适应阈值是如何产生的
"""
Adaptive Pseudo-Labeling
"""
print('\n\n\n================> Total stage 2/3: Adaptive Pseudo labeling all unlabeled images')
MODE = 'label'
dataset = SemiDataset(args.dataset, args.data_root, MODE, None, None, args.unlabeled_id_path)
dataloader = DataLoader(dataset,batch_size=args.batch_size,shuffle=False,pin_memory=True, num_workers=4,drop_last=False)
sparse_label(best_model, dataloader, args)
这是训练的第二阶段,此时就需要自适应阈值出手了,亮点在sparse_label这里
def sparse_label(model, dataloader, args):
model.eval()
tbar = tqdm(dataloader)
with torch.no_grad():
for img, _, id in tbar:
img = img.cuda()
pred = model(img)
soft_max_output, hard_output = pred.max(dim=1)
for j in range(soft_max_output.shape[0]):
soft, hard = soft_max_output[j].cpu().numpy(), hard_output[j].cpu().numpy()
need = []
for c in range(NUM_CLASSES[args.dataset]):
soft_clone, hard_clone = deepcopy(soft), deepcopy(hard)
need.append(ratio_sample(hard_clone, soft_clone, args.ratio, c))
need = np.min(np.array(need), axis=0)
pred = Image.fromarray(need.astype(np.uint8), mode='P')
pred.save('%s/%s' % (args.pseudo_mask_path, os.path.basename(id[j].split(' ')[1])))
可以看到,其中提取了预测图像的预测概率以及对应的hard标签,pre的维度应该是(batch_size,class_num,h,w),所以soft_max_output, hard_output维度为(batch_size,h,w),一直到
ratio_sample(hard_clone, soft_clone, args.ratio, c)
这里就到了我们的自适应阈值出场了:
def ratio_sample(hard_out, soft_max_out, ratio, s_class):
single_h = hard_out
single_s = soft_max_out
h_index = (single_h != s_class)
single_h[h_index] = 255
single_s[h_index] = 0
all = sorted(soft_max_out[(hard_out == s_class)], reverse=True)
num = len(all)
need_num = int(num * ratio + 0.5)
if need_num != 0:
adaptive_threshold = all[need_num - 1]
mask = (single_s >= adaptive_threshold)
index = (mask == False)
single_h[index] = 255
else:
single_h[(single_h != 255)] = 255
return single_h
这里将single_s中不是s_class的硬标签取255,软标签取0,及概率取0,输出为黑色,融合对其中认定为s_class的软标签进行排序,并取第need_num个的概率为阈值,并讲软标签中该类概率小于阈值时,对应的hard标签设为255,这就实现了自适应标签的功能