联邦学习基于fedml的编程

先部署何老师的fedml框架,然后了解每一部分的作用

 

 我也是纯新手,希望研究联邦学习的朋友们看见这篇文可以联系我加好友一起学习

这个文件夹是代表联邦学习两大场景,“跨设备cross-device”,“跨孤岛cross-silo”,

跨设备是说整合大量移动端和边缘设备应用程序,移动键盘之类

跨孤岛是只涉及少量相对可靠二等客户端应用程序,例如多个组织合作训练一个模型,

我先学习centralized这个文件

首先是设置变量

def add_args(parser):
    parser.add_argument(
        "--model",
        type=str,
        default="mobilenet",
        metavar="N",
        help="neural network used in training",)

    parser.add_argument(
        "--data_parallel",
        type=int,
        default=0,
        help="if distributed training"
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        metavar="N",
        help="dataset used for training",
    )

    parser.add_argument(
        "--data_dir",
        type=str,
        default="./../../../data/cifar10",
        help="data directory"
    )

    parser.add_argument(
        "--partition_method",
        type=str,
        default="hetero",
        metavar="N",
        help="how to partition the dataset on local workers",
    )

    parser.add_argument(
        "--partition_alpha",
        type=float,
        default=0.5,
        metavar="PA",
        help="partition alpha (default: 0.5)",
    )

    parser.add_argument(
        "--client_num_in_total",
        type=int,
        default=1000,
        metavar="NN",
        help="number of workers in a distributed cluster",
    )

    parser.add_argument(
        "--client_num_per_round",
        type=int,
        default=4,
        metavar="NN",
        help="number of workers",
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )


    parser.add_argument(
        "--client_optimizer",
        type=str,
        default="adam",
        help="SGD with momentum; adam",
    )

    parser.add_argument(
        "--lr",
        type=float,
        default=0.001,
        metavar="LR",
        help="learning rate (default: 0.001)",
    )

    parser.add_argument(
        "--wd",
        help="weight decay parameter;",
        type=float,
        default=0.0001,
    )

    parser.add_argument(
        "--epochs",
        type=int,
        default=5,
        metavar="EP",
        help="how many epochs will be trained locally",
    )


    parser.add_argument(
        "--comm_round",
        type=int,
        default=10,
        help="how many round of communications we shoud use",
    )

    parser.add_argument(
        "--is_mobile",
        type=int,
        default=0,
        help="whether the program is running on the FedML-Mobile server side",
    )

    parser.add_argument(
        "--frequency_of_train_acc_report",
        type=int,
        default=10,
        help="the frequency of training accuracy report",
    )


    parser.add_argument(
        "--frequency_of_test_acc_report",
        type=int,
        default=1,
        help="the frequency of test accuracy report",
    )


    parser.add_argument(
        "--gpu_sever_num",
        type=int,
        default=1,
        help="gpu_server_num"
    )

    parser.add_argument(
        "--gpu_num_per_sever",
        type=int, 
        default=4, 
        help="gpu_num_per_server"
    )
    
    parser.add_argument(
        "--ci",
        type=int,
        default=0,
        help="CI",
    )
    
    parser.add_argument(
        "--gpu",
        type=int,
        default=0,
        help="gpu",
    )
    
    parser.add_argument(
        "--gpu_util",
        type=str,
        default="0",
        help="gpu utils",
    )
    
    parser.add_argument(
        "--local_rank",
        type=int,
        default=0,
        help="given by torch.distributed.launch"
    )
    
    args=parser.parse_args()
    return args


然后在看参数配置,设置dataloader

