条件流模型loss.backward()后conditioner梯度为None

最近在写条件流模型代码时,再次遇到网络梯度为None的问题,有了上次的经验,这次排查起问题来就快多了。

首先,介绍一下条件流模型的相关内容,便于后续的展开。

流模型的简介可以看如下链接:

生成模型-流模型(Flow)

条件流模型是为了在流变换中考虑上下文因素而发明的,由transfomerconditioner构成,transformer是传统的流模型变换,例如最简单的仿射变换等,与传统流模型不同的是,transformer变换的参数由conditioner决定,博主选用的是最简单的平面流;conditioner是一个可以输入上下文信息输出transformer变换参数的网络,由于博主处理的是时序信息,故选用LSTM当做conditioner

言归正传,这次的问题在于使用流变换之前的变量进行反向传播梯度存在,而一旦使用经过流变换之后的参数进行反向传播,梯度就为None了,通过初步排查发现极大可能是因为使用conditoner的输出去不正确地更新transformer中的参数造成的。

原先的transformer是这样操作的,它首先会自身初始化self.wself.bself.u ,然后在后续步骤中使用conditioner的输出ws,bs,us更新self.wself.bself.u 这些参数,具体代码如下:

class PlanarFlow(nn.Module):
    def __init__(self, D, activation=torch.tanh):
        super().__init__()
        self.D = D
        self.w = nn.Parameter(torch.empty(D))
        self.b = nn.Parameter(torch.empty(1))
        self.u = nn.Parameter(torch.empty(D))
        self.activation = activation
        self.activation_derivative = ACTIVATION_DERIVATIVES[activation]
        nn.init.normal_(self.w, mean=0, std=0.01)
        nn.init.normal_(self.u, mean=0, std=0.01)
        nn.init.normal_(self.b, mean=0, std=0.01)

    def forward(self, z: torch.Tensor):
        lin = (z @ self.w + self.b).unsqueeze(1)
        f = z + self.u * self.activation(lin)
        phi = self.activation_derivative(lin) * self.w
        log_det = torch.log(torch.abs(1 + phi @ self.u) + 1e-4)
        return f, log_det

    def update_parameters(self, ws, bs, us):
        self.w.data = ws.to(self.w.device)
        self.b.data = bs.to(self.b.device)
        self.u.data = us.to(self.u.device)

可以看到,由于transformer自身存在self.wself.bself.u ,与外部传递进来的ws,bs,us不是同一变量,所以这样的信息传递中断了ws,bs,us的计算图,从而造成梯度无法继续传播。

基于这个问题,想到的解决办法是要让 self.wself.bself.u 与外部传入的 wsbsus 共享完全相同的计算图,代码修改如下:

class PlanarFlow(nn.Module):
    def __init__(self, D, activation=torch.tanh):
        super().__init__()
        self.D = D
        self.activation = activation
        self.activation_derivative = ACTIVATION_DERIVATIVES[activation]
        # 不初始化参数,而是在update_parameters中动态初始化

    def forward(self, z: torch.Tensor):
        lin = (z @ self.w + self.b).unsqueeze(1)
        f = z + self.u * self.activation(lin)
        phi = self.activation_derivative(lin) * self.w
        log_det = torch.log(torch.abs(1 + phi @ self.u) + 1e-4)
        return f, log_det

    def update_parameters(self, ws, bs, us):
        # 将外部参数赋值给w,b,u
        self.w = ws
        self.b = bs
        self.u = us

修改后,我们在初始化transformer时不会初始化self.wself.bself.u,而是在每次使用conditioner的输出进行参数更新时初始化self.wself.bself.u,这样操作使得self.wself.bself.uws,bs,us共享有相同的计算图,从而可以进行梯度的传播。

感谢您的阅读!如果对您有帮助,烦请点赞收藏呀!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值