Per_FedAvg源码分析其二

Per_FeAVG源码分析——data目录下:

init.py

不做分析

utils.py
字典:DATASET_DICT

作用:它将字符串键(如 "mnist""cifar")映射到对应的类(MNISTDatasetCIFARDataset

DATASET_DICT = {
    "mnist": MNISTDataset,
    "cifar": CIFARDataset,
}
函数:CURRENT_DIR

作用:CURRENT_DIR 被设置为当前 Python 脚本文件的父目录的绝对路径

CURRENT_DIR = Path(__file__).parent.abspath()
函数:get_dataloader

作用:从一个预处理好的 pickle 文件中加载数据集,并根据给定的 client_id 分割为训练集和验证集。

def get_dataloader(dataset: str, client_id: int, batch_size=20, valset_ratio=0.1):
    pickles_dir = CURRENT_DIR / dataset / "pickles"
    if os.path.isdir(pickles_dir) is False:
        raise RuntimeError("Please preprocess and create pickles first.")

    with open(pickles_dir / str(client_id) + ".pkl", "rb") as f:
        client_dataset: DATASET_DICT[dataset] = pickle.load(f)

    val_num_samples = int(valset_ratio * len(client_dataset))
    train_num_samples = len(client_dataset) - val_num_samples

    trainset, valset = random_split(
        client_dataset, [train_num_samples, val_num_samples]
    )
    trainloader = DataLoader(trainset, batch_size, drop_last=True)
    valloader = DataLoader(valset, batch_size)

    return trainloader, valloader

函数:get_client_id_indices(dataset)

作用:从一个特定的 pickle 文件中加载并返回关于数据集分割的信息。从一个 seperation.pkl 文件中读取训练集、测试集以及总数目的索引或标识符。

def get_client_id_indices(dataset):
    dataset_pickles_path = CURRENT_DIR / dataset / "pickles"
    with open(dataset_pickles_path / "seperation.pkl", "rb") as f:
        seperation = pickle.load(f)
    return (seperation["train"], seperation["test"], seperation["total"])
preprocess.py
函数:CURRENT_DIR

作用:CURRENT_DIR 被设置为当前 Python 脚本文件的父目录的绝对路径

CURRENT_DIR = Path(__file__).parent.abspath()
字典:DATASET

作用:数据集名称映射到了两个元组,可以基于数据集名称来动态地加载和实例化相应的数据集

DATASET = {
    "mnist": (MNIST, MNISTDataset),
    "cifar": (CIFAR10, CIFARDataset),
}
字典:MEAN

作用:用来存储不同数据集的像素均值。这些均值通常用于数据归一化,只包含一个灰度通道,因此其均值是一个单元素元组 (0.1307,)。这意味着当你对 MNIST 数据集进行归一化时,你会从每个像素值中减去 0.1307。

MEAN = {
    "mnist": (0.1307,),
    "cifar": (0.4914, 0.4822, 0.4465),
}
字典:STD

作用:存储不同数据集的像素标准差。使用归一化时,标准差通常与均值一起使用,以确保数据的每个特征(在这个案例中是像素值)都有相似的尺度。它只包含一个灰度通道,因此其标准差是一个单元素元组 (0.3015,)。这意味着在归一化 MNIST 数据时,每个像素值都会根据其灰度通道的标准差进行缩放。

STD = {
    "mnist": (0.3015,),
    "cifar": (0.2023, 0.1994, 0.2010),
}
函数:preprocess()

作用:用于预处理数据集,在联邦学习或分布式学习的场景中,数据需要在多个客户端(或节点)之间分配。

def preprocess(args: Namespace) -> None:
    #参数和目录设置:
    #设置数据集目录(dataset_dir)和pickle文件目录(pickles_dir)。
    dataset_dir = CURRENT_DIR / args.dataset
    pickles_dir = CURRENT_DIR / args.dataset / "pickles"

    #设置随机数生成器的种子,以确保结果的可重复性。
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    num_train_clients = int(args.client_num_in_total * args.fraction)
    num_test_clients = args.client_num_in_total - num_train_clients

    #数据转换
    #定义了一个转换transform,它只包含标准化(假设MEAN和STD是预定义的字典,包含了每个数据集的均值和标准差),初始化了训练集和测试集的统计信息字典
    transform = transforms.Compose(
        [transforms.Normalize(MEAN[args.dataset], STD[args.dataset]),]
    )
    target_transform = None
    trainset_stats = {}
    testset_stats = {}

    #目录和文件处理:
    #检查数据集目录是否存在,如果不存在则创建它。
    #如果pickle目录已经存在,则删除它(可能是为了确保没有旧的pickle文件干扰)。
    #创建新的pickle目录。
    if not os.path.isdir(CURRENT_DIR / args.dataset):
        os.mkdir(CURRENT_DIR / args.dataset)
    if os.path.isdir(pickles_dir):
        os.system(f"rm -rf {pickles_dir}")
    os.mkdir(f"{pickles_dir}")

    #加载数据集
    #从预定义的DATASET字典中获取原始和目标数据集。
    #使用ori_dataset类创建训练集和测试集。注意,训练集在加载时还指定了download=True(用于自动下载数据集),而测试集没有。两者都使用了transforms.ToTensor()进行初步的数据转换。
    ori_dataset, target_dataset = DATASET[args.dataset]
    trainset = ori_dataset(
        dataset_dir, train=True, download=True, transform=transforms.ToTensor()
    )
    testset = ori_dataset(dataset_dir, train=False, transform=transforms.ToTensor())

    #分配类别到客户端
    #根据args.classes确定每个客户端应有的类别数量(默认为10类)
    #使用randomly_alloc_classes函数将类别随机分配给训练集和测试集的客户端
    #randomly_alloc_classes函数还返回每个子集的统计信息(
    num_classes = 10 if args.classes <= 0 else args.classes
    all_trainsets, trainset_stats = randomly_alloc_classes(
        ori_dataset=trainset,
        target_dataset=target_dataset,
        num_clients=num_train_clients,
        num_classes=num_classes,
        transform=transform,
        target_transform=target_transform,
    )
    all_testsets, testset_stats = randomly_alloc_classes(
        ori_dataset=testset,
        target_dataset=target_dataset,
        num_clients=num_test_clients,
        num_classes=num_classes,
        transform=transform,
        target_transform=target_transform,
    )

    #将所有训练集和测试集组合到一个列表all_datasets中
    all_datasets = all_trainsets + all_testsets

    #保存客户端数据集为pickle文件
    #通过enumerate(all_datasets)遍历all_datasets列表中的每个数据集和对应的client_id(即客户端的ID)。
    #使用pathlib的/操作符(如果pickles_dir是pathlib.Path对象)来构建pickle文件的路径。
    #使用pickle.dump()函数将每个数据集保存到对应的pickle文件中。
    for client_id, dataset in enumerate(all_datasets):
        with open(pickles_dir / str(client_id) + ".pkl", "wb") as f:
            pickle.dump(dataset, f)
     
    #保存客户端索引
    #创建一个字典,其中包含三个键:“train”、“test”和“total”,“total”的值就是总的客户端数量。
    #“train”和“test”的值是客户端ID的列表,分别代表训练集和测试集的客户端。这里假设num_train_clients表示训练集客户端的数量,而args.client_num_in_total表示总的客户端数量。
    #使用pickle.dump()函数将这个字典保存到名为“seperation.pkl”的文件中。这个文件用于在后续的训练和测试过程中区分哪些客户端是训练集,哪些是测试集
    with open(pickles_dir / "seperation.pkl", "wb") as f:
        pickle.dump(
            {
                "train": [i for i in range(num_train_clients)],
                "test": [i for i in range(num_train_clients, args.client_num_in_total)],
                "total": args.client_num_in_total,
            },
            f,
        )
        
    #保存数据集统计信息
    #trainset_stats和testset_stats是在之前的预处理步骤中收集的训练集和测试集的统计信息。
    #使用json.dump()函数将这些统计信息保存为JSON格式的文件“all_stats.json”。这个文件用于在后续的模型训练和评估过程中提供数据集的相关信息
    with open(dataset_dir / "all_stats.json", "w") as f:
        json.dump({"train": trainset_stats, "test": testset_stats}, f)
函数:randomly_alloc_classes

作用:将原始数据集(ori_dataset)中的样本随机分配给多个客户端(或用户),同时确保每个客户端获得指定数量的不同类别的样本。函数还返回了分配给每个客户端的数据集列表和相应的统计信息。

def randomly_alloc_classes(
    ori_dataset: Dataset,
    target_dataset: Dataset,
    num_clients: int,
    num_classes: int,
    transform=None,
    target_transform=None,
) -> Tuple[List[Dataset], Dict[str, Dict[str, int]]]:
    
    #分配样本
    #使用noniid_slicing函数来将ori_dataset中的样本分配给num_clients个客户端。这个函数应该返回一个字典,其中键是客户端ID,值是分配给该客户端的样本索引列表。
    dict_users = noniid_slicing(ori_dataset, num_clients, num_clients * num_classes)
    stats = {}
    #收集统计信息
    #对于每个客户端,从ori_dataset中提取标签(ori_dataset.targets),然后根据分配给该客户端的样本索引列表计算标签的类别分布。这些统计信息被存储在stats字典中。
    for i, indices in dict_users.items():
        targets_numpy = np.array(ori_dataset.targets)
        stats[f"client {i}"] = {"x": 0, "y": {}}
        stats[f"client {i}"]["x"] = len(indices)
        stats[f"client {i}"]["y"] = Counter(targets_numpy[indices].tolist())
    datasets = []
    #创建数据集
    #使用target_dataset类从ori_dataset中创建子集,每个子集对应于一个客户端。这是通过从ori_dataset中提取分配给该客户端的样本,并传递给target_dataset的构造函数来实现的。
    for indices in dict_users.values():
        datasets.append(
            target_dataset(
                [ori_dataset[i] for i in indices],
                transform=transform,
                target_transform=target_transform,
            )
        )
    return datasets, stats
函数:__name__==“main”

作用:基本的命令行参数解析设置,它使用argparse库来从命令行获取参数。这些参数包括数据集类型、客户端总数、训练客户端的比例、每个客户端数据所属的类别数量以及随机种子。

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset", type=str, choices=["mnist", "cifar"], default="mnist",
    )
    parser.add_argument("--client_num_in_total", type=int, default=200)
    parser.add_argument(
        "--fraction", type=float, default=0.9, help="Propotion of train clients"
    )
    parser.add_argument(
        "--classes",
        type=int,
        default=2,
        help="Num of classes that one client's data belong to.",
    )
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    preprocess(args)
dataset.py
类:MNISTDataset(Dataset)
函数:init

