Revisiting Knowledge Distillation via Label Smoothing Regularization

摘要
知识提炼(KD)旨在将繁琐的教师模型中的知识提炼为轻量级的学生模型。它的成功通常归功于教师模型提供的关于类别之间相似性的特权信息,从这个意义上说,只有强大的教师模型被部署在实践中教较弱的学生。在这项工作中,我们通过以下实验观察来挑战这一共同信念:1)除了承认教师可以改善学生,学生还可以通过逆转KD程序显著提高教师;2) 一位训练有素的教师,其准确度比学生低得多,仍然可以显著提高后者。为了解释这些观察结果,我们对KD和标签平滑正则化之间的关系进行了理论分析。我们证明了1)KD是一种学习标签平滑正则化,2)标签平滑正则化为KD提供了一个虚拟教师模型。从这些结果来看,我们认为KD的成功不完全是由于教师类别之间的相似性信息,而是由于软目标的正规化,这一点同样重要,甚至更重要。
基于这些分析,我们进一步提出了一个新的无教师知识提取(Tf-KD)框架,其中学生模型从自身或手动设计的正则化分布中学习。Tf-KD的表现与高级教师提供的普通KD相当,在没有更强大的教师模型时可以很好地应用。同时,Tf-KD是通用的,可以直接用于深层神经网络的训练。在没有任何额外计算成本的情况下,Tf-KD与成熟的基线模型相比,在ImageNet上实现了高达0.65%的改进,这优于标签平滑正则化。

1.介绍

知识提炼(KD)[7]旨在将知识从一个神经网络(教师)转移到另一个神经网络(学生)。通常,教师模式具有较强的学习能力和较高的绩效,通过提供“软目标”来教授能力较低的学生模式。人们普遍认为,教师模型的软目标可以转移包含不同类别之间相似性特权信息的“暗知识”,以增强学生模型。

在这项工作中,我们首先通过以下探索性实验来检验这样一个共同的信念:1)让学生模型通过传递学生的软目标来教授教师模型;(2) 让表现较差的、训练较少的教师模型来教学生。基于常识来说,预计教师模型不会通过来自学生的培训得到显著增强,而训练有素的教师也不会增强学生,因为薄弱的学生和训练有素的教师模型无法提供类别之间可靠的相似信息。然而,在对各种模型和数据集进行大量实验后,我们观察到了相互矛盾的结果:薄弱的学生可以提高教师水平,而训练有素的教师也可以显著提高学生水平。这些有趣的结果促使我们将KD解释为一个正则化术语,我们从标签平滑正则化(LSR)[16]的角度重新审视了知识提取,它通过用平滑标签替换一个热标签来正则化模型训练。

然后,我们从理论上分析了KD和LSR之间的关系。对于LSR,通过将平滑标签分成两部分并检查相应的损失,我们发现第一部分是真实标签分布(一个热标签)和模型输出的普通交叉熵,第二部分对应于一个虚拟教师模型,该模型提供了统一的分布来教授模型。对于KD,通过将教师的软目标与one-hot 真实标签相结合,我们发现KD是一个学习的LSR,其中KD的平滑分布来自教师模型,而LSR的平滑分布是手动设计的。总之,我们发现KD是一个学习的LSR,LSR是一个特殊的KD。 这种关系可以解释上述违反直觉的结果——来自弱学生和训练不足的教师模型的软目标可以有效地规范模型训练,尽管它们缺乏类别之间的强相似信息。因此,我们认为类别之间的相似信息不能完全解释KD中的暗知识,而教师模型中的软目标确实为学生模型提供了有效的正则化,这一点同样重要,甚至更重要。

