torchvision.utils.save_image

43 篇文章 10 订阅
16 篇文章 1 订阅

1. torchvision.utils.save_image

1.1 封装的原函数

@torch.no_grad()
def save_image(
    tensor: Union[torch.Tensor, List[torch.Tensor]],
    fp: Union[str, pathlib.Path, BinaryIO],
    format: Optional[str] = None,
    **kwargs,
) -> None:
    """
    Save a given Tensor into an image file.

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
        fp (string or file object): A filename or a file object
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
        **kwargs: Other arguments are documented in ``make_grid``.
    """

    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(save_image)
    grid = make_grid(tensor, **kwargs)
    # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    im.save(fp, format=format)

1.2 调用示例

import torch
import os
import torchvision.utils as tvu
from PIL import Image, __version__ as PILLOW_VERSION

def save_image(img, file_directory):
    if not os.path.exists(os.path.dirname(file_directory)):
        os.makedirs(os.path.dirname(file_directory))
    tvu.save_image(img, file_directory)

1.3 修改重写torchvision.utils.save_image函数

示例1:

import torch
import os
import torchvision.utils as tvu
from PIL import Image, __version__ as PILLOW_VERSION

def save_image_scale(img, file_directory, size):
    if not os.path.exists(os.path.dirname(file_directory)):
        os.makedirs(os.path.dirname(file_directory))
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        tvu._log_api_usage_once(tvu.save_image)
    grid = tvu.make_grid(img)
    # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    # 根据Pillow版本选择滤镜
    if PILLOW_VERSION.startswith('9.'):
        resample = Image.Resampling.LANCZOS
    else:
        resample = Image.ANTIALIAS
    im_resize = im.resize(size, resample)
    im_resize.save(file_directory)

示例2:

@staticmethod
    def save_img_batch_scale(batch, dirpath, fname, size, save_num=1): 
        util.mkdir(dirpath)
        imgpath = osp.join(dirpath, fname)

        # If you want to visiual a single image, call .unsqueeze(0)
        assert len(batch.shape) == 4
        # torchvision.utils.save_image(batch[:save_num], imgpath)     
        img = batch[:save_num]
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            torchvision.utils._log_api_usage_once(torchvision.utils.save_image)
        grid = torchvision.utils.make_grid(img)
        # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
        ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
        im = Image.fromarray(ndarr)
        # print(f"PILLOW_VERSION: {PILLOW_VERSION}")
        # 根据Pillow版本选择滤镜
        if int(PILLOW_VERSION.split('.')[0]) >= 9:
            resample = Image.Resampling.LANCZOS
        else:
            resample = Image.ANTIALIAS
        im_resize = im.resize(size, resample)
        im_resize.save(imgpath)

2. 图像缩放

2.1 torch.nn.functional示例

import torch.nn.functional as F

