训练过程source和target采用不同的BN参数,在测试阶段就不用指定是使用哪个域的BN参数了

行人重试别的无监督训练过程,在backbone中为了使源域和目标域的BN训练参数区分开,不相互影响,使用了下面的代码:

import torch
import torch.nn as nn

# Domain-specific BatchNorm

class DSBN2d(nn.Module):
    def __init__(self, planes):
        super(DSBN2d, self).__init__()
        self.num_features = planes
        self.BN_S = nn.BatchNorm2d(planes)
        self.BN_T = nn.BatchNorm2d(planes)

    def forward(self, x):
        if (not self.training):
            return self.BN_T(x)

        bs = x.size(0)
        assert (bs%2==0)
        split = torch.split(x, int(bs/2), 0)
        out1 = self.BN_S(split[0].contiguous())
        out2 = self.BN_T(split[1].contiguous())
        out = torch.cat((out1, out2), 0)
        return out

class DSBN1d(nn.Module):
    def __init__(self, planes):
        super(DSBN1d, self).__init__()
        self.num_features = planes
        self.BN_S = nn.BatchNorm1d(planes)
        self.BN_T = nn.BatchNorm1d(planes)

    def forward(self, x):
        if (not self.training):
            return self.BN_T(x)

        bs = x.size(0)
        assert (bs%2==0)
        split = torch.split(x, int(bs/2), 0)
        out1 = self.BN_S(split[0].contiguous())
        out2 = self.BN_T(split[1].contiguous())
        out = torch.cat((out1, out2), 0)
        return out

def convert_dsbn(model):
    for _, (child_name, child) in enumerate(model.named_children()):
        assert(not next(model.parameters()).is_cuda)
        if isinstance(child, nn.BatchNorm2d):
            m = DSBN2d(child.num_features)
            m.BN_S.load_state_dict(child.state_dict())
            m.BN_T.load_state_dict(child.state_dict())
            setattr(model, child_name, m)
        elif isinstance(child, nn.BatchNorm1d):
            m = DSBN1d(child.num_features)
            m.BN_S.load_state_dict(child.state_dict())
            m.BN_T.load_state_dict(child.state_dict())
            setattr(model, child_name, m)
        else:
            convert_dsbn(child)

def convert_bn(model, use_target=True):
    for _, (child_name, child) in enumerate(model.named_children()):
        assert(not next(model.parameters()).is_cuda)
        if isinstance(child, DSBN2d):
            m = nn.BatchNorm2d(child.num_features)
            if use_target:
                m.load_state_dict(child.BN_T.state_dict())
            else:
                m.load_state_dict(child.BN_S.state_dict())
            setattr(model, child_name, m)
        elif isinstance(child, DSBN1d):
            m = nn.BatchNorm1d(child.num_features)
            if use_target:
                m.load_state_dict(child.BN_T.state_dict())
            else:
                m.load_state_dict(child.BN_S.state_dict())
            setattr(model, child_name, m)
        else:
            convert_bn(child, use_target=use_target)

在前向过程中,将源域和目标域数据经过不同的BN层,这样相对于不加区分的来说,有了很大的提升。在测试阶段,我们也可以将目标域的数据只通过目标域对应的BN层,但是实验结果发现,这样的性能和将测试数据统一送入网络不加区分源域还是目标域的BN层,结果上没什么区别。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
在Pytorch中,BN层是Batch Normalization的缩写,用于在深度学习模型中对输入数据进行归一化处理。BN层的作用是通过对每个小批量的输入数据进行归一化,使得模型在训练过程中更加稳定和快速收敛。\[1\] 在Pytorch中,使用BN层的方法如下所示: ```python from torch import nn # 创建一个BN层对象,需要传入特征的通道数num_features作为参数 bn = nn.BatchNorm2d(num_features) # 输入数据 input = torch.randn(batch_size, num_features, height, width) # 将输入数据传入BN层进行处理 output = bn(input) ``` 其中,`num_features`表示输入数据的通道数,`batch_size`表示输入数据的批量大小,`height`和`width`表示输入数据的高度和宽度。\[1\] 在BN层的类中,还有一些其他的参数可以进行设置,例如`eps`表示用于数值稳定性的小值,默认为1e-5;`momentum`表示用于计算移动平均的动量,默认为0.1;`affine`表示是否学习BN层的参数γ和β,默认为True;`track_running_stats`表示是否跟踪训练过程中的统计数据,默认为True。\[2\] 需要注意的是,BN层的参数γ和β是否可学习是由`affine`参数控制的,默认情况下是可学习的,即可通过反向传播进行更新。而BN层的统计数据更新是在每一次训练阶段的`model.train()`后的`forward()`方法中自动实现的,而不是在梯度计算与反向传播中更新`optim.step()`中完成。\[3\] #### 引用[.reference_title] - *1* [一起来学PyTorch——神经网络(BN层)](https://blog.csdn.net/TomorrowZoo/article/details/129531658)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [pytorch中的BN层简介](https://blog.csdn.net/lpj822/article/details/109772094)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值