model_parameters = filter(lambda p: p.requires_grad, self.global_model.parameters())

过滤出模型 self.global_model需要梯度更新的参数,为优化器的初始化提供参数列表。


1. self.global_model.parameters()

  • 含义:

    • self.global_model.parameters() 是一个生成器,返回模型中所有参数的迭代器。
    • 每个参数是一个 PyTorch 的 torch.Tensor,包含权重或偏置数据。
  • 输出:

    • 包括所有的权重和偏置(bias),无论它们是否参与梯度更新。

2. lambda p: p.requires_grad

  • p.requires_grad:

    • 这是 PyTorch 中的一个属性,表示参数 p 是否需要计算梯度。
    • 如果 requires_grad=True,表示该参数会参与反向传播,优化器会更新它的值。
    • 如果 requires_grad=False,表示该参数是冻结的,通常用于冻结预训练模型的一部分,或在特定阶段不需要更新某些参数。
  • lambda p: p.requires_grad:

    • 这是一个匿名函数,作用是检查每个参数的 requires_grad 属性,返回 TrueFalse

3. filter(lambda p: p.requires_grad, self.global_model.parameters())

  • filter 函数:

    • 它接受两个参数:
      1. 过滤条件(这里是 lambda p: p.requires_grad)。
      2. 待过滤的可迭代对象(这里是 self.global_model.parameters())。
    • filter 会迭代 self.global_model.parameters() 中的所有参数,并返回满足 requires_grad=True 条件的参数。
  • 作用:

    • 过滤掉 requires_grad=False 的参数,保留所有需要参与梯度更新的参数。

4. 为什么要过滤参数?

  • 在深度学习训练中,优化器(如 SGD、Adam)需要传入参数列表。通常我们只想更新那些需要梯度的参数,避免浪费计算资源。
  • 通过过滤 requires_grad=False 的参数,可以确保优化器只更新需要的参数,冻结的参数不会被修改。

5. 典型的用法

这段代码通常与优化器一起使用,例如:

model_parameters = filter(lambda p: p.requires_grad, self.global_model.parameters())
optimizer = torch.optim.Adam(model_parameters, lr=1e-3)
  • 解释:
    • filter 返回的对象被传递给优化器,确保优化器只会更新 requires_grad=True 的参数。
    • 优化器不会浪费计算资源在冻结参数上。

6. 示例

假设模型包含以下参数:

from torch import nn

class ExampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 5)  # 默认 requires_grad=True
        self.layer2 = nn.Linear(5, 2)
        self.layer2.weight.requires_grad = False  # 冻结 layer2 的权重

model = ExampleModel()
所有参数:
for param in model.parameters():
    print(param.requires_grad)
# 输出:
# True (layer1 weight)
# True (layer1 bias)
# False (layer2 weight)
# True (layer2 bias)
过滤后:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
for param in model_parameters:
    print(param.requires_grad)
# 输出:
# True (layer1 weight)
# True (layer1 bias)
# True (layer2 bias)

最后

filter(lambda p: p.requires_grad, self.global_model.parameters()) 的作用是提取出模型中需要更新的参数,避免优化器处理不需要梯度更新的参数(如冻结层)。这种方式对于冻结部分预训练模型权重,或减少计算开销时非常有用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值