PyTorch requires_grad/detach

前情提要

在排查GAN训练失真问题时,对pytorch中梯度相关知识进行了回顾,特此记录,以便自我回顾。

参考文章

测试代码

import torch
import torch.nn as nn


def test_requires_grad(requires_grad=False):
    torch.manual_seed(0)
    x = torch.randn(2, 2)
    print('============ input ======== \n {} \n ========================='.format(x))
    # x.requires_grad = True

    lin0 = nn.Linear(2, 2)
    lin1 = nn.Linear(2, 2)
    lin2 = nn.Linear(2, 2)
    lin3 = nn.Linear(2, 2)
    x1 = lin0(x)
    x2 = lin1(x1)
    for p in lin2.parameters():
        print('is_leaf: {}'.format(p.is_leaf))
        p.requires_grad = requires_grad
    x3 = lin2(x2)
    x4 = lin3(x3)
    x4.sum().backward()
    print(lin0.weight.grad)
    print(lin1.weight.grad)
    print(lin2.weight.grad)
    print(lin3.weight.grad)

    print(x.grad_fn)
    print(x1.grad_fn)
    print(x2.grad_fn)
    print(x3.grad_fn)
    print(x4.grad_fn)


def test_detach(detach=False):
    torch.manual_seed(0)
    x = torch.randn(2, 2)
    print('============ input ======== \n {} \n ========================='.format(x))
    x.requires_grad = True
    print(x.is_leaf)
    lin0 = nn.Linear(2, 2)
    lin1 = nn.Linear(2, 2)
    lin2 = nn.Linear(2, 2)
    lin3 = nn.Linear(2, 2)
    x1 = lin0(x)
    x2 = lin1(x1)
    if detach:
        x3 = lin2(x2.detach())
    else:
        x3 = lin2(x2)
    x4 = lin3(x3)
    x4.sum().backward()
    print(lin0.weight.grad)
    print(lin1.weight.grad)
    print(lin2.weight.grad)
    print(lin3.weight.is_leaf, lin3.weight.grad)

    print(x.grad_fn)
    print(x1.grad_fn)
    print(x2.grad_fn)
    print(x3.grad_fn)
    print(x4.grad_fn)


if __name__ == '__main__':

    # test_detach(True)
    # test_detach(False)

    test_requires_grad(True)
    test_requires_grad(False)



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值