在 PyTorch 中,要查看模型中各层是否被冻结(即参数的 requires_grad 属性是否为 False),可以使用以下方法。
- 查看所有参数的冻结状态
使用 model.parameters() 遍历模型的所有参数,并检查每个参数的 requires_grad 属性。
for name, param in model.named_parameters():
print(f"Parameter: {name}, Requires Grad: {param.requires_grad}")
示例输出
如果模型的部分层被冻结,你会看到类似以下输出:
Parameter: patch_embed.proj.weight, Requires Grad: False
Parameter: patch_embed.proj.bias, Requires Grad: False
Parameter: blocks.0.attn.qkv.weight, Requires Grad: True
...
- 查看特定层的冻结状态
如果只对某些特定层感兴趣,可以直接访问这些层的参数。
# 查看某层的参数是否被冻结,例如分类头
for name, param in model.head.named_parameters():
print(f"Parameter: {name}, Requires Grad: {param.requires_grad}")
或者针对整个模块:
# 检查整个模块的参数是否被冻结
for param in model.head.parameters():
print(f"Requires Grad: {param.requires_grad}")
- 检查冻结层数量
统计模型中冻结的层数和未冻结的层数。
frozen = sum(param.requires_grad == False for param in model.parameters())
unfrozen = sum(param.requires_grad for param in model.parameters())
print(f"Frozen layers: {frozen}, Unfrozen layers: {unfrozen}")
- 查看模型是否全部冻结
如果需要检查整个模型是否被冻结,可以直接判断所有参数的状态:
is_frozen = all(param.requires_grad == False for param in model.parameters())
print(f"Is the entire model frozen? {is_frozen}")
- 示例:冻结部分层并检查状态
以下是一个冻结部分层并查看其状态的完整示例:
import timm
# 加载模型
model = timm.create_model('vit_base_patch16_224', pretrained=True)
# 冻结前几个 Transformer block
for param in model.blocks[:6].parameters():
param.requires_grad = False
# 检查冻结状态
for name, param in model.named_parameters():
print(f"Parameter: {name}, Requires Grad: {param.requires_grad}")
- 使用 requires_grad_() 快速设置冻结状态
如果你想快速切换冻结状态,可以使用 requires_grad_()。
# 冻结整个模型
model.requires_grad_(False)
# 解冻分类头
model.head.requires_grad_(True)
# 检查冻结状态
for name, param in model.named_parameters():
print(f"Parameter: {name}, Requires Grad: {param.requires_grad}")
总结
• 检查所有参数:使用 model.named_parameters() 遍历所有参数。
• 检查特定层:针对感兴趣的模块或参数单独检查。
• 统计冻结层数量:快速计算冻结和未冻结层的数量。
• 快速设置状态:通过 requires_grad_() 方法快速冻结或解冻模型。