由于pytorch版本更新,在1.3及以上的版本中要求forward方法必须为静态方法。而在许多较早版本的代码实现中,梯度反转通常是这样写的:
from torch.autograd import Function
class GradReverse(Function):
def __init__(self, lambd):
self.lambd = lambd
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output * -self.lambd)
def grad_reverse(x, lambd=1.0):
return GradReverse(lambd)(x)
而上述代码在1.3及以上版本的pytorch中运行时会报如下错误:
RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method.
而在声明时添加 @staticmethod装饰器 的方案并不能解决该问题。
后来找到了一篇比较新的梯度反转实现方案,这里使用了其中的方案一,修改后的梯度反转实现代码如下:
实现方案参考这里
from typing import Any, Optional, Tuple
import torch
class GradReverse(torch.autograd.Function):
"""
重写自定义的梯度计算方式
"""
@staticmethod
def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
ctx.coeff = coeff
output = input * 1.0
return output
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
return grad_output.neg() * ctx.coeff, None
def grad_reverse(x, coeff):
return GradReverse.apply(x, coeff)
在前向传播中的调用示例:
def grad_reverse(x, coeff):
return GradReverse.apply(x, coeff)
def forward(self, x, coeff=1, reverse=False):
x = self.generator(x) # 已经定义的网络部分
if reverse:
x = grad_reverse(x, coeff)
x = self.classifier(x) # 已经定义的网络部分
return x