def load_data(args,dataset_name):
    if dataset_name=="mnist":
        logging.info("load_data.dataset_name=%s"%dataset_name)
        (
            client_num,
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            class_num,
        )=load_partition_data_mnist(args.batch_size)

        args.client_num_in_total=client_num

    elif dataset_name=="femnist":
        logging.info("load_data.dataset_name==%s"%dataset_name)
        (
            client_num,
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            trian_data_local_dict,
            tets_data_local_dict,
            class_num,
        )=load_partition_data_federated_emnist(args.dataset,args.data_dir)
        args.client_num_in_total=client_num

    elif dataset_name=="shakespeare":
        logging.info("load_data.dataset_name=%s"%dataset_name)
        (
            client_num,
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            class_num,
        )=load_partition_data_federated_shakespeare(args.batch_size)
        args.client_num_in_total=client_num

    elif dataset_name=="fed_shakespeare":
        logging.info("load_data.dataset_name=%s"%dataset_name)
        (
            client_num,
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            class_num,
        )=load_partition_data_federated_shakespeare(args.dataset,args.batch_size)
        args.client_num_in_total=client_num

    elif dataset_name=="fed_cifar100":
        logging.info("load_data.dataset_name=%s"%dataset_name)
        (
            client_num,
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            class_num,
        )=load_partition_data_federated_cifar100(args.dataset,args.batch_size)
        args.client_num_in_total=client_num

    elif dataset_name=="stackoverflow_lr":
        logging.info("load_data.dataset_name=%s"%dataset_name)
        (
            client_num,
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            class_num,

        )=load_partition_data_federated_stackoverflow_lr(args.dataset,args.data_dir)
        args.client_num_in_total=client_num

    elif dataset_name=="stackoverflow_nwp":
        logging.info("load_data.dataset_name=%s"%dataset_name)
        (
            client_num,
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            class_num,

        )=load_partition_data_federated_stackoverflow_nwp(args.dataset,args.data_dir)
        args.client_num_in_total=client_num

    elif dataset_name in ["ILSVRC2012","ILSVRC2012_hdf5"]:
        if args.data_parallel==1:
            logging.info("load_data.dataset_name=%s"%dataset_name)
            (
                train_data_num,
                test_data_num,
                train_data_global,
                test_data_global,
                train_data_local_num_dict,
                train_data_local_dict,
                test_data_local_dict,
                class_num,
            )=distributed_centralized_ImageNet_loader(
                dataset=dataset_name,
                data_dir=args.data_dir,
                world_size=args.world_size,
                rank=args.rank,
                batch_size=args.batch_size,
            )

        else:
            logging.info("load_data.dataswet_name=%s"%dataset_name)
            (
                train_data_num,
                test_data_num,
                train_data_global,
                test_data_global,
                train_data_local_num_dict,
                train_data_local_dict,
                test_data_local_dict,
                class_num,
            )=load_partition_data_ImageNet(

                dataset=dataset_name,
                data_dir=args.daat_dir,
                partition_method=None,
                partition_alpha=None,
                client_num=args.client_num_in_total,
                batch_size=args.batch_size,
            )
    elif dataset_name=="gld23k":
        logging.info("load_data.dataset_name=%s"%dataset_name)
        args.client_num_in_total=233
        fed_train_map_file=os.path.join(
            args.data_dir,"data_user_dict/gld23k_user_dict_train.csv"
        )
        fed_test_map_file=os.path.join(
            args.data_dir,"data_user_dict/gld23k_user_dict_test.csv"
        )

        args.data_dir=os.path.join(args.data_dir,"images")

        (
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            class_num,
        )=load_partition_data_landmarks(
            dataset=dataset_name,
            data_dir=args.data_dir,
            fed_train_map_file=fed_train_map_file,
            fed_test_map_file=fed_test_map_file,
            partition_method=None,
            partition_alpha=None,
            client_number=args.client_num_in_total,
            batch_size=args.batch_size


        )

    elif dataset_name=="gld160k":
        logging.info("load_data.data_name=%s"%dataset_name)
        args.client_num_in_total=1262
        fed_train_map_file=os.path.join(

            args.data_dir,"data_user_dict/gld160k_user_dict_train.csv"
        )

        fed_test_map_file=os.path.join(
            args.data_dir,"data_user_dict/gld160k_user_dict_test.csv"
        )

        args.data_dir=os.path.join(args.data_dir,"images")


        (
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            class_num,
        )=load_partition_data_landmarks(
            dataset=dataset_name,
            data_dir=args.data_dir,
            fed_train_map_file=fed_train_map_file,
            fed_test_map_file=fed_test_map_file,
            partition_method=None,
            partition_alpha=None,
            client_number=args.client_num_in_total,
            batch_size=args.batch_size,
        )

    else:
      if dataset_name=="cifar10":
        data_loader=load_partition_data_cifar100
      elif dataset_name=="cifar100":
        data_loader=load_partition_data_cifar100
      elif dataset_name=="cinic10":
          data_loader=load_partition_data_cinic10
      else:
          data_loader=load_partition_data_cifar10

      (
        train_data_num,
        test_data_num,
        train_data_global,
        test_data_global,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        class_num,
      )=data_loader(
        args.dataset,
        args.data_dir,
        args.partitiom_method,
        args.partition_alpha,
        args.client_num_in_total,
        args.batch_size,
      )
    
    dataset=[
        train_data_num,
        test_data_num,
        train_data_global,
        test_data_global,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        class_num,
    ]
    return dataset

接下来是设置模型

