best_psnr, best_ssim, best_epoch_ssim, best_epoch_psnr = 0., 0., 0, 0
if epoch % 1 == 0:
model.eval()
PSNRs, SSIMs = [], []
pbar = tqdm(val_loader)
for ii, data_val in enumerate(pbar, 0):
input_ = data_val[0].cuda()
target_SR = data_val[1].cuda()
model.forward(input_, input_, training=False)
restored_SR = model.sample(testing=True)
with torch.no_grad():
for res, tar in zip(restored_SR, target_SR):
print(res.shape, tar.shape)
temp_psnr = torchPSNR(res, tar)
temp_ssim = torchSSIM(restored_SR, target_SR)
PSNRs.append(temp_psnr)
SSIMs.append(temp_ssim)
pbar.set_description("[Epoch] {} [MODE] VALID [PSNR] {:.4f} [SSIM] {:.4f}".format(
epoch,
torchPSNR(restored_SR, target_SR),
torchSSIM(restored_SR, target_SR))
)
PSNRs = torch.stack(PSNRs).mean().item()
SSIMs = torch.stack(SSIMs).mean().item()
save_image(torch.cat((pad_img(input_[0], (256, 256)),
restored_SR[0], target_SR[0]), -1),
os.path.join(opt.output, str(epoch) + '.png')) # save image
# Save the best PSNR model of validation
if PSNRs > best_psnr:
best_psnr = PSNRs
best_epoch_psnr = epoch
model_out_path = opt.save_folder + "epoch_best_psnr.pth"
torch.save(model.state_dict(), model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
print("[PSNR] {:.4f} [Best_PSNR] {:.4f} (epoch {})".format(PSNRs, best_psnr, best_epoch_psnr))
# Save the best SSIM model of validation
if SSIMs > best_ssim:
best_ssim = SSIMs
best_epoch_ssim = epoch
model_out_path = opt.save_folder + "epoch_best_ssim.pth"
torch.save(model.state_dict(), model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
print("[SSIM] {:.4f} [Best_SSIM] {:.4f} (epoch {})".format(SSIMs, best_ssim, best_epoch_ssim))
深度学习-训练网络后直接进行验证保存最优模型
最新推荐文章于 2023-10-10 20:41:21 发布