def restore_scale(self, val_loader):
        image_folder = os.path.join(self.args.image_folder, self.config.data.val_dataset)
        val_name_list = [os.path.splitext(img_name)[0] for img_name in os.listdir(image_folder)]
        # print(f'val_name_list: {val_name_list}')
        with torch.no_grad():
            for i, (x, y, src_w, src_h) in enumerate(val_loader):
                # 20240826 add
                # if y[0] + '_cfwd' in val_name_list:
                #     continue

                print(f'x.shape: {x.shape}, y[0]: {y[0]}, size: ({src_w}, {src_h})')
                # _, _, org_h, org_w = x.shape
                # max_flag = max(org_h, org_w) > 1536
                # if max_flag:
                #     # 计算缩放比例,使最大边为1536像素,保持比例不变
                #     scale = 1536.0 / max(org_h, org_w)
                #     new_h, new_w = int(org_h * scale), int(org_w * scale)
                    
                #     # 将输入图像缩放到新的尺寸,保持原始比例
                #     x_cond_resized = F.interpolate(x, size=(new_h, new_w), mode='bicubic', align_corners=False)
                # else:
                #     x_cond_resized = x
            
                # x_cond = x_cond_resized[:, :3, :, :].to(self.diffusion.device)
                x_cond = x[:, :3, :, :].to(self.diffusion.device)
                b, c, h, w = x_cond.shape
                # img_h_32 = int(32 * np.ceil(h / 32.0))
                # img_w_32 = int(32 * np.ceil(w / 32.0))
                # x_cond = F.pad(x_cond, (0, img_w_32 - w, 0, img_h_32 - h), 'reflect')
                factor = 32
                img_h_32 = ((h + factor) // factor) * factor
                img_w_32 = ((w + factor) // factor) * factor
                padh = img_h_32 - h if h % factor != 0 else 0
                padw = img_w_32 - w if w % factor != 0 else 0
                x_cond = F.pad(x_cond, (0, padw, 0, padh), 'reflect')

                x_output_resized = self.diffusive_restoration(x_cond)

                x_output_resized = x_output_resized[:, :, :h, :w]

                # if max_flag:
                #     # 将处理后的图像恢复到原始尺寸
                #     x_output = F.interpolate(x_output_resized, size=(org_h, org_w), mode='bicubic', align_corners=False)
                # else:
                #     x_output = x_output_resized
                # utils.logging.save_image(x_output, os.path.join(image_folder, f"{y[0]}_cfwd.jpg"))

                _, _, dst_h, dst_w = x_output_resized.shape
                print(f'dst_h: {dst_h}, dst_w: {dst_w}')
                if dst_h != src_h:
                    utils.logging.save_image_scale(x_output_resized, os.path.join(image_folder, f"{y[0]}_cfwd1.jpg"), (src_w.item(), src_h.item()))
                    print('no resize=============')
                    
                else:
                    utils.logging.save_image(x_output_resized, os.path.join(image_folder, f"{y[0]}_cfwd1.jpg"))
                    print('resize---------------')

2.2 PIL.Image示例:

# 2024-08-27 add 
    def resize_image_if_larger(self, image, scale=1536):
        # 获取图像尺寸
        w, h = image.size
        # 检查图像是否大于 scale
        if max(w, h) > scale:
            # 计算缩放比例
            ratio = scale / max(w, h)
            # 应用缩放
            new_size = (int(w * ratio), int(h * ratio))
            return image.resize(new_size, Image.BICUBIC)
        return image
    
    def get_images(self, index):
        input_name = self.input_names[index].replace('\n', '')
        gt_name = self.gt_names[index].replace('\n', '')
        img_id = re.split('/', input_name)[-1][:-4]
        input_img = Image.open(os.path.join(self.dir, input_name)).convert('RGB') if self.dir else PIL.Image.open(input_name)
        gt_img = Image.open(os.path.join(self.dir, gt_name)).convert('RGB') if self.dir else PIL.Image.open(gt_name)
        # 2024-08-27 add
        w, h = input_img.size
        input_img = self.resize_image_if_larger(input_img)
        gt_img = self.resize_image_if_larger(gt_img)        

        input_img, gt_img = self.transforms(input_img, gt_img)
        

        return torch.cat([input_img, gt_img], dim=0), img_id, w, h

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.input_names)

2.3 python-opencv示例:

# 类外定义
# 2024-08-29 add 
def resize_image_if_larger(image, scale=1536):
    w, h = image.shape[:2] # 获取输入图像的高度和宽度
    # 检查图像是否大于 scale
    if max(w, h) > scale:
        # 计算缩放比例
        ratio = scale / max(w, h)
        # 应用缩放
        new_h = int(h * ratio)  # 按比例计算新的高度
        new_w = int(w * ratio)  # 按比例计算新的宽度
        return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
    return image
    
# 类中定义   
def __getitem__(self, idx):
    """Returns a pair of images with the given identifier. This is lazy loading
    of data into memory. Only those image pairs needed for the current batch
    are loaded.

    :param idx: image pair identifier
    :returns: dictionary containing input and output images and their identifier
    :rtype: dictionary

    """
    res_item = {INPUT_FPATH: self.input_list[idx]}

    # different seed for different item, but same for GT and INPUT in one item:
    # the "seed of seed" is fixed for reproducing
    # random.seed(GLOBAL_SEED)
    seed = random.randint(0, 100000)
    input_img = cv2.imread(self.input_list[idx])[:, :, [2, 1, 0]]

    # 2024-08-29 对二维图像进行缩放处理
    res_item[WIDTH] = input_img.shape[1]  # 对应的在globalenv.py中定义WIDTH
    res_item[HEIGHT] = input_img.shape[0]
    input_img = resize_image_if_larger(input_img)

    if self.have_gt and self.gt_list[idx].endswith('.hdr'):
        input_img = torch.Tensor(input_img / 255).permute(2, 0, 1)
    else:
        input_img = augment_one_img(input_img, seed, transform=self.transform)
           
    res_item[INPUT] = input_img

    if self.have_gt:
        res_item[GT_FPATH] = self.gt_list[idx]

        if res_item[GT_FPATH].endswith('.hdr'):
            # gt may be HDR
            # do not augment HDR image.
            gt_img = cv2.imread(self.gt_list[idx], flags=cv2.IMREAD_ANYDEPTH)[:, :, [2, 1, 0]]
            # 2024-08-29 add
            gt_img = resize_image_if_larger(gt_img)
            gt_img = torch.Tensor(np.log10(gt_img + 1)).permute(2, 0, 1)
        else:
            gt_img = cv2.imread(self.gt_list[idx])[:, :, [2, 1, 0]]
            gt_img = augment_one_img(gt_img, seed, transform=self.transform)

        res_item[GT] = gt_img
        assert res_item[GT].shape == res_item[INPUT].shape

    print(f"res_item[INPUT] shape: {input_img.shape}, res_item[INPUT_FPATH]: {res_item[INPUT_FPATH]}")

    return res_item
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值