pytorch中transforms.ToTensor()函数解析

torchvision.transforms.functional.py中to_tensor()函数源码:

def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
    if not(_is_pil_image(pic) or _is_numpy(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

    if _is_numpy(pic) and not _is_numpy_image(pic):
        raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))

    if isinstance(pic, np.ndarray):
        # handle numpy array
        if pic.ndim == 2:
            pic = pic[:, :, None]

        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    elif pic.mode == 'F':
        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
    elif pic.mode == '1':
        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img

从to_tensor()函数源码看出其接受PIL Image或numpy.ndarray格式,功能如下:

  • 先由HWC转置为CHW格式;
  • 再转为float类型;
  • 最后,每个像素除以255。
  • 24
    点赞
  • 49
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这段代码是一个 PyTorch 程序的主函数,它包括了数据预处理的相关操作和主要的模型训练过程。 首先,程序导入了一些必要的库,包括 os、json、torchtorch.nn、torch.optim、transforms 和 datasets 等。其,os 和 json 库用于处理文件和数据的读写,torch 库是 PyTorch 深度学习框架的核心库,torch.nn 库提供了深度学习模型的基础组件,torch.optim 库提供了常见的优化算法,transforms 库提供了一些常用的数据预处理操作,datasets 库则提供了常见的数据集加载方法。 然后,程序定义了一个主函数 main()。主函数首先通过 torch.cuda.is_available() 函数判断是否可以使用 GPU 加速,如果可以,则将设备设置为 CUDA 设备,否则设置为 CPU 设备。接着,程序定义了一个名为 data_transform 的字典,它包含了两个键值对,分别对应训练集和验证集的数据预处理操作。其,训练集的预处理操作包括随机裁剪、随机水平翻转、转换为张量以及标准化等,验证集的预处理操作包括缩放、心裁剪、转换为张量以及标准化等。 这段代码还引入了一个自定义的 resnet34 模型,这个模型基于 ResNet-34 架构,用于对图像进行分类。最后,主函数进入了一个循环,用于对模型进行训练和验证。其,训练数据集和验证数据集通过 datasets.ImageFolder 函数加载,模型的损失函数采用交叉熵损失函数,优化算法采用随机梯度下降算法,每个 epoch 的训练过程通过 tqdm 库进行可视化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值