上一篇迁移学习的文章下,有朋友问在fine-tune的时候如何固定某些层不参与训练,下去研究并实验了一下,在这儿总结一下:
pytorch中关于网络的反向传播操作是基于Variable对象,Variable中有一个参数requires_grad,将requires_grad=False,网络就不会对该层计算梯度。
在用户手动定义Variable时,参数requires_grad默认值是False。而在Module中的层在定义时,相关Variable的requires_grad参数默认是True。
在训练时如果想要固定网络的底层,那么可以令这部分网络对应子图的参数requires_grad为False。这样,在反向过程中就不会计算这些参数对应的梯度。
冻结模型参考链接:
https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088model_ft = models.resnet50(pretrained=True) #读入resnet50模型
ct = 0
for child in model_ft.children():
ct += 1
if ct < 7:
for param in child.parameters():
param.requires_grad = False
然后需要在优化器中filter一下
optimizer.SGD(filter(lambda p: p.requires_grad, model_ft.parameters()), lr=1e-3)
那么model.children()是什么呢?
这里需要介绍一下self.children()与self.module()
参考链接:https://blog.csdn.net/dss_dssssd/article/details/83958518
例如下边的网络结构
![](https://i-blog.csdnimg.cn/blog_migrate/8d5377e2af0c6922c6466cf67827111e.jpeg)
![](https://i-blog.csdnimg.cn/blog_migrate/a3d81e455c90631a83afcdda0d5a20af.jpeg)
分别利用children()与module()输出的结果,分析二者:
self.children()
out:
children
0 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
1 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
self.modules()
out:
modules
######net节点
0 Net(
(layer): Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
(layer2): Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
(layer3): Linear(in_features=1, out_features=1, bias=True)
)
####左边第一个节点及其遍历
1 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
3 ReLU(inplace)
####中间节点及其下边的子节点
4 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
5 Linear(in_features=1, out_features=1, bias=True)
6 ReLU(inplace)
##最右边的节点
7 Linear(in_features=1, out_features=1, bias=True)
从上边可以看出,children()输出的是网络的子节点即net‘s children
module()输出所有节点的遍历