作用:用于初始化一个对象的状态

def __init__(
    #参数
    self,
    subset=None,
    data=None,
    targets=None,
    transform=None,
    target_transform=None,
) -> None:
    #处理
    self.transform = transform
    self.target_transform = target_transform
    #如果data和targets都非空,它将data增加一个新的维度(使用unsqueeze(1)),并将data和targets设置为对象的属性。
    if (data is not None) and (targets is not None):
        self.data = data.unsqueeze(1)
        self.targets = targets
    # 如果subset非空,它将遍历subset,检查元组中的每个元素是否为张量。如果不是,它将使用torch.tensor将其转换为张量。然后,它使用torch.stack将数据和标签分别堆叠成张量,并设置为对象的属性。   
    elif subset is not None:
        self.data = torch.stack(
            list(
                map(
                    lambda tup: tup[0]
                    if isinstance(tup[0], torch.Tensor)
                    else torch.tensor(tup[0]),
                    subset,
                )
            )
        )
        self.targets = torch.stack(
            list(
                map(
                    lambda tup: tup[1]
                    if isinstance(tup[1], torch.Tensor)
                    else torch.tensor(tup[1]),
                    subset,
                )
            )
        )
    #如果data和targets以及subset都为空,则抛出一个ValueError,说明需要提供数据格式。
    else:
        raise ValueError(
            "Data Format: subset: Tuple(data: Tensor / Image / np.ndarray, targets: Tensor) OR data: List[Tensor]  targets: List[Tensor]"
        )
