Pytorch_DDC(深度网络自适应,以resnet50为例)代码解读

本文分享了迁移学习中深度网络自适应DDC的代码实践,详细介绍了数据加载、模型选择及核心自适应层的实现过程。

最近跑了一下王晋东博士迁移学习简明手册上的深度网络自适应DDC(Deep Domain Confusion)的代码实现,在这里做一下笔记。
来源:Githup开源链接

总结代码的大体框架如下:
1.数据集选择:office31
2.模型选择:Resnet50

3.所用到的.py文件如下图所示:
在这里插入图片描述

下面来一个模块一个模块分析:

data_loader.py

from torchvision import datasets, transforms
import torch

#参数为 下载数据集的路径、batch_size、布尔型变量判断是否是训练集、数据加载器中的进程数
def load_data(data_folder, batch_size, train, kwargs):
    transform = {
   
   
        'train': transforms.Compose(
            [transforms.Resize([256, 256]),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])]),
        'test': transforms.Compose(
            [transforms.Resize([224, 224]),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])])
        }
    data = datasets.ImageFolder(root = data_folder, transform=transform['train' if train else 'test'])
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True if train else False)
    return data_loader

分析:
这部分代码与我之前写过的的finetune代码中的dataload部分大同小异,具体可参考我的上一篇文章Pytorch_finetune代码解读,这部分主要是处理实验所用的数据,使之可以直接输入到模型,参数在注释里列出。

bckbone.py

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable

#这里列出的是resnet50的网络
class ResNet50Fc(nn.Module):
    def __init__(self):
        super(ResNet50Fc, self).__init__()
        model_resnet50 = models.resnet50(pretrained=True)
        self.conv1 = model_resnet50.conv1
        self.bn1 = model_resnet50.bn1
        self.relu = model_resnet50.relu
        self.maxpool = model_resnet50.maxpool
        #resnet有四个block,每个block的层数分别为layers=[3,4,6,3]
        self.layer1 = model_resnet50.layer1
        self.layer2 = model_resnet50.layer2
        self.layer3 = model_resnet50.layer3
        self.layer4 = model_resnet50.layer4

        self.avgpool = model_resnet50.avgpool
        #获取全连接层的输入特征
        self.__in_features = model_resnet50.fc.in_features

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x

    def output_num(self):
        return self.__in_features
        
 network_dict = {
   
   "alexnet": AlexNetFc,
                "resnet18": ResNet18Fc,
                "resnet34": ResNet34Fc,
                "resnet50": ResNet50Fc,
                "resnet101": ResNet101Fc,
                "resnet152": ResNet152Fc}

分析:
这部分代码实现了预模型参数的下载,这里给出了多个模型,我们只关注resnet50的模型参数即可,所以我把其他模型的配置删去了。
注意这里需要了解resnet的基本网络架构,参考资料如下:
resnet18 50网络结构以及pytorch实现代码
ResNet网络结构分析
ResNet的pytorch实现与解析

mmd.py

import torch
import torch.nn as nn


class MMD_loss(nn.Module):
    def 
评论 13
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值