基于cross_silo做联邦学习编程的学习

框架是fedml

我是一名初学者,若是有也研究联邦学习的朋友看见这篇博文,欢迎私信或者加我好友,一起讨论一起学习。

我从client开始学习

先是客户端初始化,clientinitialize

from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL, FEDML_CROSS_SILO_SCENARIO_HORIZONTAL
from .fedml_client_master_manager import ClientMasterManager
from .fedml_trainer_dist_adapter import TrainerDistAdapter


def init_client(
        args,
        device,
        comm,
        client_rank,
        client_num,
        model,
        train_data_num,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        model_trainer=None,
):
    backend=args.backend
    trainer_dist_adapter=get_trainer_dist_adapter(
        args,
        device,
        client_rank,
        model,
        train_data_num,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        model_trainer,

    )
    if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:#垂直划分
        if args.proc_rank_in_silo == 0:
            client_manager=get_client_manager_master(
                args,trainer_dist_adapter,comm,client_rank, client_num, backend
            )
        else:
            client_manager=get_client_manager_salve(args, trainer_dist_adapter)

    elif args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL:#水平划分
        client_manager=get_client_manager_master(
            args,trainer_dist_adapter,comm,client_rank,client_num,backend
        )
    else:
        raise Exception(
            "we do not support {}. Please check whether this is typo.".format(
                args.scenario
            )
        )
    client_manager.run()#配置好了客户端的管理,开始运行
def get_trainer_dist_adapter(
        args,
        device,
        client_rank,
        model,
        train_data_num,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        model_trainer,
):
    return TrainDistAdapter(
        args,
        device,
        client_rank,
        model,
        train_data_num,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        model_trainer,
    )
def get_client_manager_master(
        args,trainer_dist_adapter,comm,client_rank,client_num,backend
):
    return ClientMasterManager(#这个函数后文还有出现
        args,trainer_dist_adapter,comm,client_rank,client_num,backend
    )
def get_client_manager_salve(args,trainer_dist_adapter):
    from .fedml_client_slave_manager import ClientSlaveManager
    return ClientSlaveManager(args,trainer_dist_adapter)
#这两个函数后续都有说明

2.fedml_client_master_manager.py

这个用于横向联邦学习,另外一个用于纵向,因为我要做的是横向,所以我需要这个,要求后台可以支持mpi多机并行运算,有一个参数comm之前一直不知道他是什么意思,现在这个文件,感觉应该是comm_round全局迭代轮数

这个是我认为比较重要的文件之一,在这个上面做一些改动

import json
import logging
import platform
import time
import torch.distributed as dist

from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
from .message_define import MyMessage
from .utils import convert_model_params_from_ddp, convert_model_params_to_ddp
from ...core.distributed.client.client_manager import ClientManager
from ...core.distributed.communication.message import Message
from ...core.mlops.mlops_metrics import MLOpsMetrics
from ...core.mlops.mlops_profiler_event import MLOpsProfilerEvent

class ClientMasterManager(ClientManager):
    def __init__(
            self,args,trainer_dist_adapter,comm=None,rank=0,size=0,backend="MPI"
    ):
        super().__init__(args,comm,rank,size,backend)
        self.trainer_dist_adapter=trainer_dist_adapter
        self.args=args
        self.num_rounds=args.comm_round
        self.round_idx=0
        self.rank=rank
        self.client_real_ids=json.loads(args.client_id_list)
        #读取客户端的id
        logging.info("self.client_real_ids = {}".format(self.client_real_ids))
        # for the client, len(self.client_real_ids)==1: we only specify its client id in the list, not including others.
        self.client_real_id = self.client_real_ids[0]
        if hasattr(self.args, "using_mlops") and self.args.using_mlops:
            self.mlops_metrics = MLOpsMetrics()
            self.mlops_metrics.set_messenger(self.com_manager_status, args)
            self.mlops_event = MLOpsProfilerEvent(self.args)
        #判断是否适用mlops

