如何用Python加载数据集?以联邦学习中加载CIFAR-10为例

baseline_main.py封装过后

经过封装后,在baseline_main.py中只有一行

   train_dataset, test_dataset, _ = get_dataset(args)

util.py逐步展开

传参

也就是把args这个从命令行中获取的参数传入函数中,args.dataset选择数据集。

args:一个包含各种设置和参数的对象。这里包括选择的数据集类型(如 'cifar' 或 'mnist')、是否使用独立同分布 (IID) 数据划分、用户数量等。

详见

Python中是如何接收在终端中输入(自定义的)命令行参数的?_MikingG的博客-CSDN博客读取在终端中输入的命令行参数的背后原理及简单使用教程https://blog.csdn.net/weixin_64123373/article/details/132246130?spm=1001.2014.3001.5501

处理 CIFAR-10 数据集

如果 args.dataset 设置为 'cifar',函数将进行以下操作:

1.设置数据目录。

2.定义一个图像转换操作,将图像转换为张量,并对其进行归一化。

3.获取 CIFAR-10 的训练和测试数据集

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)

详见

Python中CIFAR10的图像数据预处理_MikingG的博客-CSDN博客以小见大,见微知著https://blog.csdn.net/weixin_64123373/article/details/1322473394.根据 args.iid 判断数据划分是否为 IID。如果是,则使用 cifar_iid 函数进行划分,否则根据 args.unequal 判断划分是否均等,并使用 cifar_noniid 函数进行划分。

        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

详见

独立同分布抽样。以CIFAR为例,Python。_MikingG的博客-CSDN博客IID(独立同分布):这里的数据被假设为独立和同分布的。也就是说,每个数据点都是从同一分布中随机、独立地抽取的。在实际操作中,这通常意味着数据在所有用户之间平均分配。https://blog.csdn.net/weixin_64123373/article/details/132250751

https://blog.csdn.net/weixin_64123373/article/details/132251015icon-default.png?t=N6B9https://blog.csdn.net/weixin_64123373/article/details/132251015

处理 MNIST 或 Fashion-MNIST 数据集

如果 args.dataset 设置为 'mnist' 或 'fmnist',函数将进行以下操作:

  • 根据所选的数据集设置数据目录。
  • 定义一个图像转换操作,将图像转换为张量,并对其进行归一化。
  • 获取 MNIST 的训练和测试数据集。
  • 根据 args.iid 判断数据划分是否为 IID。如果是,则使用 mnist_iid 函数进行划分,否则根据 args.unequal 判断划分是否均等,并可能使用 mnist_noniidmnist_noniid_unequal 函数进行划分。

(细节处理参见CIFAR)

返回值

  • train_dataset:训练数据集对象。
  • test_dataset:测试数据集对象。
  • user_groups:一个字典,其中键是用户索引,值是每个用户的相应数据。这些数据可能是均匀分布的,也可能是不均匀分布的,具体取决于 args.iidargs.unequal 的设置。

总结

get_dataset 函数是一个灵活的数据准备功能,可处理多个不同的数据集并按照多种方式对其进行划分。通过更改传递给该函数的 args 对象的属性,可以轻松控制数据准备过程。

 完整代码

def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值