代码来源:zheng-ningxin/Pruning-from-scratch (github.com)
使用了lambda表达式并结合filter函数,生成需要求梯度参数的迭代器,具体代码如下:
# 冻结所有模型参数 for para in model.parameters(): para.requires_grad = False # 只更新BN层的参数 for layer in model.modules(): if isinstance(layer, nn.BatchNorm2d): for para in layer.parameters(): para.requires_grad = True model.train() criterion = nn.CrossEntropyLoss() # filter(lambda p: p.requires_grad, model.parameters()): 给出需要求梯度的参数的迭代器 optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)