函数:getitem

作用:允许类的实例像列表、元组或其他可迭代对象那样进行索引访问。在你提供的上下文中,这个方法通常用于数据加载器(如PyTorch的DataLoader),以便在训练或评估模型时能够按索引访问数据集中的单个样本。

    def __getitem__(self, index):
        data, targets = self.data[index], self.targets[index]

        if self.transform is not None:
            data = self.transform(self.data[index])

        if self.target_transform is not None:
            targets = self.target_transform(self.targets[index])

        return data, targets
函数:len

作用:确定self.data的长度

    def __len__(self):
        return len(self.targets)
类:CIFARDataset(Dataset)
函数:init

作用:用于初始化一个对象的状态

def __init__(
    self,
    subset=None,
    data=None,
    targets=None,
    transform=None,
    target_transform=None,
) -> None:
    self.transform = transform
    self.target_transform = target_transform
    if (data is not None) and (targets is not None):
        self.data = data.unsqueeze(1)
        self.targets = targets
    elif subset is not None:
        self.data = torch.stack(
            list(
                map(
                    lambda tup: tup[0]
                    if isinstance(tup[0], torch.Tensor)
                    else torch.tensor(tup[0]),
                    subset,
                )
            )
        )
        self.targets = torch.stack(
            list(
                map(
                    lambda tup: tup[1]
                    if isinstance(tup[1], torch.Tensor)
                    else torch.tensor(tup[1]),
                    subset,
                )
            )
        )
    else:
        raise ValueError(
            "Data Format: subset: Tuple(data: Tensor / Image / np.ndarray, targets: Tensor) OR data: List[Tensor]  targets: List[Tensor]"
        )
函数:getitem

作用:允许类的实例像列表、元组或其他可迭代对象那样进行索引访问。在你提供的上下文中,这个方法通常用于数据加载器(如PyTorch的DataLoader),以便在训练或评估模型时能够按索引访问数据集中的单个样本。

def __getitem__(self, index):
    img, targets = self.data[index], self.targets[index]

    if self.transform is not None:
        img = self.transform(self.data[index])

    if self.target_transform is not None:
        targets = self.target_transform(self.targets[index])

    return img, targets
函数:len

作用:返回长度

def __len__(self):
    return len(self.targets)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值