[迁移学习]域自适应代码解析

一、概述

        代码来自:https://github.com/jindongwang/transferlearning,可以前往github下载代码,本文涉及的代码的位置为:Code->DeepDA。理论基础可以参见:[迁移学习]域自适应

        整体网络结构如下:可以视为一个分类网络(如Resnet50)+一个fc_adapt模块组成 。其损失函数为L=L_c(x_i,y_i)+\lambda Distance(D_s,D_t),即原来的交叉熵损失函数后面添加一个衡量源域和目标域之间距离的损失函数,一般为MMD,超参数\lambda用来控制此损失函数的权重。

 二、代码分析

        1.dataloader

        转到main函数,可以看到与dataloader相关的代码为:

source_loader, target_train_loader, target_test_loader, n_class = load_data(args)

        跳转后可以看到,在load_data(.)函数中数据集被分成下面三个部分,分别对应源域、目标域训练集、目标域测试集。

source_loader,target_train_loader,target_test_loader

        而实际起作用的是data_loader中的load_data,该函数重写了DataSet类并调用了dataloader

data = datasets.ImageFolder(root=data_folder, transform=transform['train' if train else 'test'])
    data_loader = get_data_loader(data, batch_size=batch_size, 
                                shuffle=True if train else False, 
                                num_workers=num_workers, **kwargs, drop_last=True if train else False)
    n_class = len(data.classes)

        2.model

        main函数中,与model有关的代码为:

model = get_model(args)
def get_model(args):
    model = models.TransferNet(
        args.n_class, transfer_loss=args.transfer_loss, base_net=args.backbone, max_iter=args.max_iter, use_bottleneck=args.use_bottleneck).to(args.device)
    return model

        该函数实际是从models文件中的TransferNet类中提取网络模型,继续转到TransferNet,该类有以下几个参数:类别个数,骨干网络类型,transfer_loss类型,以及一些调整骨干网络的参数。

class TransferNet(nn.Module):
    def __init__(self, num_class, base_net='resnet50', 
        transfer_loss='mmd', use_bottleneck=True, 
        bottleneck_width=256, max_iter=1000, **kwargs):

        随后看该网络的前向传递函数,基本可以归类为以下几部:

                ①骨干网络提取

source = self.base_network(source)
target = self.base_network(target)

                ②源域分类

source_clf = self.classifier_layer(source)

                代码中的分类器是一个全连接层,对应的参数是隐藏层通道数和输出类别数

self.classifier_layer = nn.Linear(feature_dim, num_class)

                ③源域分类损失函数计算

clf_loss = self.criterion(source_clf, source_label)

                代码中采用了一个交叉熵损失函数

self.criterion = torch.nn.CrossEntropyLoss()

                ④迁移学习

                这一小节是域自适应迁移学习和传统分类网络的最大区别,除了传统的cls_loss之外,该网络还计算了transfer_loss。

                代码中提供了lmmd,daan,bnm三种方式,其代码基本大同小异,这里选取lmmd进行解析,lmmd代码如下:

if self.transfer_loss == "lmmd":
    kwargs['source_label'] = source_label
    target_clf = self.classifier_layer(target)
    kwargs['target_logits'] = torch.nn.functional.softmax(target_clf, dim=1)

                该段代码的主要功能是从参数列表中获取source_label,同时使用上面同样的分类器对由骨干网络提取到的目标域特征进行分类(使用softmax进行分类,结果记录为target_logits),随后源域特征和目标域特征以及参数会被送入adapt_loss(.)模块用以计算transfer_loss

                同时从后续代码得知,如果使用最简单的mmd是不需要经过这一步处理的

transfer_loss = self.adapt_loss(source, target, **kwargs)
self.adapt_loss = TransferLoss(**transfer_loss_args)

                TransferLoss类为一个transfer_loss的提取模块,提供了6种不同的损失函数,这里以mmd为例。

if loss_type == "mmd":
    self.loss_func = MMDLoss(**kwargs)

                MMDLoss的完整代码如下:

import torch
import torch.nn as nn

class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            YY = torch.mean(kernels[batch_size:, batch_size:])
            XY = torch.mean(kernels[:batch_size, batch_size:])
            YX = torch.mean(kernels[batch_size:, :batch_size])
            loss = torch.mean(XX + YY - XY - YX)
            return loss

                里面有大量的数学变换在这里就不进行深究了。其前向函数的作用是将源域和目标域之间的差距计算出来并返回loss

                ⑤返回函数

                该模型的前向传递函数最后会返回两个参数:clf_loss, transfer_loss,这两个参数将会被用于后面的反向传递。

        3.训练

        训练过程中,会从source_loader中提取源域图像和标签,从target_train_loader中提取目标域图像(不需要目标域的标签)。

        然后将数据这三个数据送入模型,得到clf_loss和transfer_loss。最后将transfer_loss×权重系数lambda后与clf_loss相加后可以得到最终的损失函数loss。

clf_loss, transfer_loss = model(data_source, data_target, label_source)
loss = clf_loss + args.transfer_loss_weight * transfer_loss

        再然后就是自动的反向传递:

optimizer.zero_grad()
loss.backward()
optimizer.step()

        4.测试

        测试过程与上面的训练大同小异,主要是不再需要前向传递。使用的是model中的predict函数而不是默认的前向传递函数

s_output = model.predict(data)

        该函数与之前的前向传递函数相比没有adapt_loss模块:

def predict(self, x):
    features = self.base_network(x)
    x = self.bottleneck_layer(features)
    clf = self.classifier_layer(x)
    return clf
  • 6
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值