过滤出模型 self.global_model 中需要梯度更新的参数,为优化器的初始化提供参数列表。
1. self.global_model.parameters()
-
含义:
self.global_model.parameters()是一个生成器,返回模型中所有参数的迭代器。- 每个参数是一个 PyTorch 的
torch.Tensor,包含权重或偏置数据。
-
输出:
- 包括所有的权重和偏置(bias),无论它们是否参与梯度更新。
2. lambda p: p.requires_grad
-
p.requires_grad:- 这是 PyTorch 中的一个属性,表示参数
p是否需要计算梯度。 - 如果
requires_grad=True,表示该参数会参与反向传播,优化器会更新它的值。 - 如果
requires_grad=False,表示该参数是冻结的,通常用于冻结预训练模型的一部分,或在特定阶段不需要更新某些参数。
- 这是 PyTorch 中的一个属性,表示参数
-
lambda p: p.requires_grad:- 这是一个匿名函数,作用是检查每个参数的
requires_grad属性,返回True或False。
- 这是一个匿名函数,作用是检查每个参数的
3. filter(lambda p: p.requires_grad, self.global_model.parameters())
-
filter函数:- 它接受两个参数:
- 过滤条件(这里是
lambda p: p.requires_grad)。 - 待过滤的可迭代对象(这里是
self.global_model.parameters())。
- 过滤条件(这里是
filter会迭代self.global_model.parameters()中的所有参数,并返回满足requires_grad=True条件的参数。
- 它接受两个参数:
-
作用:
- 过滤掉
requires_grad=False的参数,保留所有需要参与梯度更新的参数。
- 过滤掉
4. 为什么要过滤参数?
- 在深度学习训练中,优化器(如 SGD、Adam)需要传入参数列表。通常我们只想更新那些需要梯度的参数,避免浪费计算资源。
- 通过过滤
requires_grad=False的参数,可以确保优化器只更新需要的参数,冻结的参数不会被修改。
5. 典型的用法
这段代码通常与优化器一起使用,例如:
model_parameters = filter(lambda p: p.requires_grad, self.global_model.parameters()) optimizer = torch.optim.Adam(model_parameters, lr=1e-3)
- 解释:
filter返回的对象被传递给优化器,确保优化器只会更新requires_grad=True的参数。- 优化器不会浪费计算资源在冻结参数上。
6. 示例
假设模型包含以下参数:
from torch import nn class ExampleModel(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Linear(10, 5) # 默认 requires_grad=True self.layer2 = nn.Linear(5, 2) self.layer2.weight.requires_grad = False # 冻结 layer2 的权重 model = ExampleModel()
所有参数:
for param in model.parameters(): print(param.requires_grad) # 输出: # True (layer1 weight) # True (layer1 bias) # False (layer2 weight) # True (layer2 bias)
过滤后:
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) for param in model_parameters: print(param.requires_grad) # 输出: # True (layer1 weight) # True (layer1 bias) # True (layer2 bias)
最后
filter(lambda p: p.requires_grad, self.global_model.parameters()) 的作用是提取出模型中需要更新的参数,避免优化器处理不需要梯度更新的参数(如冻结层)。这种方式对于冻结部分预训练模型权重,或减少计算开销时非常有用。
904

被折叠的 条评论
为什么被折叠?