基于这些分析,我们推测,如果教师模型中的类别之间的相似性信息不可靠,甚至为零,KD仍然可以很好地改进学生模型。因此,我们提出了一个新的无教师知识提炼(Tf-KD)框架,该框架有两个实现。第一种方法是自行训练学生模型(即自我训练),第二种方法是手动设计目标分布作为虚拟教师模型,具有100%的准确性。第一种方法的动机是用模型本身的预测取代暗知识,第二种方法的灵感来自KD和LSR之间的关系。我们通过大量实验验证了Tf-KD的两种实现都是简单而有效的。特别是,在虚拟教师中没有相似信息的第二个实现中,Tf-KD仍然实现了与正常KD相当的性能,这清楚地证明: dark知识不仅包括类别之间的相似性,而且还对学生的训练进行了规范化。

Tf-KD很好地适用于学生模型太强而无法找到教师模型或培训教师模型的计算资源有限的场景。例如,如果我们将一个笨重的单一模型ResNeXt101-32×8d[18]作为学生模型(ImageNet上的参数为88.79M,FLOPs次数为16.51G),那么训练一个更强大的教师模型很难,或者计算成本很高。我们部署了我们的虚拟教师来教这个强大的学生,并在ImageNet上实现了0.48%的改进,而无需任何额外的计算成本。同样,当使用一个功能强大的单模型ResNeXt29-8×64d和34.53M参数作为学生模型时,我们的自我培训实现比CIFAR100提高了1.0%以上(从81.03%提高到82.08%)。

我们的贡献主要如下:
1.通过在KD教师模型上设计两个探索性实验,我们观察到了违反直觉的结果,这促使我们将KD解释为一种正则化方法。
2.然后,我们提供了理论分析,以揭示KD和标签平滑正则化之间的关系。
3.我们提出了无教师知识蒸馏(Tf-KD),其性能与普通的知识蒸馏相当,并优于ImageNet2012上的标签平滑正则化。

2.探索性实验和违反常识的观察

为了检验KD中关于dark知识的普遍看法,我们进行了两个探索性实验:
(1)标准的知识蒸馏是让一位老师教一个较弱的学生。如果我们取消这个操作呢?基于这一常识,教师应该不会有显著的提升,因为学生太弱而无法传授有效的知识。
(2) 如果我们用一位训练较少的老师来教学生,而老师的表现比学生差得多,那么我们就认为这不会给学生带来任何进步。例如,如果在一个图像分类任务中采用了一个培训不足、准确率只有10%的教师,那么学生将从其90%的错误软目标中学习,因此学生不应该得到改进,甚至表现更差。

我们将“学生-教师”命名为反向知识蒸馏(Re-KD),将“培训不良的教师-教师-学生”命名为缺陷知识提取(De-KD)(图1)。我们使用各种神经网络在CIFAR10、CIFAR100和微小的ImageNet数据集上进行Re-KD和De-KD实验。为了进行公平比较,所有实验都在相同的设置下进行,并通过网格搜索从70个epoch(总共200个epoch)的训练中获得超参数。补充材料中给出了详细的实现和实验设置。

2.1反向知识蒸馏

我们分别在这三个数据集上进行了Re-KD实验。CIFAR10和CIFAR100[9]分别包含10类和100类32x32像素的自然RGB图像,而Tiny ImageNet是200类ImageNet[3]的一个子集,其中每个图像的大小减小到64x64像素。为了实验的通用性,我们采用了5层普通CNN、MobilenetV2[15]和ShufflenetV2[10]作为学生模型,ResNet18、ResNet50[6]、DenseNet121[8]和ResNeXt29-8×64d作为教师。这三个数据集上的Re KD结果在选项卡中给出。1到3。

在表1上。通过向学生学习,教师模型得到了显著改进,尤其是教师模型ResNet18和ResNet50。这两位教师在MobileNet V2和ShuffleNet V2。我们也可以在CIFAR10和Tiny ImageNet上观察到类似的结果。当比较Re-KD(S→T) 正常KD(T→S) ,我们可以看到,在大多数情况下,正常KD获得更好的结果。值得注意的是,Re-KD以教师的准确度作为基线准确度,远高于正常KD。然而,在某些情况下,我们可以发现Re-KD优于正常KD。例如,在Tab2(第三排),学生模型(普通CNN)在MobileNet V2教授时只能提高0.31%,但教师(MobileNet V2)可以通过向学生学习提高0.92%。我们对ResNeXt29和ResNet18有类似的观察结果(表2第4行)。