def create_model(args,model_name,output_dim):
    logging.info(
        "create_model.model_name=%s,output_dim=%s"%(model_name,output_dim)
    )

    model=None
    if model_name=="lr" and args.dataset=="mnist":
        logging.info("LogisticRegression+MNIST")
        model=LogisticRegression(28*28,output_dim)

    elif model_name== "cnn" and args.dataset=="femni":
        logging.info("CNN+FederatedEMNIST")
        model=CNN_DropOut(False)
    elif model_name=="resnet18_gn" and args.dataset=="fed_cifar100":
        logging.info("ResNet18_GN+Federated_CIFAR100")
        model=resnet18()
    elif model_name=="rnn" and args.dataset=='shakespeare':
        logging.info("RNN+shakespeare")
        model=RNN_OriginalFedAvg()
    elif model_name=="rnn" and args.dataset=="fed_shakespeare":
        logging.info("RNN+fed_shakespeare")
        mofdel=RNN_OriginalFedAvg()
    elif model_name=="lr" and args.dataset=="stackoverflow_lr":
        logging.info("lr+stackoverflow_lr")
        model=LogisticRegression(10004,output_dim)
    elif model_name=="rnn" and args.dataset=="stackoverflow_nwp":
        logging.info("CNN+stackoverflow_nwp")
        model=RNN_StackOverFlow()
    elif model_name=="resnet56":
        model=resnet56(class_num=output_dim)
    elif model_name=="mobilenet":
        model=mobilenet(class_num=output_dim)
    elif model_name=="mobilenet_v3":
        model=MobileNetV3(model_mode="LARGE",num_classes=output_dim)
    elif model_name=="efficientnet":
        efficientnet_dict={
            "efficientnet-b0": (1.0, 1.0, 224, 0.2),
            "efficientnet-b1": (1.0, 1.1, 240, 0.2),
            "efficientnet-b2": (1.1, 1.2, 260, 0.3),
            "efficientnet-b3": (1.2, 1.4, 300, 0.3),
            "efficientnet-b4": (1.4, 1.8, 380, 0.4),
            "efficientnet-b5": (1.6, 2.2, 456, 0.4),
            "efficientnet-b6": (1.8, 2.6, 528, 0.5),
            "efficientnet-b7": (2.0, 3.1, 600, 0.5),
            "efficientnet-b8": (2.2, 3.6, 672, 0.5),
            "efficientnet-l2": (4.3, 5.3, 800, 0.5),
        }
        model=EfficientNet.from_name(
            model_name="efficientnet-b0",num_classes=output_dim
        )
        
        return model

接下来是主函数

if __name__=="__main__":
    parser=argparse.ArgumentParser()
    args=add_args(parser)
    args.world_size=len(args.gpu_util.split(","))
    worker_number=1
    process_id=0
    
    if args.data_parallel==1:
        torch.distributed.init_process_group(backend="nccl",init_method="env://")
        args.rank = torch.distributed.get_rank()
        gpu_util = args.gpu_util.split(",")
        gpu_util = [int(item.strip()) for item in gpu_util]
        # device = torch.device("cuda", local_rank)
        torch.cuda.set_device(gpu_util[args.rank])
        process_id = args.rank
    else:
        args.rank = 0
        
    logging.info(args)
    str_process_name = "Fedml (single):" + str(process_id)
    setproctitle.setproctitle(str_process_name)
    logging.basicConfig(
        level=logging.INFO,
        # logging.basicConfig(level=logging.DEBUG,
        format=str(process_id)
               + " - %(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
        datefmt="%a, %d %b %Y %H:%M:%S",
    )
    hostname = socket.gethostname()
    logging.info(
        "#############process ID = "
        + str(process_id)
        + ", host name = "
        + hostname
        + "########"
        + ", process ID = "
        + str(os.getpid())
        + ", process Name = "
        + str(psutil.Process(os.getpid()))
    )

    # initialize the wandb machine learning experimental tracking platform (https://www.wandb.com/).
    if process_id == 0:
        wandb.init(
            # project="federated_nas",
            project="fedml",
            name="Fedml (central)"
                 + str(args.partition_method)
                 + "r"
                 + str(args.comm_round)
                 + "-e"
                 + str(args.epochs)
                 + "-lr"
                 + str(args.lr),
            config=args,
        )

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    logging.info("process_id = %d, size = %d" % (process_id, args.world_size))

    # load data
    dataset = load_data(args, args.dataset)
    [
        train_data_num,
        test_data_num,
        train_data_global,
        test_data_global,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        class_num,
    ] = dataset

    # create model.
    # Note if the model is DNN (e.g., ResNet), the training will be very slow.
    # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
    model = create_model(args, model_name=args.model, output_dim=dataset[7])

    if args.data_parallel == 1:
        device = torch.device("cuda:" + str(gpu_util[args.rank]))
        model.to(device)
        model = DistributedDataParallel(
            model, device_ids=[gpu_util[args.rank]], output_device=gpu_util[args.rank]
        )
    else:
        device = torch.device(
            "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu"
        )
    # start "federated averaging (FedAvg)"
    single_trainer = CentralizedTrainer(dataset, model, device, args)
    single_trainer.train()

跨设备的代码学习

1.mqtt_s3_fedavg_cifar10_resnet20_example

首先是dataset.py

