查看模型参数
for name, param in mpdel.named_parameters():
if param.requires_grad:
print(name, ':', param.size())
计算模型参数量
total = sum([param.nelement() for param in model.parameters()])
print("Number of parameter: %.2fM" % (total/1e6))
迁移学习
EarlyStopping的使用
from pytorchtools import EarlyStopping
early_stopping = EarlyStopping(patience=patience, verbose=True)
for batch, data in enumerate(dataloader):
# 计算train_loss和val_loss
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
模型的保存和加载
训练过程中,有时候会由于各种原因停止训练,这时候在训练过程中就需要注意将每一个epoch的模型进行保存以便下次接着训练。
模型保存
# 标准
torch.save(model.state_dict(), 'checkpoint.pt')
# 如果是多GPU
torch.save(model.module.state_dict(), 'checkpoint.pt')
# 如果需要保存多个参数
state = {'epoch': epoch + 1, #保存的当前轮数
'state_dict': model.state_dict(), # 训练好的参数
'optimizer': optimizer.state_dict(), # 优化器参数,为了后续的resume
'best_pred': best_pred # 当前最好的精度
,....,...}
torch.save(state, 'checkpoint.pt')
模型加载
# 标准
model.load_state_dict(torch.load('checkpoint.pt'))
# 多参数的模型加载
checkpoint = torch.load('checkpiont.pt')
model.load_state_dict(checkpoint['state_dict']) # 模型参数
optimizer.load_state_dict(checkpoint['optimizer']) # 优化参数
epoch = checkpoint['epoch'] # epoch,可以用于更新学习率等