Pytorch image loader与transform学习笔记

default_loader源代码

from https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html

def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f: #用'rb'防止图片当作文本读入产生UnicodeDecodeError 
        img = Image.open(f)
        return img.convert('RGB')


# TODO: specify the return type
def accimage_loader(path: str) -> Any:
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path: str) -> Any:
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

可见,一般情况下用PIL Image.open 读入,为RGB 形式

transform 总结

transforms.ToTensor文档原文

原文链接:https://blog.csdn.net/qq_36560894/article/details/104480449
transforms.ToTensor()之后,数据会从 [H,W,C] 变为 [C,H,W],并且取值范围为 [0.0,1.0] 之间。

ToTensor 文档原文

Convert a PIL Image or numpy.ndarray to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8

In the other cases, tensors are returned without scaling.
… note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the references_ for implementing the transforms for image masks.
… _references: https://github.com/pytorch/vision/tree/main/references/segmentation

Image Normalization 目的

原文链接:https://blog.csdn.net/qq_36560894/article/details/104480449

从链接中的博文得知,Image Normalization 防止梯度消失或梯度爆炸。

transforms.Normalize() 源代码

from https://pytorch.org/vision/stable/transforms.html

class Normalize(torch.nn.Module):
    """Normalize a tensor image with mean and standard deviation.
    This transform does not support PIL Image.
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``

    .. note::
        This transform acts out of place, i.e., it does not mutate the input tensor.

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation in-place.

    """

    def __init__(self, mean, std, inplace=False):
        super().__init__()
        self.mean = mean
        self.std = std
        self.inplace = inplace

	def forward(self, tensor: Tensor) -> Tensor:
        """
        Args:
            tensor (Tensor): Tensor image to be normalized.

        Returns:
            Tensor: Normalized Tensor image.
        """
        return F.normalize(tensor, self.mean, self.std, self.inplace)

注意:正则化tensor, 不支持PIL Image

图片加载,变换定义部分样例

以下部分为伪代码

from torchvision import transforms as T
transform=T.Compose(
    Resize(size=(256, 128), interpolation=PIL.Image.BICUBIC)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.485, 0.456, 0.406])
)

图片load, transform 部分定义在 getitem 函数中

from torch.utils.data import Dataset
class PersonDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        super(PersonDataset, self).__init__()
        self.data_dir = data_dir
        self.img_paths = os.listdir(data_dir)
        self.transform = transform
        if not os.path.exists(self.data_dir):
            raise IOError('dataset dir: {} is non-exist'.format(
                            self.data_dir))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        im = read_image(osp.join(self.data_dir, self.img_paths[index])) 
        #完成default_loder,返回PIL Image 
        
        if self.transform is not None:
            im = self.transform(im)
        else:
            im = im / 255.
        return im

    def __repr__(self):
        format_string  = self.__class__.__name__ + '(num_imgs='
        format_string += '{:d})'.format(len(self))
        return format_string

总结

  1. pytorch中,图片通过PIL Image.open() 读入,默认为RGB 形式,与cv2 中的形式不同。
  2. 在transform 过程中,通过ToTensor 完成从PIL Imagetorch.tensor 的变换,变换后图片的像素值为 [0.0, 1.0]
  3. Normalize 中的mean, std 同过训练集确定,通常使用的 Normalize(mean=[0.485, 0.456, 0.406], std=[0.485, 0.456, 0.406]) 来自ImageNet。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值