Pytorch实现梯度反转报错non-static forward method的解决办法

由于pytorch版本更新,在1.3及以上的版本中要求forward方法必须为静态方法。而在许多较早版本的代码实现中,梯度反转通常是这样写的:

from torch.autograd import Function


class GradReverse(Function):
    def __init__(self, lambd):
        self.lambd = lambd

    def forward(self, x):
        return x.view_as(x)

    def backward(self, grad_output):
        return (grad_output * -self.lambd)


def grad_reverse(x, lambd=1.0):
    return GradReverse(lambd)(x)

而上述代码在1.3及以上版本的pytorch中运行时会报如下错误:
RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method.

而在声明时添加 @staticmethod装饰器 的方案并不能解决该问题。
后来找到了一篇比较新的梯度反转实现方案,这里使用了其中的方案一,修改后的梯度反转实现代码如下:
实现方案参考这里

from typing import Any, Optional, Tuple
import torch

class GradReverse(torch.autograd.Function):
    """
        重写自定义的梯度计算方式
    """

    @staticmethod
    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
        ctx.coeff = coeff
        output = input * 1.0
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        return grad_output.neg() * ctx.coeff, None


def grad_reverse(x, coeff):
    return GradReverse.apply(x, coeff)

在前向传播中的调用示例:

	def grad_reverse(x, coeff):
    	return GradReverse.apply(x, coeff)
    	
    def forward(self, x, coeff=1, reverse=False):
        x = self.generator(x)         # 已经定义的网络部分
        if reverse:
            x = grad_reverse(x, coeff)
        x = self.classifier(x)		  # 已经定义的网络部分
        return x
  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值