Unified Deep Supervised Domain Adaptation and Generalization

论文概述

问题研究背景:supervised domain adaptation(SDA),源域有大量带标签的数据,目标域仅有少量可使用的数据

问题的难点:目标域数据不足导致概率分布在语义上很难对齐和区分。对齐指的是源域图片类别之间的关系与目标域图片类别之间的关系尽可能的相似;区分指的是同一个domain中,不同类别的特征要尽可能不同。

方法的优点:适应的速度快,仅需要少量的数据就可以获得很不错的效果。易于拓展成DG方法。

具体的方法实现
一般来说,一个DA模型对应的 f u n c t i o n function function可以被看作是两个函数的组合 f = h   o   g . f=h\ o\ g. f=h o g. 其中 g : χ → Z , χ g:\chi \rightarrow Z,\chi g:χZ,χ代表输入的特征空间,Z代表embedding space。 h : z → y 。 h:z\rightarrow y。 h:zyh代表利用embedding space中的特征进行预测的函数。对于源域的数据有 f s = h s   o   g s f_s=h_s\ o\ g_s fs=hs o gs,对于目标域的数据有 f t = h t   o   g t f_t=h_t\ o\ g_t ft=ht o gt
为了对齐源域和目标域之间数据的分布(即上述中的g),常使用下面形式的loss:
L C A ( g ) = d ( p ( g ( X s ) ) , p ( X t ) ) L_{CA}(g)=d(p(g(X^s)),p(X^t)) LCA(g)=d(p(g(Xs)),p(Xt))
这个loss的作用就是让源域和目标域中的数据经过映射以后无法被分辨。用原文中的话来说就是 I n   t h e   e m b e d d i n g   s p a c e   Z ,   f e a t u r e s   a r e   a s s u m e d   t o   b e   d o m a i n   i n v a r i n t In\ the\ embedding\ space\ Z,\ features\ are\ assumed\ to\ be\ domain\ invarint In the embedding space Z, features are assumed to be domain invarint
上面的loss对于unsupervised domain adaptation来说很适合,但是存在一个很大的问题:没办法保证不同域之间的语义是对齐的。SDA相较于UDA就可以利用label信息来对齐语义。loss被改写成如下的形式:
L S A ( g ) = ∑ a = 1 C d ( p ( g ( X a s ) , p ( g ( X a t ) ) ) ) L_{SA}(g)=\displaystyle\sum_{a=1}^Cd(p(g(X^s_a),p(g(X^t_a)))) LSA(g)=a=1Cd(p(g(Xas),p(g(Xat))))
上述loss被称为semantic alignment loss,d是一个度量距离的函数,具体来说,它用来度量不同域中同一类别样本的特征在被映射到embedding space之后的距离,我们希望这个距离越小越好。

光有上述loss还不行,因为模型学习的方向可能使得所有的类别分布趋同。为了使得不同域的不同类别之间距离尽可能变大,需要加上下面的separation loss
L S ( g ) = ∑ a , b ∣ a ≠ b k ( p ( g ( X a s ) ) , p ( g ( X b t ) ) ) L_S(g)=\displaystyle\sum_{a,b|a\ne b }k(p(g(X^s_a)),p(g(X^t_b))) LS(g)=a,ba=bk(p(g(Xas)),p(g(Xbt)))
k表示相似性函数,当源域中的a类与目标域中b类靠的太近时会施加惩罚。
最后是个用来分类的loss L C L_C LC,多任务分类一般使用的是交叉熵函数。最后loss的表达形式如下:
L C C S A ( f ) = L C ( h   o   g ) + L S A ( g ) + L S ( g ) L_{CCSA}(f)=L_C(h\ o\ g) + L_{SA}(g)+L_S(g) LCCSA(f)=LC(h o g)+LSA(g)+LS(g)
由于目标域中的数据很少,文章中提出逐点计算loss。具体来说,作者将目标域中每个样本与源域中的样本进行配对,每个样本对应的embedding feature之间进行loss的计算。
在这里插入图片描述

代码实现

模型部分

class NetWork(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Dropout(0.25),
            nn.Flatten(),
            nn.Linear(1152, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        feature = self.cnn(x)
        prediction = self.classifier(feature)
        return prediction, feature

模型部分,其中包括一个卷积神经网络用作特征提取器,它对应的就是function g。function h对应的是全连接层用来实现分类功能。模型会返回提取的特征feature以及分类器输出的概率分布。

loss

DA方法loss占据着核心地位。

def csa_loss(x, y, class_eq):
    margin = 1
    dist = F.pairwise_distance(x, y, 2)
    loss = class_eq * dist.pow(2)
    loss += (1 - class_eq) * (margin - dist).clamp(min=0).pow(2)
    return loss.mean()

x表示源域样本的embedding feature,y是目标域样本中的embedding feature,class_eq代表源域样本和目标域样本是否是同一种类。首先,计算两个特征图之间各像素点的二范数平方和,这是用于语义对齐的损失。后面,计算separation loss。semantic alighment loss和separation loss只会有一个存在。

训练过程

def train(net, loader):
    net.train()
    for i, (src_img, src_label, tar_img, tar_label) in enumerate(loader):
        src_img = src_img.to(device)
        src_label = src_label.to(device).long()
        tar_img = tar_img.to(device)
        tar_label = tar_label.to(device).long()

        src_pred, src_feature = net(src_img)
        _, tar_feature = net(tar_img)

        ce = entropy_loss(src_pred, src_label)
        csa = csa_loss(src_feature, tar_feature, (src_label == tar_label).float())

        loss = (1 - alpha) * ce + alpha * csa
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print("loss : %4f" % (loss.item()))

        for i, (tar_img, tar_label, src_img, src_label) in enumerate(loader):
            src_img = src_img.to(device)
            src_label = src_label.to(device).long()
            tar_img = tar_img.to(device)
            tar_label = tar_label.to(device)

            src_pred, src_feature = net(src_img)
            _, tar_feature = net(tar_img)

            ce = entropy_loss(src_pred, src_label)
            csa = csa_loss(src_feature, tar_feature,
                        (src_label == tar_label).float())
            loss = (1 - alpha) * ce + alpha * csa
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值