import MNN
from torchvision.datasets import CIFAR10

F=MNN.expr

class Cifar10Dataset(MNN.data.dataset):
    def __init__(self,training_dataset=True):
        super(Cifar10Dataset,self).__init__()
        self.is_training_dataset=training_dataset
        trainset=CIFAR10(root="./data", train=True, download=True)
        testset=CIFAR10(root="./data", train=False, download=True)
        if self.is_training_dataset:
            self.data=trainset.data.transpose(0,3,1,2)/255.0
            self.labels=trainset.targets
        else:
            self.data=testset.data.transpose(0,3,1,2)/255.0
            self.labels=testset.targets


    def __getitem__(self, index):
        dv=F.const(
            self.data[index].flatten().tolist(),[3,32,32],F.data_format.NCHW

        )
        dl=F.const(
            [self.labels[index]],[],F.data_format.NCHW,F.dtype.unit8
        )
        return [dv],[dl]
    def __len__(self):
        if self.is_training_dataset:
            return 50000
        else:
            return 10000

然后是torch_sever.py,感觉和torch的编程很像

import MNN

import fedml
from fedml.cross_device import ServerMNN
from my_dataset import Cifar10Dataset

if __name__=="__main__":
    args=fedml.init()
    device=fedml.device.get_device(args)
    train_dataset=Cifar10Dataset(True)
    test_dataset=Cifar10Dataset(False)
    train_loader=MNN.data.Dataloader(train_dataset,batch_size=64,shuffle=True)
    test_loader=MNN.data.Dataloader(
        test_dataset,batch_size=args.batch_size,shuffle=False
    )
    
    class_num=10
    model=fedml.model.create(args,output_dim=class_num)
    sever=ServerMNN(
        args,device,test_loader,None
    )
    sever.run

采用yaml控制参数

common_args:
  training_type:"cross_device"
  using_mlops:false
  random_seed:0
  config_version:release
environment_args:
  bootstrap: config / bootstrap.sh

data_args:
  dataset: "cifar"
  data_cache_dir: ~/fedml_data
  partition_method: "hetero"
  partition_alpha: 0.5
  train_size: 10000
  test_size: 5000

model_args:
  model: "resnet20"
  deeplearning_backend: "mnn"
  model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
  global_model_file_path: "./model_file_cache/global_model.mnn"

train_args:
  federated_optimizer: "FedAvg"
  client_id_list: "[138]"
  client_num_in_total: 1
  client_num_per_round: 1
  comm_round: 3
  epochs: 1
  batch_size: 100
  client_optimizer: sgd
  learning_rate: 0.03
  weight_decay: 0.001
validation_args:
  frequency_of_the_test: 5
device_args:
  worker_num: 1 # this only reflects on the client number, not including the server
  using_gpu: false
  gpu_mapping_file: config/gpu_mapping.yaml
  gpu_mapping_key: mapping_default
comm_args:
  backend: "MQTT_S3_MNN"
  mqtt_config_path: config/mqtt_config.yaml
  s3_config_path: config/s3_config.yaml

tracking_args:
  log_file_dir: ./log
  enable_wandb: false
  wandb_project: fedml
  run_name: fedml_torch_fedavg_cifar_lr

2.mqtt_s3_fedavg_mnist_lenet_example

dataset.py

import MNN
from torchvision.datasets import MNIST

F=MNN.expr

class MnistDataset(MNN.data.Dataset):
    def __init__(self,training_dataset=True):
        super(MnistDataset,self).__init__()
        self.is_training_dataset=training_dataset
        trainset=MNIST(root="./data", train=True, download=True)
        testset = MNIST(root="./data", train=False, download=True)
        if self.is_training_dataset:
            self.data=trainset.data/255.0
            self.labels=trainset.targets
        else:
            self.data=testset.data/255.0
            self.labels=testset.targets
    def __getitem__(self, index):
        dv=F.const(
            self.data[index].flatten().tolist(),
            [1,28,28],
            F.data_format.NCHW

        )
        dl=F.const([self.labels[index]],[],F.data_format.NCHW,F.dtype.unit8)
        def __len__(self):
            if self.is_training_dataset:
                return 60000
            else:
                return 10000
import MNN

import fedml
from fedml.cross_device import ServerMNN
from my_dataset import MnistDataset

if __name__=="__main__":
    args=fedml.init()
    device=fedml.device.get_device(args)
    train_dataset=MnistDataset(True)
    test_dataset=MnistDataset(False)
    train_dataloader=MNN.data.DataLoader(train_dataset,batch_size=64,shuffle=True)
    test_dataloader=MNN.data.DatatLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    class_num=10
    model=fedml.model.create(args,output_dim=class_num)
    sever=ServerMNN(
        args,device,test_dataloader,None
    )
    sever.run()

  • 5
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值