1. 固定网络的部分参数进行训练
- 简单方法,插入 with 块中,则块中的部分不会更新参数
s1 = torch.tensor([1.])
s2 = torch.tensor([1.])
with torch.no_grad():
s3 = s1 + s2
- 第一种方法不适用时,可以根据 model.children() 查看模型所有参数,从中选择将部分网络的 requires_grad 参数设置为 False
- 在变量设置 volatile=True 或者 requires_grad=False
- 在变量后写上 xxx.detach()
2. 加载预训练模型
-对于完全一样的网络,可以用过 torch.load 直接添加,例如:
model = Model() # 网络
pretrained = './pretrained_model.pth'
state = torch.load(pretrained)['state_dict']
model.load_state_dict(state)
- 对于修改过的网络(添加了新的网络成分),要进行简单修改
def load_model(model, load_model):
dict_trained = torch.load(load_model)["state_dict"]
dict_new = model.state_dict().copy()
new_list = list(model.state_dict().keys())
trained_list = list (dict_trained.keys())
for i in new_list:
if i in trained_list:
dict_new[i] = dict_trained[i]
model.load_state_dict(dict_new)
3. 动态调整学习率
- 在训练一定次数后调整、降低学习率,进一步降低 Loss
def adapt_learning_rate(optimizer, epoch, decay_iter=20, decay_factor=0.9):
if epoch // decay_iter == 0:
for params in optimizer.param_groups:
params['lr'] = params['lr'] * decay_factor
# Omitting codes
for epoch in range(all_epochs):
test = model(input_test)
loss = loss_func(test, gt_value)
loss.backward()
optimizer.step()
adapt_learning_rate(optimizer, epoch)
4. Git 使用小 Tips
- 使用过程中需要忽略一些文件,例如需要忽略
.pyc
文件,则在工作根目录创建一个.gitignore
文件,在其中输入需要忽略的文件类型等即可,以下代码将pyc、pymodel、tar文件排除git范围:
tee ./.gitignore << EOF
*.pyc
*.pymodel
*.tar
EOF
git add .
git commit -m "Adding GitIgnore To Exclude Some File Types"
5. Pytorch 训练错误总结
- 训练过程中出现
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
错误 ,则可能是在网络构建过程中错误使用 += 进行赋值
# ... other codes
def forward(self, inputx):
# ....
x = self.conv1(x)
x += self.conv2(x)
# 改为
x = self.conv2(x) + x
# 如此可以避免 inplace 的 Error