pytorch框架中的梯度反向传播

pytorch框架下的loss.backward()是机器学习代码中的重要部分,在最近的一个项目中需要提出模型中间量梯度进行处理,发现开源代码及网上的一些说法存在错误,并且".backward()"方法中有一些基础知识令人混淆,专门写了一个简单的模型测试了torch中的梯度反向传播过程。文中的计算图指torch框架使用的动态计算图结构。

模型构成

使用3个MLP模型测试梯度回传过程。

# coding: utf-8
# only train one step

import torch.nn as nn
import torch
import os
import argparse


class MLPem0(nn.Module):
    def __init__(self, num_embeddings0, embedding_dim0, model_dim0) -> None:
        super(MLPem0, self).__init__()

        self.embed = nn.Embedding(num_embeddings0, embedding_dim0)
        self.layer = nn.Linear(embedding_dim0, model_dim0)
    
    def forward(self, input):
        d1 = self.embed(input)
        out = self.layer(d1)
        return out


class MLP1(nn.Module):
    def __init__(self, embedding_dim1, model_dim1) -> None:
        super(MLP1, self).__init__()

        # self.embed = nn.Embedding(num_embeddings, embedding_dim)
        self.layer = nn.Linear(embedding_dim1, model_dim1) # embedding dim1 = model dim 0
    
    def forward(self, input):
        # d1 = self.embed(input)
        out = self.layer(input)

        return out
    

class MLP2L(nn.Module):
    def __init__(self, embedding_dim2, model_dim2) -> None:
        super(MLP2L, self).__init__()

        # self.embed = nn.Embedding(num_embeddings, embedding_dim)
        self.layer = nn.Linear(embedding_dim2, model_dim2) # embedding dim2 = model dim 1; model dim 2 = label dim
    
    def forward(self, input, label):
        # d1 = self.embed(input)
        out = self.layer(input)
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(out, label)  #label one-hot 

        return out, loss

def parameters1():
    parser = argparse.ArgumentParser()
    parser.add_argument("--embedding_dim", type=int, help="feature dimention", 
                        default=6)
    parser.add_argument("--num_embedding", type=int, help="number of embending words", 
                        default=10)
    parser.add_argument("--model_dim01", type=int, help="feature dimention",
                        default=6) 
    parser.add_argument("--model_dim2", type=int, help="output dimention", default=3)
    parser.add_argument("--save_name",type=str, default = 'pytorch_model_parabatch')
    parser.add_argument("--client_blocks", type=int, help='number of client\'s blocks', default=1)
    args = parser.parse_args()

    return args

def main():
    torch.manual_seed(42)
    args = parameters1()
    
    model1 = MLPem0(args.num_embedding, args.embedding_dim, args.model_dim01)
    model2 = MLP1(args.embedding_dim, args.model_dim01)
    model3 = MLP2L(args.embedding_dim, args.model_dim2)

    optim1 = torch.optim.Adam(model1.parameters(), 0.01)
    optim2 = torch.optim.Adam(model2.parameters(), 0.01)
    optim3 = torch.optim.Adam(model3.parameters(), 0.01)
    # optim3 = torch.optim.Adam([{'params':model1.parameters()}, {'params':model2.parameters()}, {'params':model3.parameters()}], 0.01) # 使用这个optimizer 需要使用f1 f2作为中间变量

    input = torch.LongTensor([[2,1,3,0,9]])
    label = torch.LongTensor([[0,1,0]])
    
    f1 = model1(input) # [1,5]-->[1,5,6]-->[1,5,6]

    ff1 = f1.clone().detach().requires_grad_(True)
    f2 = model2(ff1) # -->[1,5,6]
    print(f2)
    ff2 = f2.clone().detach().requires_grad_(True)
    out, loss = model3(ff2, label) #  -->[1,5,3]

    for name, parameter in model1.named_parameters():
        print(name)
        print(parameter)

    for name, parameter in model2.named_parameters():
        print(name)
        print(parameter)

    for name, parameter in model3.named_parameters():
        print(name)
        print(parameter)

    
    print("out is : {}".format(out))
    print("loss is: {}".format(loss.item()))
    print(loss)
    optim3.zero_grad()
    loss.backward() # calculate gradients of model3's parameters 
    optim3.step()  # model3 step
    # print(ff2.grad)

    optim2.zero_grad()
    f2.backward(ff2.grad)
    optim2.step() # model2 step
    # print(ff1.grad)

    optim1.zero_grad()
    f1.backward(ff1.grad)
    optim1.step()  # model1 step

    for name, parameter in model1.named_parameters():
        print(name)
        print(parameter)

    for name, parameter in model2.named_parameters():
        print(name)
        print(parameter)

    for name, parameter in model3.named_parameters():
        print(name)
        print(parameter)