我们认为,虽然标准的知识蒸馏可以提高学生在所有数据集上的表现,但正如Re-KD实验所表明的那样,优秀教师也可以通过向弱学生学习而得到显著提升。
在这里插入图片描述

2.2. 缺陷知识蒸馏

我们在CIFAR100和Tiny ImageNet上进行De-KD(缺陷知识蒸馏)。我们采用MobileNetV2和ShuffleNetV2作为学生模型,采用ResNet18、ResNet50和ResNeXt29(8×64d)作为教师模型。培训不佳的教师接受1个时代(ResNet18)或50个时代(ResNet50和ResNeXt29)的培训,表现非常差。例如,经过1个历元的训练,ResNet18在CIFAR100上的准确率仅为15.48%,在Tiny ImageNet上的准确率仅为9.41%;经过50个epoch(总共200个epoch)的训练,ResNet50在CIFAR100和Tiny ImageNet上的准确率分别为45.82%和31.01%。

根据表4中CIFAR100上的De-KD实验结果。我们观察到,即使由训练较少的老师进行知识蒸馏,学生也能得到极大的提升。例如,当one-epoch-train训练的ResNet18以15.48%的准确率(第二排)教授MobileNetV2和ShuffleNetV2时,可以提升2.27%和1.48%。对于训练较差的ResNeXt29,准确率为51.94%(第四排),我们发现ResNet18仍可以提高1.41%,而Mo bileNetV2则可以提高3.14%。根据Tab4中Tiny ImageNet上的De-KD实验结果。我们发现,准确率为9.14%的ResNet18仍然可以将教师模型MobileNetV2提高1.16%。其他训练有素的教师都能在一定程度上提高学生的能力。

为了更好地展示学生在接受过不同程度准确度的低水平培训的教师授课时的蒸馏准确度,我们在正常培训过程中保留了ResNet18和ResNeXt29的9个检查点。将这些检查点作为教授MobileNetV2的教师模型,我们观察到,训练不好的ResNet18或训练不好的ResNeXt29总是可以提高MobileNetV2的准确性(图2)。因此,我们可以说,虽然一位训练有素的老师为学生提供了更多具有噪声的logit,但学生仍然可以得到提升。De-KD的实验结果也与普遍的看法相矛盾。

Re-KD和De-KD的反直觉结果让我们重新思考KD中的“暗知识”,我们认为它不只是包含相似信息。缺乏足够的相似性信息,一个模型仍然可以提供“暗知识”来增强其他模型。为了解释这一点,我们做出了合理的假设,并将知识蒸馏视为一种模型正则化,并研究了模型“暗知识”中的附加信息。接下来,我们将分析知识提取和标签平滑正则化之间的关系,以解释Re-KD和De-KD的实验结果。
在这里插入图片描述在这里插入图片描述

3.知识蒸馏与标签平滑正则化

我们从数学上分析了知识蒸馏(KD)和标签平滑正则化(LSR)之间的关系,希望能解释Sec2的探索性实验的有趣结果。给定要训练的神经网络S,我们首先给出S的LSR损失函数。对于每个训练示例x,S输出每个标签的概率 :
在这里插入图片描述
其中zi是神经网络S的logit。标签上的真实值分布是q(k | x)。 为了简单起见,我们把p(k | x)写成p(k),把q(k | x)写成q(k) 。模型S可以通过最小化交叉熵损失来训练:H(q,p)= 在这里插入图片描述

对于一个真实标签y,对于所有k≠y,q(y | x)=1和q(k | x)=0。