#登记收信的操作者
    def register_message_receive_handlers(self):
        self.register_message_receive_handler(
            MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready
        )

        self.register_message_receive_handler(
            MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.handle_message_check_status
        )

        self.register_message_receive_handler(
            MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init
        )
        self.register_message_receive_handler(
            MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT,
            self.handle_message_receive_model_from_server,
        )

        self.register_message_receive_handler(
            MyMessage.MSG_TYPE_S2C_FINISH, self.handle_message_finish,
        )
    def handle_message_connection_ready(self, msg_params):
        logging.info("Connection is ready!")
        if not self.has_sent_online_msg:
            self.has_sent_online_msg = True
            self.send_client_status(0)

            if hasattr(self.args, "using_mlops") and self.args.using_mlops:
                # Notify MLOps with training status.
                self.report_training_status(
                    MyMessage.MSG_MLOPS_CLIENT_STATUS_INITIALIZING
                )

                # Open new process for report system performances to MQTT server
                MLOpsMetrics.report_sys_perf(self.args)

    def handle_message_check_status(self, msg_params):
        self.send_client_status(0)
    def handle_messsage_init(self,msg_params):
        global_model_params=msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        data_silo_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
        logging.info("data_silo_index = %s" % str(data_silo_index))
        self.report_training_status(MyMessage.MSG_MLOPS_CLIENT_STATUS_TRAINING)

        if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
            global_model_params = convert_model_params_to_ddp(global_model_params)
            self.sync_process_group(0, global_model_params, data_silo_index)

        self.trainer_dist_adapter.update_model(global_model_params)
        self.trainer_dist_adapter.update_dataset(int(data_silo_index))
        self.round_idx = 0

        self.__train()
    def  handle_message_receive_model_from_server(self, msg_params):
        logging.info("handle_message_receive_model_from_server.")
        model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
        if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
            model_params = convert_model_params_to_ddp(model_params)
            self.sync_process_group(self.round_idx, model_params, client_index)

        self.trainer_dist_adapter.update_model(model_params)
        self.trainer_dist_adapter.update_dataset(int(client_index))
        if self.round_idx == self.num_rounds - 1:
            # 这里可能需要动

            # Notify MLOps with the finished message
            if hasattr(self.args, "using_mlops") and self.args.using_mlops:
                self.mlops_metrics.report_client_id_status(
                    self.args.run_id,
                    self.client_real_id,
                    MyMessage.MSG_MLOPS_CLIENT_STATUS_FINISHED,
                )
            return
        self.round_idx += 1
        self.__train()
    def handle_message_finish(self, msg_params):
        logging.info(" ====================cleanup ====================")
        self.cleanup()
    def cleanup(self):
        if hasattr(self.args, "using_mlops") and self.args.using_mlops:
            # mlops_metrics = MLOpsMetrics()
            # mlops_metrics.set_sys_reporting_status(False)
            pass
        self.finish()
    def send_model_to_sever(self,receive_id,weights,local_sample_num):
        tick=time.time()
        if hasattr(self.args, "using_mlops") and self.args.using_mlops:
            self.mlops_event.log_event_started(
                "comm_c2s", event_value=str(self.round_idx)
            )
        message = Message(
            MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER,
            self.client_real_id,
            receive_id,
        )
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights)
        message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num)
        self.send_message(message)
        MLOpsProfilerEvent.log_to_wandb(
            {"Communication/Send_Total": time.time() - tick}
        )
        # Report client model to MLOps
        if hasattr(self.args, "using_mlops") and self.args.using_mlops:
            model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL)
            model_info = {
                "run_id": self.args.run_id,
                "edge_id": self.client_real_id,
                "round_idx": self.round_idx + 1,
                "client_model_s3_address": model_url,
            }
            self.mlops_metrics.report_client_model_info(model_info)

    #
    
    def send_client_status(self, receive_id, status="ONLINE"):
        logging.info("send_client_status")
        message = Message(
            MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id
        )
        sys_name = platform.system()
        if sys_name == "Darwin":
            sys_name = "Mac"
        # Debug for simulation mobile system
        # sys_name = MyMessage.MSG_CLIENT_OS_ANDROID

        message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_STATUS, status)
        message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, sys_name)
        self.send_message(message)
    def report_training_status(self, status):
        if hasattr(self.args, "using_mlops") and self.args.using_mlops:
            self.mlops_metrics.set_messenger(self.com_manager_status, self.args)
            self.mlops_metrics.report_client_training_status(
                self.client_real_id, status
            )

    def sync_process_group(
            self, round_idx, model_params=None, client_index=None, src=0
    ):
        logging.info("sending round number to pg")
        round_number = [round_idx, model_params, client_index]
        dist.broadcast_object_list(
            round_number,
            src=src,
            group=self.trainer_dist_adapter.process_group_manager.get_process_group(),
        )
        logging.info("round number %d broadcast to process group" % round_number[0])

    def __train(self):
        logging.info("#######training########### round_id = %d" % self.round_idx)
        if hasattr(self.args, "using_mlops") and self.args.using_mlops:
            self.mlops_event.log_event_started("train", event_value=str(self.round_idx))

        weights, local_sample_num = self.trainer_dist_adapter.train(self.round_idx)

        if hasattr(self.args, "using_mlops") and self.args.using_mlops:
            self.mlops_event.log_event_ended("train", event_value=str(self.round_idx))

        # the current model is still DDP-wrapped under cross-silo-hi setting
        if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
            weights = convert_model_params_from_ddp(weights)

        self.send_model_to_server(0, weights, local_sample_num)

    def run(self):
        super().run()