# output
# f2
tensor([[[ 0.2715, -0.1838, -0.0243,  0.0205,  0.4656, -0.2684],
         [ 0.3276, -0.4536,  1.0650, -0.1313,  0.3529,  0.0665],
         [ 0.8405,  0.1754, -0.4132,  0.6693,  0.2858,  0.3899],
         [ 0.5250,  0.0374,  0.3187,  0.4030, -0.4056, -0.4441],
         [ 0.5251, -0.1842,  0.2542,  0.3018,  0.2963,  0.0432]]],
       grad_fn=<ViewBackward0>) 

out is : tensor([[[ 0.4878, -0.5974, -0.0024],
         [ 0.5564, -0.5523, -0.1109],
         [ 0.4062, -0.7355,  0.2404],
         [ 0.3436, -0.4390,  0.1391],
         [ 0.4473, -0.6542,  0.0482]]], grad_fn=<ViewBackward0>)
loss is: 1.6084858179092407 # loss.item
tensor(1.6085, grad_fn=<NllLoss2DBackward0>)
# ff2.grad
tensor([[[-0.1723,  0.0278,  0.0554,  0.1180,  0.0238, -0.0246],
         [ 0.0896, -0.0918, -0.0331,  0.0252,  0.0503,  0.0087],
         [ 0.0309,  0.0177, -0.0098, -0.0470, -0.0251,  0.0065],
         [ 0.0240,  0.0251, -0.0064, -0.0491, -0.0276,  0.0052],
         [ 0.0278,  0.0211, -0.0060, -0.0471, -0.0214,  0.0042]]])

注意

# 分离计算图时:
#....
ff2 = f2.clone().detach().requires_grad_(True) 
out, loss=model3(ff2)
# 使用.detach()将中间量f2从计算图中分离出来,如不分离,打印f2的梯度将会收获一个警告:
>>> UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/core/TensorBody.h:480.)
>>>  print(f2.grad) # f2 作为中间变量,不属于叶节点,其梯度属于计算f2的输出层的参数梯度的中间量,计算后不会保留,所以此处为None. 如果需要该节点的梯度保留以供调取,可以使用.retain_grad()
>>> None

# 模型简单串联时:
#...
f1=model1(input)
f1.rtain_grad()
f2=model2(f1)
out, loss=model3(f2)
#...
loss.backward()
#....
print(f1.grad)
>>> tensor([[[ 1.4645e-05, -4.5453e-03, -1.3644e-02,  2.8668e-02,  3.3055e-02,
           1.2564e-02],
         [ 8.2688e-03, -3.2370e-02,  2.4661e-02,  7.1180e-03, -1.6006e-02,
          -6.6946e-03],
         [-3.1059e-03,  1.1493e-02, -2.1026e-03, -1.2253e-02, -6.7652e-03,
          -2.9434e-03],
         [-3.3075e-03,  1.3811e-02, -4.5127e-03, -1.2450e-02, -5.1925e-03,
          -1.9075e-03],
         [-1.8700e-03,  1.1611e-02, -4.4019e-03, -1.1083e-02, -5.0915e-03,
          -1.0186e-03]]])


# 分离计算图时第二个模型的梯度反传
    f2.backward(ff2.grad) # loss.backward()省略了参数,完整写法是loss.backward(grad=loss)
    print('ff1.grad')
    print(ff1.grad)

>>> ff1.grad:
>>> tensor([[[ 1.4645e-05, -4.5453e-03, -1.3644e-02,  2.8668e-02,  3.3055e-02,
           1.2564e-02],
         [ 8.2688e-03, -3.2370e-02,  2.4661e-02,  7.1180e-03, -1.6006e-02,
          -6.6946e-03],
         [-3.1059e-03,  1.1493e-02, -2.1026e-03, -1.2253e-02, -6.7652e-03,
          -2.9434e-03],
         [-3.3075e-03,  1.3811e-02, -4.5127e-03, -1.2450e-02, -5.1925e-03,
          -1.9075e-03],
         [-1.8700e-03,  1.1611e-02, -4.4019e-03, -1.1083e-02, -5.0915e-03,
          -1.0186e-03]]]) # 与前述不分离计算图时结果一致。

反向传播前后模型参数对比

1、对比反向传播前后模型参数变化
在这里插入图片描述

2、对比使用控制三个模型变量的optim3和仅控制model3参数的optim3时,模型参数的变化:
在这里插入图片描述

总结

  1. 模型串联时计算图不分离,loss.backward()只计算梯度,更新模型参数由optimizer完成,optimizer控制多少参数,自然更新多少参数
  2. .detach()方法会使计算图分离
  3. 反向传播必定需要optimizer,否则模型参数不更新
  4. 一个optimizer可以同时更新多个模型的参数
  5. 但是必须明确一个模型里需要更新的只有模型参数,输入没有也不该有导数, output和中间变量属于链式求导的一环, 打印后可以确认其可以计算梯度,符合BP的推导结论
  • 7
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值