在LSR中,它最小化了修改后的标签分布q’(k)和网络输出p(k)之间的交叉熵,其中q’(k)是表示为:
在这里插入图片描述
它是q(k)和固定分布u(k)的混合物,权重为α。通常,u(k)是均匀分布的u(k)=1/K。定义在平滑标签上的交叉熵损失H(q‘,p)为 :
在这里插入图片描述
其中D KL是Kullback-Leibler散度(KL散度),H(u)表示u的熵,是固定均匀分布u(k)的常数。因此,模型S的标签平滑损失函数可以写成:
在这里插入图片描述在这里插入图片描述对于知识蒸馏,师生学习机制被用于提高学生的表现。我们假设学生是具有输出预测p(k)的模型S,教师网络的输出预测是:
在这里插入图片描述
其中 z t z^t zt是教师网络的输出逻辑,τ是软化 p t p^t pt(k)的温度(软化后写为 p τ t p^t_τ pτt(k))。知识蒸馏背后的理念是让学生(模型S)通过最小化交叉熵损失和学生与教师预测之间的KL差异来模仿教师 :
在这里插入图片描述
比较式(3)和式(4),我们发现这两个损失函数的形式相似。唯一的区别是, D K L D_{KL} DKL p τ t p^t_τ pτt p τ p_τ pτ)中的 p τ t p^t_τ pτt(k)是来自教师模型的分布,而 D K L D_{KL} DKL(u,p)中的u(k)是预定义的均匀分布。从这个观点,我们可以考虑KD作为LSR的特殊情况,其中平滑分布是学习的,但不是预先定义的。另一方面,如果我们将正则化项 D K L D_{KL} DKL(u,p)视为知识提炼的虚拟教师模型,该教师模型将为所有班级提供统一的概率,这意味着它具有随机精度(CIFAR100的精度为1%,ImageNet的精度为0.1%)。
由于:
在这里插入图片描述
对于固定的教师模型,当熵H( p τ t p^t_τ pτt)为常数时,我们可以将公式(4)重新表示为:
在这里插入图片描述
如果我们设置温度τ=1,我们得到:
在这里插入图片描述
其中~qt是 :
在这里插入图片描述如果我们将公式(6)与公式(1)进行比较,可以更清楚地看到KD是LSR的一个特例。此外,分布 p t p^t pt(k)是一个学习分布(来自训练好的教师),而不是均匀分布u(k)。我们将教师的输出概率 p t p^t pt(k)可视化,并将其与辅助材料中的标签平滑进行比较,发现温度τ越高, p t p^t pt(k)越类似于标签平滑的均匀分布u(k)。 基于两种损失函数的比较,我们总结了知识蒸馏和标签平滑正则化之间的关系,如下所示:
1.知识蒸馏是一种学习到的标签平滑正则化,它具有与后者类似的功能,即正则化模型的分类器层。
2.标签平滑是一种特殊的知识蒸馏,可以作为一个随机精度和温度τ=1的教师模型进行重新访问。
3.随着温度的升高,知识蒸馏中教师软目标的分布更接近于标签平滑的均匀分布。

因此,Re-KD和De-KD的实验结果可以解释为该模型在高温下的软目标更接近于标签平滑的均匀分布,学习到的软目标可以为教师模型提供模型正则化。这就是为什么一个学生可以提升老师,而一个缺乏训练的老师仍然可以改进学生模型。

4.无教师的知识蒸馏

如上所述,教师模型中的“暗知识”更多地是一个正则化术语,而不是类别之间的相似信息。直观地,我们考虑用简单的模型代替教师模型的输出分布。因此,我们提出了一个新的无教师知识提炼(Tf-KD)框架,该框架有两个实现。Tf-KD特别适用于没有更强大的教师模型,或仅提供有限计算资源的情况。

