MindSpore数据预处理,并生成batch装载数据

# 常用转化用算子
import mindspore.dataset.transforms.c_transforms as c
# 图像转化用算子
import mindspore.dataset.vision.c_transforms as cv
# mind spore.common包中会有诸如type形态转变、权重初始化等的常规工具。
from mindspore.common import dtype
# Mind spore模块主要用于本次实验卷积神经网络的构建,包括很多子模块。
import mindspore
# 主要包括CI FAR-10数据集的载入与处理,也可以自定义数据集。
import mindspore.dataset as ds



# 数据预处理,再生成装载数据

def my_dataset(my_data, batch__size=32, status="train"):

    # 设置类型变化
    typecast_op = c.TypeCast(dtype.int32)
    my_data = my_data.map(input_columns="label", operations=typecast_op)

    if status == "train":
        # 设置随机裁剪参数
        random_crop_op = cv.RandomCrop([32, 32], [4, 4, 4, 4])
        my_data = my_data.map(input_columns="image", operations=random_crop_op)
        # 设置随机翻转参数
        random_horizontal_op = cv.RandomHorizontalFlip()
        my_data = my_data.map(input_columns="image", operations=random_horizontal_op)

    # 重设大小
    resize_op = cv.Resize((32, 32))
    my_data = my_data.map(input_columns="image", operations=resize_op)

    # 归一化
    rescale = 1.0 / 255.0
    # 平移
    shift = 0.0
    rescale_op = cv.Rescale(rescale, shift)
    my_data = my_data.map(input_columns="image", operations=rescale_op)

    # RGB三通道分别设定mean和std
    normalize_op = cv.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    my_data = my_data.map(input_columns="image", operations=normalize_op)

    # 设置通道变化
    channel_swap_op = cv.HWC2CHW()
    my_data = my_data.map(input_columns="image", operations=channel_swap_op)

    # 打乱顺序shuffle
    my_data = my_data.shuffle(buffer_size=1000)

    # 切分数据集到batch_size
    my_data = my_data.batch(batch__size, drop_remainder=True)

    return my_data
  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在PyTorch中,数据预处理通常涉及以下几个步骤: 1. 加载数据集:使用PyTorch的数据加载器(如`torchvision.datasets`)加载数据集。可以是常见的图像数据集(如MNIST、CIFAR10)或自定义数据集。 2. 转换数据:使用`torchvision.transforms`模块中的转换函数对数据进行预处理。常见的转换包括缩放、裁剪、旋转、归一化等。可以根据需求组合多个转换操作。 3. 创建数据加载器:将转换后的数据集传递给`torch.utils.data.DataLoader`来创建一个数据加载器。数据加载器可以指定批处理大小、并发加载等参数。 下面是一个简单的示例,演示如何使用PyTorch进行数据预处理: ```python import torch import torchvision import torchvision.transforms as transforms # 1. 加载数据集 train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True) # 2. 转换数据 transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为Tensor transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1, 1]范围 ]) train_dataset = train_dataset.transform(transform) # 3. 创建数据加载器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) ``` 在这个示例中,我们加载了MNIST数据集,并将图像转换为Tensor,并进行了归一化处理。然后使用`DataLoader`创建了一个批处理大小为64的数据加载器,同时打乱了数据的顺序。 这只是一个简单的例子,根据具体需求,你可能需要进行更复杂的数据预处理操作。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值