前言
这篇博客主要是解决DHAN-SHR网络没有发布推理代码的问题(新手写推理代码可能有问题),所以本文给出了DHAN-SHR网络的推理代码,希望可以帮助到大家。
import gc
import os
import numpy as np
import torchvision.transforms.functional as F
import torch
import warnings
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from models import Model
from utils import load_checkpoint
from config import Config
from tqdm import tqdm
from accelerate import Accelerator
warnings.filterwarnings('ignore')
# Load configuration
opt = Config('config.yml')
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model
model = Model().to(device)
load_checkpoint(model, opt.TESTING.WEIGHT)
model.eval()
# Example usage
if __name__ == '__main__':
accelerator = Accelerator()
images_path = r"images"
images = os.listdir(images_path)
for image in tqdm(images):
input_image_path = os.path.join(images_path, image)
output_image_path = os.path.join("result", image) # Path to save the processed image
image = Image.open(input_image_path).convert('RGB')
PS_W, PS_H = image.size
image = F.to_tensor(image)
model, image = accelerator.prepare(model, image)
image = image.unsqueeze(0).to(device)
with torch.no_grad():
result = model(image).clamp(0, 1)
# Post-process result if needed (optional, depends on the model)
result = result.squeeze(0).cpu()
# 将 PIL 图像转换为 NumPy 数组
result_np = np.array(result)
# result = result.mul(255).clamp(0, 255).byte()
result = transforms.ToPILImage()(result)
# Save the result
result.save(output_image_path)
del image
del result
gc.collect() # Explicitly collect garbage
torch.cuda.empty_cache()