**第一种Tf-KD方法是自训练知识提取,称为Tf-KDself。**如前所述,教师可以由学生教授,而训练有素的教师也可以提高学生的能力。因此,如果没有更强大的教师模式,我们建议部署“自我培训”。值得注意的是,KD中的老师总是意味着更强的榜样。我们将自我培训称为无教师培训,因为该模式不是一个学习能力比自身更强的教师。我们的自我类似于重生网络[4],但有两个区别。我们的动机(自我训练\自我调节)不同于重生网络;我们的方法使用模型自身的软目标作为正则化,而再生的网络使用学生模型的集合来迭代地训练自己。 具体地说,我们首先以正常的方式训练学生模型,以获得预先训练的模型,然后使用该模型提供软标签来训练自己,如等式(4)所示。形式上,给定一个模型S,我们将其预训练模型表示为Sp;然后,我们尝试通过Tf KDself最小化S和 S p S^p Sp之间逻辑的KL散度。Tf KDself对训练模型S的损失函数为 :
在这里插入图片描述
其中p, p τ t p^t_τ pτt分别是S和 S p S^p Sp的输出概率,τ是温度,α是权重。

**我们的Tf-KD方法的第二个实现是手动设计一个100%准确率的教师。**在sec3中,我们发现LSR是一个具有随机精度的虚拟教师模型。因此,如果我们设计一个更准确的老师,我们可以认为这会给学生带来更多的进步。我们建议结合KD和LSR构建一个简单的教师模型,该模型将输出类别分布,如下所示:
在这里插入图片描述
其中K是类的总数,c是正确的标签,a是正确类的正确概率。我们总是设定一个a≥ 0.9,因此正确上课的概率远高于错误上课的概率,手动设计的教师模型对任何数据集都有100%的准确性。

通过手动设计正则化,我们将这种方法命名为无教师KD,表示为Tf KDreg。损失函数是 :

在这里插入图片描述
式中,τ是软化手动设计分布 p d p^d pd的温度(作为软化后的 p τ d p^d_τ pτd)。我们设定了高温τ≥ 20使这个虚拟教师输出一个软概率,这样它就获得了LSR的平滑特性。我们在图3中展示了手动设计的教师的分布。如图3所示,这种手动设计的教师模型输出的软目标具有100%的分类精度,并且具有标签平滑的平滑特性。但是Tf KDreg不是LSR的过度参数化版本,因为温度τ≥1.因此,当我们调整参数α、a或u(k)时,等式9将不等于等式3。 Tf-KDself和Tf-KDreg这两种无教师指导的方法非常简单但有效,在下一节中通过大量实验验证了这一点。
在这里插入图片描述

5.Tf-KD实验

5.1. Experiments for Self-training(Tf-KDself)

(1)CIFAR100
在这里插入图片描述
在这里插入图片描述
表5显示了六种模型的测试精度。可以看出,我们的Tf KDself始终优于基线。例如,作为一个具有34.52M参数的强大模型,ResNeXt29通过自正则化将自身提高了1.05%。即使在表5(第4列)中与具有优秀教师的普通KD相比,我们的方法也取得了相当的性能(Tf KD和普通KD的实验设置相同,并且搜索Tf KDself和普通KD两者的超参数)。例如,使用ResNet50教授ReseNet18,学生的成绩提高了1.19%,但我们的方法在不使用任何更强的教师模型的情况下实现了1.23%的提高。在图4中,我们还通过Tf KDself获得了MobileNetV2的类似结果。

(2)Tiny-ImageNet.
在这里插入图片描述
在Tiny ImageNet上,我们使用基线模型,包括MobileNetV2、ShuffleNetV2、ResNet50和DenseNet121。对于MobileNetV2、ShuffleNetV2和ResNet50、DenseNet121,它们的批量大小分别为bn=128和bn=64。初始学习率为η=0.1*bn128,然后在第60、120和160个历元时除以10。我们使用动量为0.9的SGD优化器,权重衰减设置为5e-4。表6显示了Tiny ImageNet上Tf KDself的结果。可以看出,Tf KDself持续改进了基线模型,并实现了与正常KD相当的改进。

