Pytorch_DDC(深度网络自适应,以resnet50为例)代码解读

最近跑了一下王晋东博士迁移学习简明手册上的深度网络自适应DDC(Deep Domain Confusion)的代码实现,在这里做一下笔记。
来源:Githup开源链接

总结代码的大体框架如下:
1.数据集选择:office31
2.模型选择:Resnet50

3.所用到的.py文件如下图所示:
在这里插入图片描述

下面来一个模块一个模块分析:

data_loader.py

from torchvision import datasets, transforms
import torch

#参数为 下载数据集的路径、batch_size、布尔型变量判断是否是训练集、数据加载器中的进程数
def load_data(data_folder, batch_size, train, kwargs):
    transform = {
        'train': transforms.Compose(
            [transforms.Resize([256, 256]),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])]),
        'test': transforms.Compose(
            [transforms.Resize([224, 224]),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])])
        }
    data = datasets.ImageFolder(root = data_folder, transform=transform['train' if train else 'test'])
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True if train else False)
    return data_loader

分析:
这部分代码与我之前写过的的finetune代码中的dataload部分大同小异,具体可参考我的上一篇文章Pytorch_finetune代码解读,这部分主要是处理实验所用的数据,使之可以直接输入到模型,参数在注释里列出。

bckbone.py

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable

#这里列出的是resnet50的网络
class ResNet50Fc(nn.Module):
    def __init__(self):
        super(ResNet50Fc, self).__init__()
        model_resnet50 = models.resnet50(pretrained=True)
        self.conv1 = model_resnet50.conv1
        self.bn1 = model_resnet50.bn1
        self.relu = model_resnet50.relu
        self.maxpool = model_resnet50.maxpool
        #resnet有四个block,每个block的层数分别为layers=[3,4,6,3]
        self.layer1 = model_resnet50.layer1
        self.layer2 = model_resnet50.layer2
        self.layer3 = model_resnet50.layer3
        self.layer4 = model_resnet50.layer4

        self.avgpool = model_resnet50.avgpool
        #获取全连接层的输入特征
        self.__in_features = model_resnet50.fc.in_features

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x

    def output_num(self):
        return self.__in_features
        
 network_dict = {"alexnet": AlexNetFc,
                "resnet18": ResNet18Fc,
                "resnet34": ResNet34Fc,
                "resnet50": ResNet50Fc,
                "resnet101": ResNet101Fc,
                "resnet152": ResNet152Fc}

分析:
这部分代码实现了预模型参数的下载,这里给出了多个模型,我们只关注resnet50的模型参数即可,所以我把其他模型的配置删去了。
注意这里需要了解resnet的基本网络架构,参考资料如下:
resnet18 50网络结构以及pytorch实现代码
ResNet网络结构分析
ResNet的pytorch实现与解析

mmd.py

import torch
import torch.nn as nn


class MMD_loss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            with torch.no_grad():
                XX = torch.mean(kernels[:batch_size, :batch_size])
                YY = torch.mean(kernels[batch_size:, batch_size:])
                XY = torch.mean(kernels[:batch_size, batch_size:])
                YX = torch.mean(kernels[batch_size:, :batch_size])
                loss = torch.mean(XX + YY - XY - YX)
            torch.cuda.empty_cache()
            return loss

分析:
这部分代码是深度网络自适应的核心之一,这里使用mmd算法作为自适应度量。
下图loss计算公式中红框标出部分,说明了这一部分代码的作用,作为总的损失函数的一部份组成,主要度量源域和目标域的数据分布是否达到一致。

注意: coral.py模块也是一种度量方法,我们这里使用了mmd方法,coral就不再列出,功能是一样的。

在这里插入图片描述

model.py

import torch.nn as nn
from Coral import CORAL
import mmd
import backbone


