最近在写条件流模型代码时,再次遇到网络梯度为None的问题,有了上次的经验,这次排查起问题来就快多了。
首先,介绍一下条件流模型的相关内容,便于后续的展开。
流模型的简介可以看如下链接:
条件流模型是为了在流变换中考虑上下文因素而发明的,由transfomer和conditioner构成,transformer是传统的流模型变换,例如最简单的仿射变换等,与传统流模型不同的是,transformer变换的参数由conditioner决定,博主选用的是最简单的平面流;conditioner是一个可以输入上下文信息输出transformer变换参数的网络,由于博主处理的是时序信息,故选用LSTM当做conditioner。
言归正传,这次的问题在于使用流变换之前的变量进行反向传播梯度存在,而一旦使用经过流变换之后的参数进行反向传播,梯度就为None了,通过初步排查发现极大可能是因为使用conditoner的输出去不正确地更新transformer中的参数造成的。
原先的transformer是这样操作的,它首先会自身初始化self.w
, self.b
, self.u
,然后在后续步骤中使用conditioner的输出ws,bs,us更新self.w
, self.b
, self.u
这些参数,具体代码如下:
class PlanarFlow(nn.Module):
def __init__(self, D, activation=torch.tanh):
super().__init__()
self.D = D
self.w = nn.Parameter(torch.empty(D))
self.b = nn.Parameter(torch.empty(1))
self.u = nn.Parameter(torch.empty(D))
self.activation = activation
self.activation_derivative = ACTIVATION_DERIVATIVES[activation]
nn.init.normal_(self.w, mean=0, std=0.01)
nn.init.normal_(self.u, mean=0, std=0.01)
nn.init.normal_(self.b, mean=0, std=0.01)
def forward(self, z: torch.Tensor):
lin = (z @ self.w + self.b).unsqueeze(1)
f = z + self.u * self.activation(lin)
phi = self.activation_derivative(lin) * self.w
log_det = torch.log(torch.abs(1 + phi @ self.u) + 1e-4)
return f, log_det
def update_parameters(self, ws, bs, us):
self.w.data = ws.to(self.w.device)
self.b.data = bs.to(self.b.device)
self.u.data = us.to(self.u.device)
可以看到,由于transformer自身存在self.w
, self.b
, self.u
,与外部传递进来的ws,bs,us不是同一变量,所以这样的信息传递中断了ws,bs,us的计算图,从而造成梯度无法继续传播。
基于这个问题,想到的解决办法是要让 self.w
, self.b
, self.u
与外部传入的 ws
, bs
, us
共享完全相同的计算图,代码修改如下:
class PlanarFlow(nn.Module):
def __init__(self, D, activation=torch.tanh):
super().__init__()
self.D = D
self.activation = activation
self.activation_derivative = ACTIVATION_DERIVATIVES[activation]
# 不初始化参数,而是在update_parameters中动态初始化
def forward(self, z: torch.Tensor):
lin = (z @ self.w + self.b).unsqueeze(1)
f = z + self.u * self.activation(lin)
phi = self.activation_derivative(lin) * self.w
log_det = torch.log(torch.abs(1 + phi @ self.u) + 1e-4)
return f, log_det
def update_parameters(self, ws, bs, us):
# 将外部参数赋值给w,b,u
self.w = ws
self.b = bs
self.u = us
修改后,我们在初始化transformer时不会初始化self.w
, self.b
, self.u,
而是在每次使用conditioner的输出进行参数更新时初始化
self.w
, self.b
, self.u,
这样操作使得
self.w
, self.b
, self.u
与
ws,bs,us共享有相同的计算图,从而可以进行梯度的传播。
感谢您的阅读!如果对您有帮助,烦请点赞收藏呀!