(3)ImageNet.
在这里插入图片描述
我们可以看到,自我训练可以进一步提高ImageNet-2012的基线性能。作为比较,我们还使用DenseNet121在ImageNet上教授ResNet18,ResNet18获得了0.56%的改进,这与我们的Tf KDself相当(表8)。

5.2. Experiments for Manually-designed Regularization(Tf-KDreg)

(1)CIFAR100 and Tiny-ImageNet.
在这里插入图片描述对于CIFAR100和Tiny ImageNet上的Tf KDreg实验,我们将正确类的概率设置为a=0.99(等式(8))。对于不同的基线模型,方程(9)中的温度τ和α是不同的(见补充材料)。从表9和表10中,我们可以观察到,在没有使用教师的情况下,只添加了一个正则化术语,Tf KDreg在CIFAR100和TinyImageNet上实现了与Normal KD相当的性能。

(2)ImageNet.
在这里插入图片描述
对于ImageNet上的Tf KDreg,我们采用温度τ=20作为正常知识蒸馏,并且α=0.1作为标记平滑正则化。手动设计的教师正确上课的概率为a=0.99(等式(9))。我们使用四个基线模型测试Tf KDreg:ResNet18、ResNet50、DenseNet121和ResNeXt101(32x8d)。作为一个正则化术语,与基线相比,手动设计的教师实现了一致的改进。例如,建议的TfKDreg将ImageNet-2012上ResNet50的top1精度提高了0.65%(表11)。即使对于具有88.79M参数的巨大单模型ResNeXt101(32x8d),我们的方法也通过使用手动设计的教师实现了0.48%的改进。

比较我们的两种方法Tf KDself和Tf KDreg,我们观察到Tf KDself小数据集(CIFAR100)中工作得更好,而Tf KDeg大数据集(ImageNet)中表现得稍微更好。

5.3Comparison with LSR

Tf KDreg的动机是LSR,这可以看作是LSR的修改。这种修改显著提高了神经元网络的性能,而无需额外的计算成本。与LSR相同,Tf KDreg可以作为常规训练神经网络的通用正则化方法。我们将Tf KDreg与CIFAR100、TinyImageNet和ImageNet上的标签平滑进行了比较。为了公平比较,Tf KDreg和LSR的实验设置是相同的。结果见表9、10和11。可以看出,Tf KDreg始终优于LSR。此外,KDRman的公式类似于LSR,但它不是标签平滑的过度参数化版本。我们详细比较了Tf KDreg和LSR,以显示补充材料中的差异。

6.结论

在这项工作中,我们通过实验和分析发现,教师模型的“暗知识”更多地是一个正则化术语,而不是类别的相似性信息。基于KD和LSR之间的关系,我们提出了无教师KD。实验结果表明,我们的Tf-KD在图像分类中可以获得与普通KD相当的结果。我们的工作还表明,当很难为强大的模型找到更强大的教师或计算资源被限制为训练教师模型时,目标模型仍然可以通过自我训练或手动设计的正则化项来增强。

以下为文章所述Tf-KDreg第一种方法的损失函数:

class DIST(nn.Module):
    def __init__(self, beta=1., gamma=1.):
        super(DIST, self).__init__()
        self.beta = beta
        self.gamma = gamma

    def forward(self, y_s, y_t):
        assert y_s.ndim in (2, 4)
        if y_s.ndim == 4:
            num_classes = y_s.shape[1]
            y_s = y_s.transpose(1, 3).reshape(-1, num_classes)
            y_t = y_t.transpose(1, 3).reshape(-1, num_classes)
        y_s = y_s.softmax(dim=1)
        y_t = y_t.softmax(dim=1)
        inter_loss = inter_class_relation(y_s, y_t)
        intra_loss = intra_class_relation(y_s, y_t)
        loss = self.beta * inter_loss + self.gamma * intra_loss
        return loss

以下为测试部分代码:

