import cv2
import os
import torch
import torch.optim
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
def dehaze_image(image_addr, name):
data_hazy = Image.open(image_addr)
data_hazy = np.array(data_hazy) / 255.0
data_hazy = torch.from_numpy(data_hazy).float()
data_hazy = data_hazy.permute(2, 0, 1)
data_hazy = data_hazy.unsqueeze(0)
# print(type(data_hazy))
dehaze_net = torch.load('saved_models/dehaze_net_epoch_17.pth', map_location=torch.device('cuda'))
data_hazy_cuda = data_hazy.cuda()
# print(type(data_hazy_cuda))
# clean_image = dehaze_net(data_hazy).detach().numpy().squeeze()
clean_image = dehaze_net(data_hazy_cuda).cpu()
clean_image = clean_image.detach().numpy().squeeze()
clean_image = np.swapaxes(clean_image, 0, 1)
clean_image = np.swapaxes(clean_image, 1, 2)
# save_addr = './results/' + name + '.png'
save_addr = './results/' + name# ots
# print('saveaddr:', save_addr)
# print('yes')
plt.imsave(save_addr, clean_image)
def dehazeFile(img_Dir):
img_pathDir = os.listdir(img_Dir) # 提取所有文件名,并存在列表中
print("img_Dir:", img_Dir) # 输出文件路径
print("img_pathDir:", img_pathDir) # 输出文件名列表
print(len(img_pathDir)) # 输出文件数
num = 0
for i in img_pathDir:
addr = r"./test_images/" + i
name = i.split("_")[0]
# dehaze_image(addr, name)
# print(i)
dehaze_image(addr, i) #ots
num = num + 1
print('handling:{}/{}'.format(num, len(img_pathDir)))
return
if __name__ == '__main__':
img_Dir = './test_images'#本地文件路径
dehazeFile(img_Dir)
AODNet复现: 用gpu批量处理图片
最新推荐文章于 2023-11-22 10:58:49 发布