最近跑了一下王晋东博士迁移学习简明手册上的深度网络自适应DDC(Deep Domain Confusion)的代码实现,在这里做一下笔记。
来源:Githup开源链接
总结代码的大体框架如下:
1.数据集选择:office31
2.模型选择:Resnet50
3.所用到的.py文件如下图所示:

下面来一个模块一个模块分析:
data_loader.py
from torchvision import datasets, transforms
import torch
#参数为 下载数据集的路径、batch_size、布尔型变量判断是否是训练集、数据加载器中的进程数
def load_data(data_folder, batch_size, train, kwargs):
transform = {
'train': transforms.Compose(
[transforms.Resize([256, 256]),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])]),
'test': transforms.Compose(
[transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
}
data = datasets.ImageFolder(root = data_folder, transform=transform['train' if train else 'test'])
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True if train else False)
return data_loader
分析:
这部分代码与我之前写过的的finetune代码中的dataload部分大同小异,具体可参考我的上一篇文章Pytorch_finetune代码解读,这部分主要是处理实验所用的数据,使之可以直接输入到模型,参数在注释里列出。
bckbone.py
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable
#这里列出的是resnet50的网络
class ResNet50Fc(nn.Module):
def __init__(self):
super(ResNet50Fc, self).__init__()
model_resnet50 = models.resnet50(pretrained=True)
self.conv1 = model_resnet50.conv1
self.bn1 = model_resnet50.bn1
self.relu = model_resnet50.relu
self.maxpool = model_resnet50.maxpool
#resnet有四个block,每个block的层数分别为layers=[3,4,6,3]
self.layer1 = model_resnet50.layer1
self.layer2 = model_resnet50.layer2
self.layer3 = model_resnet50.layer3
self.layer4 = model_resnet50.layer4
self.avgpool = model_resnet50.avgpool
#获取全连接层的输入特征
self.__in_features = model_resnet50.fc.in_features
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
def output_num(self):
return self.__in_features
network_dict = {
"alexnet": AlexNetFc,
"resnet18": ResNet18Fc,
"resnet34": ResNet34Fc,
"resnet50": ResNet50Fc,
"resnet101": ResNet101Fc,
"resnet152": ResNet152Fc}
分析:
这部分代码实现了预模型参数的下载,这里给出了多个模型,我们只关注resnet50的模型参数即可,所以我把其他模型的配置删去了。
注意这里需要了解resnet的基本网络架构,参考资料如下:
resnet18 50网络结构以及pytorch实现代码
ResNet网络结构分析
ResNet的pytorch实现与解析
mmd.py
import torch
import torch.nn as nn
class MMD_loss(nn.Module):
def

本文分享了迁移学习中深度网络自适应DDC的代码实践,详细介绍了数据加载、模型选择及核心自适应层的实现过程。
最低0.47元/天 解锁文章
2757