#注意有两个网络流向,adaptation layer跟mmd有关(比较两个网络流向)
#classifier跟网络本身有关
class Transfer_Net(nn.Module):
    def __init__(self, num_class, base_net='resnet50', transfer_loss='mmd', use_bottleneck=False, bottleneck_width=256, width=1024):
        super(Transfer_Net, self).__init__()
        #引入backbone.py
        #bottleneck 的层,用来将最高维的特征进行降维,然后进行距离计算。
        #定义网络,获得网络名
        self.base_network = backbone.network_dict[base_net]()
        #确定使用bottleneck层
        self.use_bottleneck = use_bottleneck
        #定义mmd距离来计算transfer_loss
        self.transfer_loss = transfer_loss
        #定义(瓶颈)全连接层、规范化
        bottleneck_list = [nn.Linear(self.base_network.output_num(
        ), bottleneck_width), nn.BatchNorm1d(bottleneck_width), nn.ReLU(), nn.Dropout(0.5)]
        #合并进程
        self.bottleneck_layer = nn.Sequential(*bottleneck_list)


        # 定义(分类)全连接层、规范化
        classifier_layer_list = [nn.Linear(self.base_network.output_num(), width), nn.ReLU(), nn.Dropout(0.5),
                                 nn.Linear(width, num_class)]
        # 合并进程
        self.classifier_layer = nn.Sequential(*classifier_layer_list)
        #???
        self.bottleneck_layer[0].weight.data.normal_(0, 0.005)
        self.bottleneck_layer[0].bias.data.fill_(0.1)
        for i in range(2):
            self.classifier_layer[i * 3].weight.data.normal_(0, 0.01)
            self.classifier_layer[i * 3].bias.data.fill_(0.0)

    def forward(self, source, target):
        #选择网络
        source = self.base_network(source)
        target = self.base_network(target)
        #源域的数据进入网络
        source_clf = self.classifier_layer(source)
        #是否使用瓶颈层,这里不适用在前面改为False
        if self.use_bottleneck:
            source = self.bottleneck_layer(source)
            target = self.bottleneck_layer(target)

        #加入适应层!!!
        #分析两个不同网络的距离分布
        transfer_loss = self.adapt_loss(source, target, self.transfer_loss)
        return source_clf, transfer_loss

    def predict(self, x):
        features = self.base_network(x)
        clf = self.classifier_layer(features)
        return clf

    #引入mmd,这里参数为源域网络矩阵、目标域矩阵网络矩阵、计算loss的方法
    def adapt_loss(self, X, Y, adapt_loss):
        """Compute adaptation loss, currently we support mmd and coral

        Arguments:
            X {tensor} -- source matrix
            Y {tensor} -- target matrix
            adapt_loss {string} -- loss type, 'mmd' or 'coral'. You can add your own loss

        Returns:
            [tensor] -- adaptation loss tensor
        """
        if adapt_loss == 'mmd':
            mmd_loss = mmd.MMD_loss()
            loss = mmd_loss(X, Y)
        elif adapt_loss == 'coral':
            loss = CORAL(X, Y)
        else:
            loss = 0
        return loss

分析:
那么这一部分代码就是整个深度网络自适应算法的最核心之处,实现了源域和目标域的距离的输出。得出了上面图中公式右端的第二个loss的的具体数值。这里bottleneck我们先不要管,只关注self.classifier_layer部分,这一部分是常规的网络训练的构建。下面forward中self.adapt_loss函数是这部分代码核心指出,理解时要以这个函数为核心,向外延申,这部分理解透了,自适应部分也就明白了,后续就是一些常规的模型训练。

utlis.py

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

分析:
用来辅助acc的平均loss的计算。

main.py

import argparse
import torch
import os
import data_loader
import models
import utils
import numpy as np

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log = []

# Command setting
parser = argparse.ArgumentParser(description='DDC_DCORAL')
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--batchsize', type=int, default=32)
parser.add_argument('--src', type=str, default='amazon')
parser.add_argument('--tar', type=str, default='webcam')
parser.add_argument('--n_class', type=int, default=31)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--n_epoch', type=int, default=10)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--decay', type=float, default=5e-4)
parser.add_argument('--data', type=str, default='D:\迁移学习\Original_images')
parser.add_argument('--early_stop', type=int, default=20)
parser.add_argument('--lamb', type=float, default=10)
parser.add_argument('--trans_loss', type=str, default='mmd')
args = parser.parse_args()

def test(model, target_test_loader):
    model.eval()
    test_loss = utils.AverageMeter()
    correct = 0
    criterion = torch.nn.CrossEntropyLoss()
    len_target_dataset = len(target_test_loader.dataset)
    with torch.no_grad():
        for data, target in target_test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            s_output = model.predict(data)
            loss = criterion(s_output, target)
            test_loss.update(loss.item())
            pred = torch.max(s_output, 1)[1]
            correct += torch.sum(pred == target)
    acc = 100. * correct / len_target_dataset
    return acc

