torchvision-transforms 常用函数总结

torchvision-transforms 常用函数总结

一、概述——为何要用transforms

在这里插入图片描述
需求是多样的,因此可以通过实例化一个transforms,满足转换的需要。具体的class可以参考transforms.py中的描述

二、函数介绍

1、ToTensor

功能:将PIL.image读取的PIL类型图片或者cv2.imread读取的numpy.ndarray转化为tensor类型
最简单的函数,没什么参数,直接默认构造函数然后调用即可,具体如下:

from torchvision import transforms
from PIL import Image

if __name__ == '__main__':
    img_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"
    img = Image.open(img_path)
    img2tensor = transforms.ToTensor()
    img_tensor = img2tensor(img)
    print(img_tensor)

2、Normalize

功能:输入RGB三通道的标准差和方差,输出正则化的图像矩阵

from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
import cv2

if __name__ == '__main__':
    img_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"
    img = Image.open(img_path)
    img2tensor = transforms.ToTensor()
    img_tensor = img2tensor(img)
    writer = SummaryWriter("logs")
    writer.add_image("original", img_tensor)
    trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    img_norm = trans_norm(img_tensor)
    writer.add_image("Normalize", img_norm)

    writer.close()

原图
在这里插入图片描述

正则化后的图像
在这里插入图片描述

3、Resize(非常常用)

功能
1、Resize([h, w])——对一个图像进行缩放,虽然会改变长宽比,但图像未发生裁剪,因此可以通过Resize再次还原回来
2、Resize(x) ——对短边缩放到x,长宽比不变

注意
PIL image 的size属性返回的是w, h而Resize参数顺序是h,w,切勿弄错

from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
import cv2

if __name__ == '__main__':
    img_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"
    img = Image.open(img_path)
    writer = SummaryWriter("logs")
    trans_resize = transforms.Resize((512, 512))
    resized_img = trans_resize(img)
    img2tensor = transforms.ToTensor()
    img_tensor = img2tensor(resized_img)
    writer.add_image("resized", img_tensor)
    writer.close()

缩放后的结果
在这里插入图片描述

4、Compose

功能:组合变换,参数是各种变换组成的列表“[transform1, transform2, …]”

    trans_resize = transforms.Resize((512, 512))
    img2tensor = transforms.ToTensor()
    trans = transforms.Compose([trans_resize, img2tensor])
    img_tensor = trans(img)

5、RandomCrop

功能:随机裁剪,和Resize类似

from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
import cv2

if __name__ == '__main__':
    img_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"
    img = Image.open(img_path)
    writer = SummaryWriter("logs")
    trans_random_crop = transforms.RandomCrop((300, 400))
    img2tensor = transforms.ToTensor()
    trans = transforms.Compose([trans_random_crop, img2tensor])
    for i in range(5):
        img_tensor = trans(img)
        writer.add_image("random crop", img_tensor, i)
    writer.close()

三、transforms和数据集的结合使用

方法:先查看数据集里都有啥(调试),然后根据需求加transform

import torchvision
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

trans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
if __name__ == '__main__':

    train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=trans, download=True)
    test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=trans, download=True)

    print(test_set)
    img, target = test_set[1]
    writer = SummaryWriter("logs")
    writer.add_image("pic1", img, 1)
    print(test_set.classes[target])
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
torchvision.transforms是一个图像预处理库,可以用于对图像进行各种变换,例如裁剪、缩放、旋转、翻转等。常用函数有: 1. transforms.Resize(size):调整图像大小为指定的size 2. transforms.CenterCrop(size):按照中心裁剪图像为指定的size 3. transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):随机裁剪图像为指定的size,可选参数包括padding、fill、padding_mode 4. transforms.RandomHorizontalFlip(p=0.5):按照概率p随机水平翻转图像 5. transforms.RandomVerticalFlip(p=0.5):按照概率p随机垂直翻转图像 6. transforms.RandomRotation(degrees, resample=False, expand=False, center=None):随机旋转图像degrees度,可选参数包括resample、expand、center 7. transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):随机调整图像颜色,可选参数包括brightness、contrast、saturation、hue 8. transforms.ToTensor():将图像转换为张量 9. transforms.Normalize(mean, std):对张量进行标准化,mean和std分别为均值和标准差 可以通过组合transforms函数来构建一个预处理管道。例如: ```python transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ``` 该预处理管道首先将图像大小调整为256,然后随机裁剪为224,随机水平翻转,将图像转换为张量,最后对张量进行标准化。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值