pytorch实现梯度反转层(Gradient Reversal Layer)

本文介绍了如何在PyTorch中使用Function实现梯度反转层(GRL),以用于域适应网络(DANN)。通过两种不同的实现方式展示了如何在自定义层中应用GRL,详细解释了代码实现过程,并通过测试代码展示了GRL在反向传播中梯度反转的效果。

问题

在有些任务中,我们需要实现梯度反转层(Gradient Reversal Layer),目的是为了在梯度反向传播时,经过计算图某个节点之后梯度往反向更新(DANN网络中便需要GRL)。pytorch提供了Function用于实现这个方法,但是看网上的博客并没有详细的实现方法的用法。

实现方式

pytorch中的Function

pytorch自定义layer有两种方式:

  • 通过继承torch.nn.Module类来实现拓展。只需重新实现__init__forward函数。
  • 通过继承torch.autograd.Function,除了要实现__init__forward函数,还要实现backward函数(就是自定义求导规则)。
    方式一看着简单,但是当要定义自己的求导方式时,就要自己实现backward,也就是所谓的Extending torch.autograd

关于Function的学习可以参看这个博客:https://blog.csdn.net/qq_27825451/article/details/95189376

因为可以自定义求导的方式,所以我们使用Function实现GRL

实现代码

定义一些无关的类便于测试使用

from typing import Any, Optional, Tuple
from torch.autograd import Function
import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy

random.seed(0)
torch.manual_seed(0)
numpy.random.seed(0)

第一种实现方式

  1. 定义一个继承自FunctionGradientReverseFunction
class GradientReverseFunction(Function):
    """
    重写自定义的梯度计算方式
    """
    @staticmethod
    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
        ctx.coeff = coeff
        output = input * 1.0
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        return grad_output.neg() * ctx.coeff, None
  1. 在需要反转的代码中使用GRF
class NormalClassifier(nn.Module):

    def __init__(self, num_features, num_classes, GRL=None):
        super().__init__()
        self.linear = nn.Linear(num_features, num_classes)
        if GRL:
            self.grl = GRL()

    def forward(self, x):
        if getattr(self, 'grl', None) is not None:
            x = GradientReverseFunction.apply(x)                # 注意这里
        return self.linear(x)

第二种实现方式

如果感觉刚才使用apply的应用方式不习惯,可以包装成一个层

  1. 把第一种方式中的GradientRevers
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MaXuwl

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值