# 参数设置(指定用第几轮的预训练权重)
parser = argparse.ArgumentParser(description="PyTorch Net")
parser.add_argument("--resume_epoch", type=int, default=1225, help="Resume from checkpoint epoch")
opt = parser.parse_args()
# 加载预训练模型4件套(不加载初始化是0)
if opt.resume_epoch:
resume_path = join(model_dir,'model_epoch_{}.pth'.format(opt.resume_epoch))
if os.path.isfile(resume_path):
print("==>loading checkpoint 'epoch{}'".format(resume_path))
checkpoint = torch.load(resume_path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
losslogger = checkpoint['losslogger']
else:
print("==> no model found at 'epoch{}'".format(opt.resume_epoch))
pytorch添加迁移学习
最新推荐文章于 2023-12-06 11:31:40 发布