DehazeFormer如何处理自己的数据输出去雾图像
源代码:DehazeFormer
论文:Vision Transformers for Single Image Dehazing
一定要参考test.py
,源代码的Readme.md
中预训练模型和训练数据集的链接。
先把需要加载的模型,输出的结果,数据集的路径先开出来。
saved_model_dir = "./saved_models/indoor/dehazeformer-s.pth"
result_dir = "./result"
dataset_dir = "./image"
os.makedirs(os.path.join(result_dir, 'imgs'), exist_ok=True)
加载模型:
这里需要先做一个single
,这个single
是test.py
里面做的,然后再加载
model = dehazeformer_s()
model.cuda()
model.load_state_dict(single(saved_model_dir))
加载要处理的数据集,这里要用到datasets.loader
里的SingleLoader
test_dataset = SingleLoader(dataset_dir)
test_loader = DataLoader(test_dataset,
batch_size=1,
num_workers=16,
pin_memory=True)
然后输入到模型中进行处理
for idx, batch in enumerate(test_loader):
input = batch["img"].cuda()
filename = batch['filename'][0]
with torch.no_grad():
# clamp()函数的功能将输入input张量每个元素的值压缩到区间 [min,max],并返回结果到一个新张量。
output = model(input).clamp_(-1, 1)
# [-1, 1] to [0, 1]
output = output * 0.5 + 0.5
_, _, H, W = output.size()
out_img = chw_to_hwc(output.detach().cpu().squeeze(0).numpy())
write_img(os.path.join(result_dir, 'imgs', filename), out_img)
print(idx,": ",filename," 已输出")
整个的代码
import os
import torch
from torch.utils.data import DataLoader
from models import *
from test import single
from datasets.loader import SingleLoader
from utils import chw_to_hwc, write_img
saved_model_dir = "./saved_models/indoor/dehazeformer-s.pth"
result_dir = "./result"
dataset_dir = "./image"
os.makedirs(os.path.join(result_dir, 'imgs'), exist_ok=True)
model = dehazeformer_s()
model.cuda()
model.load_state_dict(single(saved_model_dir))
test_dataset = SingleLoader(dataset_dir)
test_loader = DataLoader(test_dataset,
batch_size=1,
num_workers=16,
pin_memory=True)
for idx, batch in enumerate(test_loader):
input = batch["img"].cuda()
filename = batch['filename'][0]
with torch.no_grad():
# clamp()函数的功能将输入input张量每个元素的值压缩到区间 [min,max],并返回结果到一个新张量。
output = model(input).clamp_(-1, 1)
# [-1, 1] to [0, 1]
output = output * 0.5 + 0.5
_, _, H, W = output.size()
out_img = chw_to_hwc(output.detach().cpu().squeeze(0).numpy())
write_img(os.path.join(result_dir, 'imgs', filename), out_img)
print(idx,": ",filename," 已输出")