client_launch.py

这个感觉基本不用动吧

import os
import subprocess
import torch
from fedml.arguments import load_arguments
from fedml.constants import (
    FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL,
    FEDML_TRAINING_PLATFORM_CROSS_SILO,
    FEDML_CROSS_SILO_SCENARIO_HORIZONTAL,
)
from fedml.device import get_device_type

class CrossSiloLauncher:
    def launch_dist_trainer(torch_client_filename,inputs):
        args=load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO)
        if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
            CrossSiloLauncher._run_cross_silo_hierarchical(
                args, torch_client_filename, inputs
            )
        elif args.scenarios == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL:
            CrossSiloLauncher._run_cross_silo_horizontal(
                args, torch_client_filename, inputs
            )
        else:
            raise Exception(
                "we do not support {}, check whether this is typo in args.scenario".format(
                    args.scenario
                )
            )
    def _run_cross_silo_horizontal(args, torch_client_filename, inputs):
        python_path = subprocess.run(
            ["which", "python"], capture_output=True, text=True
        ).stdout.strip()
        process_arguments = [python_path, torch_client_filename] + inputs
        subprocess.run(process_arguments)
    def _run_cross_silo_hierarchical(args, torch_client_filename, inputs):
        def get_torchrun_arguments(node_rank):
            torchrun_path = subprocess.run(
                ["which", "torchrun"], capture_output=True, text=True
            ).stdout.strip()

            return [
                torchrun_path,
                f"--nnodes={args.n_node_in_silo}",
                f"--nproc_per_node={args.n_proc_per_node}",
                # "--rdzv_backend=c10d",
                f"--rdzv_endpoint={args.master_address}:{args.launcher_rdzv_port}",
                f"--node_rank={node_rank}",
                "--rdzv_id=hi_fl",
                torch_client_filename,
            ] + inputs

        network_interface = (
            None if not hasattr(args, "network_interface") else args.network_interface
        )
        print(
            f"Using network interface {network_interface} for process group and TRPC communication"
        )
        env_variables = {
            "OMP_NUM_THREADS": "4",
        }
        if network_interface:
            env_variables = {
                **env_variables,
                "NCCL_SOCKET_IFNAME": network_interface,
                "GLOO_SOCKET_IFNAME": network_interface,
            }

        if args.n_node_in_silo == 1:
            args.node_rank = 0
            args.manual_launch = True
            if not (hasattr(args, "n_proc_per_node") and args.n_proc_per_node):
                print("Number of processes per node not specified.")
                device_type = get_device_type(args)
                if torch.cuda.is_available() and device_type == "gpu":
                    gpu_count = torch.cuda.device_count()
                    print(f"Using number of GPUs ({gpu_count}) as number of processeses.")
                    args.n_proc_per_node = gpu_count
                else: 
                    print(f"Using number 1 as number of processeses.")
                    args.n_proc_per_node = 1

        if hasattr(args, "manual_launch") and args.manual_launch:
            print(f"Manual Client Launcher")
            node_rank = args.node_rank
            torchrun_cmd_arguments = get_torchrun_arguments(node_rank)
            process_args = torchrun_cmd_arguments
            print(f"Launching node {node_rank} of silo {args.rank}")
            subprocess.run(process_args, env=dict(os.environ, **env_variables))

        else:
            print(f"Automatic Client Launcher")

            which_pdsh = subprocess.run(
                ["which", "pdsh"], capture_output=True, text=True
            ).stdout.strip()

            if not which_pdsh:
                raise Exception(
                    f"Silo {args.rank} has {args.n_node_in_silo} nodes. Automatic Client Launcher for more than 1 nodes requires PSDH."
                )

            print(f"Launching nodes using pdsh")

            os.environ["PDSH_RCMD_TYPE"] = "ssh"
            node_addresses = ",".join(args.node_addresses)
            pdsh_cmd_aruments = ["pdsh", "-w", node_addresses]

            exports = ""
            for key, val in env_variables.items():
                exports += "export {}={}; ".format(key, val)
            prerun_args = [
                exports,
                f"cd {os.path.abspath('.')};",
            ]

            node_rank = "%n"
            torchrun_cmd_arguments = get_torchrun_arguments(node_rank)
            process_args = pdsh_cmd_aruments + prerun_args + torchrun_cmd_arguments
            subprocess.run(process_args)

fedml_trainer.py

这里有一个train函数可能需要改

import time

