baseline_main.py封装过后
经过封装后,在baseline_main.py中只有一行
train_dataset, test_dataset, _ = get_dataset(args)
util.py逐步展开
传参
也就是把args这个从命令行中获取的参数传入函数中,args.dataset选择数据集。
args
:一个包含各种设置和参数的对象。这里包括选择的数据集类型(如 'cifar' 或 'mnist')、是否使用独立同分布 (IID) 数据划分、用户数量等。
详见
处理 CIFAR-10 数据集
如果 args.dataset
设置为 'cifar',函数将进行以下操作:
1.设置数据目录。
2.定义一个图像转换操作,将图像转换为张量,并对其进行归一化。
3.获取 CIFAR-10 的训练和测试数据集
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_transform)
详见
Python中CIFAR10的图像数据预处理_MikingG的博客-CSDN博客以小见大,见微知著https://blog.csdn.net/weixin_64123373/article/details/1322473394.根据 args.iid
判断数据划分是否为 IID。如果是,则使用 cifar_iid
函数进行划分,否则根据 args.unequal
判断划分是否均等,并使用 cifar_noniid
函数进行划分。
if args.iid:
# Sample IID user data from Mnist
user_groups = cifar_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
raise NotImplementedError()
else:
# Chose euqal splits for every user
user_groups = cifar_noniid(train_dataset, args.num_users)
详见
处理 MNIST 或 Fashion-MNIST 数据集
如果 args.dataset
设置为 'mnist' 或 'fmnist',函数将进行以下操作:
- 根据所选的数据集设置数据目录。
- 定义一个图像转换操作,将图像转换为张量,并对其进行归一化。
- 获取 MNIST 的训练和测试数据集。
- 根据
args.iid
判断数据划分是否为 IID。如果是,则使用mnist_iid
函数进行划分,否则根据args.unequal
判断划分是否均等,并可能使用mnist_noniid
或mnist_noniid_unequal
函数进行划分。
(细节处理参见CIFAR)
返回值
train_dataset
:训练数据集对象。test_dataset
:测试数据集对象。user_groups
:一个字典,其中键是用户索引,值是每个用户的相应数据。这些数据可能是均匀分布的,也可能是不均匀分布的,具体取决于args.iid
和args.unequal
的设置。
总结
get_dataset
函数是一个灵活的数据准备功能,可处理多个不同的数据集并按照多种方式对其进行划分。通过更改传递给该函数的 args
对象的属性,可以轻松控制数据准备过程。
完整代码
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = cifar_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
raise NotImplementedError()
else:
# Chose euqal splits for every user
user_groups = cifar_noniid(train_dataset, args.num_users)
elif args.dataset == 'mnist' or 'fmnist':
if args.dataset == 'mnist':
data_dir = '../data/mnist/'
else:
data_dir = '../data/fmnist/'
apply_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.MNIST(data_dir, train=False, download=True,
transform=apply_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = mnist_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
else:
# Chose euqal splits for every user
user_groups = mnist_noniid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups