1. torchvision.utils.save_image
1.1 封装的原函数
@torch.no_grad()
def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[str, pathlib.Path, BinaryIO],
format: Optional[str] = None,
**kwargs,
) -> None:
"""
Save a given Tensor into an image file.
Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``.
fp (string or file object): A filename or a file object
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(save_image)
grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)
1.2 调用示例
import torch
import os
import torchvision.utils as tvu
from PIL import Image, __version__ as PILLOW_VERSION
def save_image(img, file_directory):
if not os.path.exists(os.path.dirname(file_directory)):
os.makedirs(os.path.dirname(file_directory))
tvu.save_image(img, file_directory)
1.3 修改重写torchvision.utils.save_image函数
示例1:
import torch
import os
import torchvision.utils as tvu
from PIL import Image, __version__ as PILLOW_VERSION
def save_image_scale(img, file_directory, size):
if not os.path.exists(os.path.dirname(file_directory)):
os.makedirs(os.path.dirname(file_directory))
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
tvu._log_api_usage_once(tvu.save_image)
grid = tvu.make_grid(img)
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(ndarr)
# 根据Pillow版本选择滤镜
if PILLOW_VERSION.startswith('9.'):
resample = Image.Resampling.LANCZOS
else:
resample = Image.ANTIALIAS
im_resize = im.resize(size, resample)
im_resize.save(file_directory)
示例2:
@staticmethod
def save_img_batch_scale(batch, dirpath, fname, size, save_num=1):
util.mkdir(dirpath)
imgpath = osp.join(dirpath, fname)
# If you want to visiual a single image, call .unsqueeze(0)
assert len(batch.shape) == 4
# torchvision.utils.save_image(batch[:save_num], imgpath)
img = batch[:save_num]
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
torchvision.utils._log_api_usage_once(torchvision.utils.save_image)
grid = torchvision.utils.make_grid(img)
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(ndarr)
# print(f"PILLOW_VERSION: {PILLOW_VERSION}")
# 根据Pillow版本选择滤镜
if int(PILLOW_VERSION.split('.')[0]) >= 9:
resample = Image.Resampling.LANCZOS
else:
resample = Image.ANTIALIAS
im_resize = im.resize(size, resample)
im_resize.save(imgpath)
2. 图像缩放
2.1 torch.nn.functional示例
import torch.nn.functional as F
def restore_scale(self, val_loader):
image_folder = os.path.join(self.args.image_folder, self.config.data.val_dataset)
val_name_list = [os.path.splitext(img_name)[0] for img_name in os.listdir(image_folder)]
# print(f'val_name_list: {val_name_list}')
with torch.no_grad():
for i, (x, y, src_w, src_h) in enumerate(val_loader):
# 20240826 add
# if y[0] + '_cfwd' in val_name_list:
# continue
print(f'x.shape: {x.shape}, y[0]: {y[0]}, size: ({src_w}, {src_h})')
# _, _, org_h, org_w = x.shape
# max_flag = max(org_h, org_w) > 1536
# if max_flag:
# # 计算缩放比例,使最大边为1536像素,保持比例不变
# scale = 1536.0 / max(org_h, org_w)
# new_h, new_w = int(org_h * scale), int(org_w * scale)
# # 将输入图像缩放到新的尺寸,保持原始比例
# x_cond_resized = F.interpolate(x, size=(new_h, new_w), mode='bicubic', align_corners=False)
# else:
# x_cond_resized = x
# x_cond = x_cond_resized[:, :3, :, :].to(self.diffusion.device)
x_cond = x[:, :3, :, :].to(self.diffusion.device)
b, c, h, w = x_cond.shape
# img_h_32 = int(32 * np.ceil(h / 32.0))
# img_w_32 = int(32 * np.ceil(w / 32.0))
# x_cond = F.pad(x_cond, (0, img_w_32 - w, 0, img_h_32 - h), 'reflect')
factor = 32
img_h_32 = ((h + factor) // factor) * factor
img_w_32 = ((w + factor) // factor) * factor
padh = img_h_32 - h if h % factor != 0 else 0
padw = img_w_32 - w if w % factor != 0 else 0
x_cond = F.pad(x_cond, (0, padw, 0, padh), 'reflect')
x_output_resized = self.diffusive_restoration(x_cond)
x_output_resized = x_output_resized[:, :, :h, :w]
# if max_flag:
# # 将处理后的图像恢复到原始尺寸
# x_output = F.interpolate(x_output_resized, size=(org_h, org_w), mode='bicubic', align_corners=False)
# else:
# x_output = x_output_resized
# utils.logging.save_image(x_output, os.path.join(image_folder, f"{y[0]}_cfwd.jpg"))
_, _, dst_h, dst_w = x_output_resized.shape
print(f'dst_h: {dst_h}, dst_w: {dst_w}')
if dst_h != src_h:
utils.logging.save_image_scale(x_output_resized, os.path.join(image_folder, f"{y[0]}_cfwd1.jpg"), (src_w.item(), src_h.item()))
print('no resize=============')
else:
utils.logging.save_image(x_output_resized, os.path.join(image_folder, f"{y[0]}_cfwd1.jpg"))
print('resize---------------')
2.2 PIL.Image示例:
# 2024-08-27 add
def resize_image_if_larger(self, image, scale=1536):
# 获取图像尺寸
w, h = image.size
# 检查图像是否大于 scale
if max(w, h) > scale:
# 计算缩放比例
ratio = scale / max(w, h)
# 应用缩放
new_size = (int(w * ratio), int(h * ratio))
return image.resize(new_size, Image.BICUBIC)
return image
def get_images(self, index):
input_name = self.input_names[index].replace('\n', '')
gt_name = self.gt_names[index].replace('\n', '')
img_id = re.split('/', input_name)[-1][:-4]
input_img = Image.open(os.path.join(self.dir, input_name)).convert('RGB') if self.dir else PIL.Image.open(input_name)
gt_img = Image.open(os.path.join(self.dir, gt_name)).convert('RGB') if self.dir else PIL.Image.open(gt_name)
# 2024-08-27 add
w, h = input_img.size
input_img = self.resize_image_if_larger(input_img)
gt_img = self.resize_image_if_larger(gt_img)
input_img, gt_img = self.transforms(input_img, gt_img)
return torch.cat([input_img, gt_img], dim=0), img_id, w, h
def __getitem__(self, index):
res = self.get_images(index)
return res
def __len__(self):
return len(self.input_names)
2.3 python-opencv示例:
# 类外定义
# 2024-08-29 add
def resize_image_if_larger(image, scale=1536):
w, h = image.shape[:2] # 获取输入图像的高度和宽度
# 检查图像是否大于 scale
if max(w, h) > scale:
# 计算缩放比例
ratio = scale / max(w, h)
# 应用缩放
new_h = int(h * ratio) # 按比例计算新的高度
new_w = int(w * ratio) # 按比例计算新的宽度
return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
return image
# 类中定义
def __getitem__(self, idx):
"""Returns a pair of images with the given identifier. This is lazy loading
of data into memory. Only those image pairs needed for the current batch
are loaded.
:param idx: image pair identifier
:returns: dictionary containing input and output images and their identifier
:rtype: dictionary
"""
res_item = {INPUT_FPATH: self.input_list[idx]}
# different seed for different item, but same for GT and INPUT in one item:
# the "seed of seed" is fixed for reproducing
# random.seed(GLOBAL_SEED)
seed = random.randint(0, 100000)
input_img = cv2.imread(self.input_list[idx])[:, :, [2, 1, 0]]
# 2024-08-29 对二维图像进行缩放处理
res_item[WIDTH] = input_img.shape[1] # 对应的在globalenv.py中定义WIDTH
res_item[HEIGHT] = input_img.shape[0]
input_img = resize_image_if_larger(input_img)
if self.have_gt and self.gt_list[idx].endswith('.hdr'):
input_img = torch.Tensor(input_img / 255).permute(2, 0, 1)
else:
input_img = augment_one_img(input_img, seed, transform=self.transform)
res_item[INPUT] = input_img
if self.have_gt:
res_item[GT_FPATH] = self.gt_list[idx]
if res_item[GT_FPATH].endswith('.hdr'):
# gt may be HDR
# do not augment HDR image.
gt_img = cv2.imread(self.gt_list[idx], flags=cv2.IMREAD_ANYDEPTH)[:, :, [2, 1, 0]]
# 2024-08-29 add
gt_img = resize_image_if_larger(gt_img)
gt_img = torch.Tensor(np.log10(gt_img + 1)).permute(2, 0, 1)
else:
gt_img = cv2.imread(self.gt_list[idx])[:, :, [2, 1, 0]]
gt_img = augment_one_img(gt_img, seed, transform=self.transform)
res_item[GT] = gt_img
assert res_item[GT].shape == res_item[INPUT].shape
print(f"res_item[INPUT] shape: {input_img.shape}, res_item[INPUT_FPATH]: {res_item[INPUT_FPATH]}")
return res_item