"""
code modified from https://github.com/frgfm/torch-cam/tree/master/torchcam
"""
import torch
from matplotlib import cm
import numpy as np
from PIL import Image
def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.
See :class:`~torchvision.transforms.ToPILImage` for more details.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
PIL Image: Image converted to PIL Image.
"""
if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != 'F':
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
'not {}'.format(type(npimg)))
if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = 'L'
elif npimg.dtype == np.int16:
expected_mode = 'I;16'
elif npimg.dtype == np.int32:
expected_mode = 'I'
elif npimg.dtype == np.float32:
expected_mode = 'F'
if mode is not None and mode != expected_mode:
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'LA'
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
if mode is None:
raise TypeError('Input type {} is not supported'.format(npimg.dtype))
return Image.fromarray(npimg, mode=mode)
def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = 'jet', alpha: float = 0.7) -> Image.Image:
"""Overlay a colormapped mask on a background image
Args:
img: background image
mask: mask to be overlayed in grayscale
colormap: colormap to be applied on the mask
alpha: transparency of the background image
Returns:
overlayed image
"""
if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):
raise TypeError('img and mask arguments need to be PIL.Image')
if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
raise ValueError('alpha argument is expected to be of type float between 0 and 1')
cmap = cm.get_cmap(colormap)
# Resize mask and apply colormap
overlay = mask.resize(img.size, resample=Image.BICUBIC)
overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)
# Overlay the image with the mask
overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))
return overlayed_img
【python code 给图像添加带有alpha的mask】
最新推荐文章于 2024-06-20 17:54:56 发布