from ...constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
from ...core.mlops.mlops_profiler_event import MLOpsProfilerEvent
from fedml.data import split_data_for_dist_trainers

class FedMLTrainer(object):
    def __init__(
            self,
            client_index,
            train_data_local_dict,
            train_data_local_num_dict,
            test_data_local_dict,
            train_data_num,
            device,
            args,
            model_trainer,
    ):
        self.trainer=model_trainer
        self.client_index=client_index
        if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
            self.train_data_local_dict = split_data_for_dist_trainers(
                train_data_local_dict, args.n_proc_in_silo
            )
        else:
            self.train_data_local_dict = train_data_local_dict
        self.train_data_local_num_dict=train_data_local_num_dict
        self.test_data_local_dict=test_data_local_dict
        self.all_train_data_num=train_data_num
        self.train_local=None
        self.local_sample_number=None
        self.test_local=None
        self.device=device
        self.args=args

    def update_model(self,weights):
        self.trainer.set_model_params(weights)

    def update_dataset(self,client_index):
        self.client_index=client_index
        if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
            self.train_local = self.train_data_local_dict[client_index][
                self.args.proc_rank_in_silo
            ]
        else:
            self.train_local = self.train_data_local_dict[client_index]
        self.local_sample_number = self.train_data_local_num_dict[client_index]
        self.test_local = self.test_data_local_dict[client_index]


    def train(self,round_idx=None):
        self.args.round_idx=round_idx
        tick=time.time()
        self.trainer.train(self.train_local, self.device, self.args)
        MLOpsProfilerEvent.log_to_wandb(
            {"Train/Time": time.time() - tick, "round": round_idx}
        )
        weights=self.trainer.get_model_params()
        return weights,self.local_sample_number


    def test(self):
        train_metrics=self.train.test(self.train_local, self.device, self.args)
        train_tot_correct,train_num_sample,train_loss=(
            train_metrics["test_correct"],
            train_metrics["test_total"],
            train_metrics["test_loss"],
        )

        test_metrics = self.trainer.test(self.test_local, self.device, self.args)
        test_tot_correct, test_num_sample, test_loss = (
            test_metrics["test_correct"],
            test_metrics["test_total"],
            test_metrics["test_loss"],
        )

        return(
            train_tot_correct,
            train_loss,
            train_num_sample,
            test_tot_correct,
            test_loss,
            test_num_sample,
        )

fedml_trainer_dist_adapter.py

import logging

from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
from .fedml_trainer import FedMLTrainer
from .trainer.trainer_creator import create_model_trainer

class TrainaDistAdapter:
    def __init__(
            self,
            args,
            device,
            client_rank,
            model,
            train_data_num,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            model_trainer,
    ):
        model.to(device)
        if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
            from torch.nn.parallel import DistributedDataParallel as DDP
            from .process_group_manager import ProcessGroupManager

            only_gpu = args.using_gpu
            self.process_group_manager = ProcessGroupManager(
                args.proc_rank_in_silo,
                args.n_proc_in_silo,
                args.pg_master_address,
                args.pg_master_port,
                only_gpu,
            )
            model = DDP(model, device_ids=[device] if only_gpu else None)

        if model_trainer is None:
            model_trainer = create_model_trainer(args, model)
        else:
            model_trainer.model = model

        client_index = client_rank - 1

        model_trainer.set_id(client_index)

        logging.info("Initiating Trainer")
        trainer = self.get_trainer(
            client_index,
            train_data_local_dict,
            train_data_local_num_dict,
            test_data_local_dict,
            train_data_num,
            device,
            args,
            model_trainer,
        )
        self.client_index=client_index
        self.client_rank=client_rank
        self.device=device
        self.trainer=trainer
        self.args=args


    def get_trainer(
            self,
            client_index,
            train_data_local_dict,
            train_data_local_num_dict,
            test_data_local_dict,
            train_data_num,
            device,
            args,
            model_trainer,
    ):
        return FedMLTrainer(
            client_index,
            train_data_local_dict,
            train_data_local_num_dict,
            test_data_local_dict,
            train_data_num,
            device,
            args,
            model_trainer,
        )


    def train(self,round_idx):
        weights,local_sample_num = self.trainer.train(round_idx)
        return weights,local_sample_num
    def update_model(self,model_params):
        self.trainer.update_model(model_params)
    def update_dataset(self,client_index=None):
        _client_index=client_index or self.client_index
        self.trainer.update_model(int(_client_index))

    def cleanup_pg(self):
        if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
            logging.info(
                "Cleaningup process group for client %s in silo %s"
                % (self.args.proc_rank_in_silo, self.args.rank_in_node)
            )
            self.process_group_manager.cleanup()

其他的文件我觉得是不需要动的

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值