class Evaluator(object):
    def __init__(self, args, num_gpus):
        self.args = args
        self.num_gpus = num_gpus
        self.device = torch.device(args.device)

        ignore_label = -1
        self.id_to_trainid = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
                        3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
                        7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
                        14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
                        18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
                        28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}

        # dataset and dataloader
        self.val_dataset = CSTestSet(args.data, args.data_list)

        val_sampler = make_data_sampler(self.val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=self.val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model, 
                                            backbone=args.backbone,
                                            aux=args.aux, 
                                            pretrained=args.pretrained, 
                                            pretrained_base='None',
                                            local_rank=args.local_rank,
                                            norm_layer=BatchNorm2d).to(self.device)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(self.model,
                device_ids=[args.local_rank], output_device=args.local_rank)
        self.model.to(self.device)

        self.metric = SegmentationMetric(self.val_dataset.num_class)

    def id2trainId(self, label, id_to_trainid, reverse=False):
        label_copy = label.copy()
        if reverse:
            for v, k in id_to_trainid.items():
                label_copy[label == k] = v
        else:
            for k, v in id_to_trainid.items():
                label_copy[label == k] = v
        return label_copy


    def reduce_tensor(self, tensor):
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        return rt

    def predict_whole(self, net, image, tile_size):
        interp = nn.Upsample(size=tile_size, mode='bilinear', align_corners=True)
        prediction = net(image.cuda())
        if isinstance(prediction, tuple) or isinstance(prediction, list):
            prediction = prediction[0]
        prediction = interp(prediction)
        return prediction

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        logger.info("Start validation, Total sample: {:d}".format(len(self.val_loader)))
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            N_, C_, H_, W_ = image.size()
            tile_size = (H_, W_)
            full_probs = torch.zeros((1, self.val_dataset.num_class, H_, W_)).cuda()

            scales = args.scales
            with torch.no_grad():
                for scale in scales:
                    scale = float(scale)
                    print("Predicting image scaled by %f" % scale)
                    scale_image = F.interpolate(image, scale_factor=scale, mode='bilinear', align_corners=True)
                    scaled_probs = self.predict_whole(model, scale_image, tile_size)

                    if args.flip_eval:
                        print("flip evaluation")
                        flip_scaled_probs = self.predict_whole(model, torch.flip(scale_image, dims=[3]), tile_size)
                        scaled_probs = 0.5 * (scaled_probs + torch.flip(flip_scaled_probs, dims=[3]))
                    full_probs += scaled_probs
                full_probs /= len(scales)  

            if self.args.save_pred:
                pred = torch.argmax(full_probs, 1)
                pred = pred.cpu().data.numpy()
                seg_pred = self.id2trainId(pred, self.id_to_trainid, reverse=True)
                
                predict = seg_pred.squeeze(0)
                # mask = get_color_pallete(predict, self.args.dataset)
                mask = PILImage.fromarray(predict.astype('uint8'))
                mask.save(os.path.join(args.outdir, os.path.splitext(filename[0])[0] + '.png'))
                print('Save mask to ' + os.path.splitext(filename[0])[0] + '.png' + ' Successfully!')

        synchronize()


if __name__ == '__main__':
    args = parse_args()
    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1
    if not args.no_cuda and torch.cuda.is_available():
        cudnn.benchmark = True
        args.device = "cuda"
    else:
        args.distributed = False
        args.device = "cpu"
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    # TODO: optim code
    outdir = '{}_{}_{}_{}'.format(args.model, args.backbone, args.dataset, args.method)
    args.outdir = os.path.join(args.save_dir, outdir)
    if args.save_pred:
        if (args.distributed and args.local_rank == 0) or args.distributed is False:
            if not os.path.exists(args.outdir):
                os.makedirs(args.outdir)

    logger = setup_logger("semantic_segmentation", args.save_dir, get_rank(),
                          filename='{}_{}_{}_log.txt'.format(args.model, args.backbone, args.dataset), mode='a+')

    evaluator = Evaluator(args, num_gpus)
    evaluator.eval()
    torch.cuda.empty_cache()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值