#参数:源域数据、目标域数据、测试数据、模型数据、优化器数据
def train(source_loader, target_train_loader, target_test_loader, model, optimizer):
    len_source_loader = len(source_loader)
    len_target_loader = len(target_train_loader)
    best_acc = 0
    stop = 0
    for e in range(args.n_epoch):
        stop += 1

        #传递计算数据的函数
        train_loss_clf = utils.AverageMeter()
        train_loss_transfer = utils.AverageMeter()
        train_loss_total = utils.AverageMeter()
        #训练模式
        model.train()
        #iter:用来生成迭代器
        iter_source, iter_target = iter(source_loader), iter(target_train_loader)
        #定义每次循环的次数
        n_batch = min(len_source_loader, len_target_loader)
        #定义损失函数
        criterion = torch.nn.CrossEntropyLoss()
        for _ in range(n_batch):
            #获得数据与标签(target域没有标签)
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            #选择设备
            data_source, label_source = data_source.to(
                DEVICE), label_source.to(DEVICE)
            data_target = data_target.to(DEVICE)

            optimizer.zero_grad()
            #将数据投入到模型,这里model为forward,因为model参数在main中已经定义完毕
            label_source_pred, transfer_loss = model(data_source, data_target)
            #计算网络的loss
            clf_loss = criterion(label_source_pred, label_source)
            #核心部分,计算源域网络和目标域网络的loss,由两部分组成
            #一个是原网络的损失函数,另一个是两个域的mmd距离
            loss = clf_loss + args.lamb * transfer_loss
            loss.backward()
            optimizer.step()

            train_loss_clf.update(clf_loss.item())
            train_loss_transfer.update(transfer_loss.item())
            train_loss_total.update(loss.item())

        # Test,获取准确率,这里每次训练都测试一下
        acc = test(model, target_test_loader)
        log.append([train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg])
        np_log = np.array(log, dtype=float)
        np.savetxt('train_log.csv', np_log, delimiter=',', fmt='%.6f')
        print('Epoch: [{:2d}/{}], cls_loss: {:.4f}, transfer_loss: {:.4f}, total_Loss: {:.4f}, acc: {:.4f}'.format(
                    e, args.n_epoch, train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg, acc))
        if best_acc < acc:
            best_acc = acc
            stop = 0
        #连续20次acc无增加,跳出循环
        if stop >= args.early_stop:
            break
    print('Transfer result: {:.4f}'.format(best_acc))
    
#下载数据集,参数为源域文件名,目标域文件名,数据所在目录
#于finetune中dataload.py里实现功能一致
def load_data(src, tar, root_dir):
    folder_src = os.path.join(root_dir, src)
    folder_tar = os.path.join(root_dir, tar)
    source_loader = data_loader.load_data(
        folder_src, args.batchsize, True, {'num_workers': 4})
    target_train_loader = data_loader.load_data(
        folder_tar, args.batchsize, True, {'num_workers': 4})
    target_test_loader = data_loader.load_data(
        folder_tar, args.batchsize, False, {'num_workers': 4})
    return source_loader, target_train_loader, target_test_loader


if __name__ == '__main__':
    torch.manual_seed(0)

    source_name = "amazon"
    target_name = "webcam"

    print('Src: %s, Tar: %s' % (source_name, target_name))

    source_loader, target_train_loader, target_test_loader = load_data(
        source_name, target_name, args.data)

    #网络模型选择,参数为:最后输出类别数(31)、loss距离名(mmd)、网络模型名(resnet50)
    model = models.Transfer_Net(
        args.n_class, transfer_loss=args.trans_loss, base_net=args.model).to(DEVICE)

    #优化器
    #注意最后训练的层学习率扩大十倍
    optimizer = torch.optim.SGD([
        {'params': model.base_network.parameters()},
        {'params': model.bottleneck_layer.parameters(), 'lr': 10 * args.lr},
        {'params': model.classifier_layer.parameters(), 'lr': 10 * args.lr},
    ], lr=args.lr, momentum=args.momentum, weight_decay=args.decay)

    #训练数据+测试
    train(source_loader, target_train_loader,
          target_test_loader, model, optimizer)

分析:
对以上模块进行整合,训练网络,得到loss和准确率并打印。

end

  • 19
    点赞
  • 116
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值