导入包
import torch
import shutil
二、模块
1)保存模型参数,保存模型状态,状态中可以有模型参数,优化器参数,epoch等。如果是在验证集上表现比之前好,那么就是is_best=True,使用shutil.copyfile(src, des)将src文件直接拷贝到des,如果已经存在,就直接覆盖掉。
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):#state是一个字典,包含优化器、网络等参数
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
2)计算相关统计值,