框架是fedml
我是一名初学者,若是有也研究联邦学习的朋友看见这篇博文,欢迎私信或者加我好友,一起讨论一起学习。
我从client开始学习
先是客户端初始化,clientinitialize
from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL, FEDML_CROSS_SILO_SCENARIO_HORIZONTAL
from .fedml_client_master_manager import ClientMasterManager
from .fedml_trainer_dist_adapter import TrainerDistAdapter
def init_client(
args,
device,
comm,
client_rank,
client_num,
model,
train_data_num,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
model_trainer=None,
):
backend=args.backend
trainer_dist_adapter=get_trainer_dist_adapter(
args,
device,
client_rank,
model,
train_data_num,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
model_trainer,
)
if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:#垂直划分
if args.proc_rank_in_silo == 0:
client_manager=get_client_manager_master(
args,trainer_dist_adapter,comm,client_rank, client_num, backend
)
else:
client_manager=get_client_manager_salve(args, trainer_dist_adapter)
elif args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL:#水平划分
client_manager=get_client_manager_master(
args,trainer_dist_adapter,comm,client_rank,client_num,backend
)
else:
raise Exception(
"we do not support {}. Please check whether this is typo.".format(
args.scenario
)
)
client_manager.run()#配置好了客户端的管理,开始运行
def get_trainer_dist_adapter(
args,
device,
client_rank,
model,
train_data_num,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
model_trainer,
):
return TrainDistAdapter(
args,
device,
client_rank,
model,
train_data_num,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
model_trainer,
)
def get_client_manager_master(
args,trainer_dist_adapter,comm,client_rank,client_num,backend
):
return ClientMasterManager(#这个函数后文还有出现
args,trainer_dist_adapter,comm,client_rank,client_num,backend
)
def get_client_manager_salve(args,trainer_dist_adapter):
from .fedml_client_slave_manager import ClientSlaveManager
return ClientSlaveManager(args,trainer_dist_adapter)
#这两个函数后续都有说明
2.fedml_client_master_manager.py
这个用于横向联邦学习,另外一个用于纵向,因为我要做的是横向,所以我需要这个,要求后台可以支持mpi多机并行运算,有一个参数comm之前一直不知道他是什么意思,现在这个文件,感觉应该是comm_round全局迭代轮数
这个是我认为比较重要的文件之一,在这个上面做一些改动
import json
import logging
import platform
import time
import torch.distributed as dist
from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
from .message_define import MyMessage
from .utils import convert_model_params_from_ddp, convert_model_params_to_ddp
from ...core.distributed.client.client_manager import ClientManager
from ...core.distributed.communication.message import Message
from ...core.mlops.mlops_metrics import MLOpsMetrics
from ...core.mlops.mlops_profiler_event import MLOpsProfilerEvent
class ClientMasterManager(ClientManager):
def __init__(
self,args,trainer_dist_adapter,comm=None,rank=0,size=0,backend="MPI"
):
super().__init__(args,comm,rank,size,backend)
self.trainer_dist_adapter=trainer_dist_adapter
self.args=args
self.num_rounds=args.comm_round
self.round_idx=0
self.rank=rank
self.client_real_ids=json.loads(args.client_id_list)
#读取客户端的id
logging.info("self.client_real_ids = {}".format(self.client_real_ids))
# for the client, len(self.client_real_ids)==1: we only specify its client id in the list, not including others.
self.client_real_id = self.client_real_ids[0]
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
self.mlops_metrics = MLOpsMetrics()
self.mlops_metrics.set_messenger(self.com_manager_status, args)
self.mlops_event = MLOpsProfilerEvent(self.args)
#判断是否适用mlops
#登记收信的操作者
def register_message_receive_handlers(self):
self.register_message_receive_handler(
MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready
)
self.register_message_receive_handler(
MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.handle_message_check_status
)
self.register_message_receive_handler(
MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init
)
self.register_message_receive_handler(
MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT,
self.handle_message_receive_model_from_server,
)
self.register_message_receive_handler(
MyMessage.MSG_TYPE_S2C_FINISH, self.handle_message_finish,
)
def handle_message_connection_ready(self, msg_params):
logging.info("Connection is ready!")
if not self.has_sent_online_msg:
self.has_sent_online_msg = True
self.send_client_status(0)
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
# Notify MLOps with training status.
self.report_training_status(
MyMessage.MSG_MLOPS_CLIENT_STATUS_INITIALIZING
)
# Open new process for report system performances to MQTT server
MLOpsMetrics.report_sys_perf(self.args)
def handle_message_check_status(self, msg_params):
self.send_client_status(0)
def handle_messsage_init(self,msg_params):
global_model_params=msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
data_silo_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
logging.info("data_silo_index = %s" % str(data_silo_index))
self.report_training_status(MyMessage.MSG_MLOPS_CLIENT_STATUS_TRAINING)
if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
global_model_params = convert_model_params_to_ddp(global_model_params)
self.sync_process_group(0, global_model_params, data_silo_index)
self.trainer_dist_adapter.update_model(global_model_params)
self.trainer_dist_adapter.update_dataset(int(data_silo_index))
self.round_idx = 0
self.__train()
def handle_message_receive_model_from_server(self, msg_params):
logging.info("handle_message_receive_model_from_server.")
model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
model_params = convert_model_params_to_ddp(model_params)
self.sync_process_group(self.round_idx, model_params, client_index)
self.trainer_dist_adapter.update_model(model_params)
self.trainer_dist_adapter.update_dataset(int(client_index))
if self.round_idx == self.num_rounds - 1:
# 这里可能需要动
# Notify MLOps with the finished message
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
self.mlops_metrics.report_client_id_status(
self.args.run_id,
self.client_real_id,
MyMessage.MSG_MLOPS_CLIENT_STATUS_FINISHED,
)
return
self.round_idx += 1
self.__train()
def handle_message_finish(self, msg_params):
logging.info(" ====================cleanup ====================")
self.cleanup()
def cleanup(self):
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
# mlops_metrics = MLOpsMetrics()
# mlops_metrics.set_sys_reporting_status(False)
pass
self.finish()
def send_model_to_sever(self,receive_id,weights,local_sample_num):
tick=time.time()
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
self.mlops_event.log_event_started(
"comm_c2s", event_value=str(self.round_idx)
)
message = Message(
MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER,
self.client_real_id,
receive_id,
)
message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights)
message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num)
self.send_message(message)
MLOpsProfilerEvent.log_to_wandb(
{"Communication/Send_Total": time.time() - tick}
)
# Report client model to MLOps
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL)
model_info = {
"run_id": self.args.run_id,
"edge_id": self.client_real_id,
"round_idx": self.round_idx + 1,
"client_model_s3_address": model_url,
}
self.mlops_metrics.report_client_model_info(model_info)
#
def send_client_status(self, receive_id, status="ONLINE"):
logging.info("send_client_status")
message = Message(
MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id
)
sys_name = platform.system()
if sys_name == "Darwin":
sys_name = "Mac"
# Debug for simulation mobile system
# sys_name = MyMessage.MSG_CLIENT_OS_ANDROID
message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_STATUS, status)
message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, sys_name)
self.send_message(message)
def report_training_status(self, status):
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
self.mlops_metrics.set_messenger(self.com_manager_status, self.args)
self.mlops_metrics.report_client_training_status(
self.client_real_id, status
)
def sync_process_group(
self, round_idx, model_params=None, client_index=None, src=0
):
logging.info("sending round number to pg")
round_number = [round_idx, model_params, client_index]
dist.broadcast_object_list(
round_number,
src=src,
group=self.trainer_dist_adapter.process_group_manager.get_process_group(),
)
logging.info("round number %d broadcast to process group" % round_number[0])
def __train(self):
logging.info("#######training########### round_id = %d" % self.round_idx)
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
self.mlops_event.log_event_started("train", event_value=str(self.round_idx))
weights, local_sample_num = self.trainer_dist_adapter.train(self.round_idx)
if hasattr(self.args, "using_mlops") and self.args.using_mlops:
self.mlops_event.log_event_ended("train", event_value=str(self.round_idx))
# the current model is still DDP-wrapped under cross-silo-hi setting
if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
weights = convert_model_params_from_ddp(weights)
self.send_model_to_server(0, weights, local_sample_num)
def run(self):
super().run()
client_launch.py
这个感觉基本不用动吧
import os
import subprocess
import torch
from fedml.arguments import load_arguments
from fedml.constants import (
FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL,
FEDML_TRAINING_PLATFORM_CROSS_SILO,
FEDML_CROSS_SILO_SCENARIO_HORIZONTAL,
)
from fedml.device import get_device_type
class CrossSiloLauncher:
def launch_dist_trainer(torch_client_filename,inputs):
args=load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO)
if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
CrossSiloLauncher._run_cross_silo_hierarchical(
args, torch_client_filename, inputs
)
elif args.scenarios == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL:
CrossSiloLauncher._run_cross_silo_horizontal(
args, torch_client_filename, inputs
)
else:
raise Exception(
"we do not support {}, check whether this is typo in args.scenario".format(
args.scenario
)
)
def _run_cross_silo_horizontal(args, torch_client_filename, inputs):
python_path = subprocess.run(
["which", "python"], capture_output=True, text=True
).stdout.strip()
process_arguments = [python_path, torch_client_filename] + inputs
subprocess.run(process_arguments)
def _run_cross_silo_hierarchical(args, torch_client_filename, inputs):
def get_torchrun_arguments(node_rank):
torchrun_path = subprocess.run(
["which", "torchrun"], capture_output=True, text=True
).stdout.strip()
return [
torchrun_path,
f"--nnodes={args.n_node_in_silo}",
f"--nproc_per_node={args.n_proc_per_node}",
# "--rdzv_backend=c10d",
f"--rdzv_endpoint={args.master_address}:{args.launcher_rdzv_port}",
f"--node_rank={node_rank}",
"--rdzv_id=hi_fl",
torch_client_filename,
] + inputs
network_interface = (
None if not hasattr(args, "network_interface") else args.network_interface
)
print(
f"Using network interface {network_interface} for process group and TRPC communication"
)
env_variables = {
"OMP_NUM_THREADS": "4",
}
if network_interface:
env_variables = {
**env_variables,
"NCCL_SOCKET_IFNAME": network_interface,
"GLOO_SOCKET_IFNAME": network_interface,
}
if args.n_node_in_silo == 1:
args.node_rank = 0
args.manual_launch = True
if not (hasattr(args, "n_proc_per_node") and args.n_proc_per_node):
print("Number of processes per node not specified.")
device_type = get_device_type(args)
if torch.cuda.is_available() and device_type == "gpu":
gpu_count = torch.cuda.device_count()
print(f"Using number of GPUs ({gpu_count}) as number of processeses.")
args.n_proc_per_node = gpu_count
else:
print(f"Using number 1 as number of processeses.")
args.n_proc_per_node = 1
if hasattr(args, "manual_launch") and args.manual_launch:
print(f"Manual Client Launcher")
node_rank = args.node_rank
torchrun_cmd_arguments = get_torchrun_arguments(node_rank)
process_args = torchrun_cmd_arguments
print(f"Launching node {node_rank} of silo {args.rank}")
subprocess.run(process_args, env=dict(os.environ, **env_variables))
else:
print(f"Automatic Client Launcher")
which_pdsh = subprocess.run(
["which", "pdsh"], capture_output=True, text=True
).stdout.strip()
if not which_pdsh:
raise Exception(
f"Silo {args.rank} has {args.n_node_in_silo} nodes. Automatic Client Launcher for more than 1 nodes requires PSDH."
)
print(f"Launching nodes using pdsh")
os.environ["PDSH_RCMD_TYPE"] = "ssh"
node_addresses = ",".join(args.node_addresses)
pdsh_cmd_aruments = ["pdsh", "-w", node_addresses]
exports = ""
for key, val in env_variables.items():
exports += "export {}={}; ".format(key, val)
prerun_args = [
exports,
f"cd {os.path.abspath('.')};",
]
node_rank = "%n"
torchrun_cmd_arguments = get_torchrun_arguments(node_rank)
process_args = pdsh_cmd_aruments + prerun_args + torchrun_cmd_arguments
subprocess.run(process_args)
fedml_trainer.py
这里有一个train函数可能需要改
import time
from ...constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
from ...core.mlops.mlops_profiler_event import MLOpsProfilerEvent
from fedml.data import split_data_for_dist_trainers
class FedMLTrainer(object):
def __init__(
self,
client_index,
train_data_local_dict,
train_data_local_num_dict,
test_data_local_dict,
train_data_num,
device,
args,
model_trainer,
):
self.trainer=model_trainer
self.client_index=client_index
if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
self.train_data_local_dict = split_data_for_dist_trainers(
train_data_local_dict, args.n_proc_in_silo
)
else:
self.train_data_local_dict = train_data_local_dict
self.train_data_local_num_dict=train_data_local_num_dict
self.test_data_local_dict=test_data_local_dict
self.all_train_data_num=train_data_num
self.train_local=None
self.local_sample_number=None
self.test_local=None
self.device=device
self.args=args
def update_model(self,weights):
self.trainer.set_model_params(weights)
def update_dataset(self,client_index):
self.client_index=client_index
if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
self.train_local = self.train_data_local_dict[client_index][
self.args.proc_rank_in_silo
]
else:
self.train_local = self.train_data_local_dict[client_index]
self.local_sample_number = self.train_data_local_num_dict[client_index]
self.test_local = self.test_data_local_dict[client_index]
def train(self,round_idx=None):
self.args.round_idx=round_idx
tick=time.time()
self.trainer.train(self.train_local, self.device, self.args)
MLOpsProfilerEvent.log_to_wandb(
{"Train/Time": time.time() - tick, "round": round_idx}
)
weights=self.trainer.get_model_params()
return weights,self.local_sample_number
def test(self):
train_metrics=self.train.test(self.train_local, self.device, self.args)
train_tot_correct,train_num_sample,train_loss=(
train_metrics["test_correct"],
train_metrics["test_total"],
train_metrics["test_loss"],
)
test_metrics = self.trainer.test(self.test_local, self.device, self.args)
test_tot_correct, test_num_sample, test_loss = (
test_metrics["test_correct"],
test_metrics["test_total"],
test_metrics["test_loss"],
)
return(
train_tot_correct,
train_loss,
train_num_sample,
test_tot_correct,
test_loss,
test_num_sample,
)
fedml_trainer_dist_adapter.py
import logging
from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
from .fedml_trainer import FedMLTrainer
from .trainer.trainer_creator import create_model_trainer
class TrainaDistAdapter:
def __init__(
self,
args,
device,
client_rank,
model,
train_data_num,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
model_trainer,
):
model.to(device)
if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
from torch.nn.parallel import DistributedDataParallel as DDP
from .process_group_manager import ProcessGroupManager
only_gpu = args.using_gpu
self.process_group_manager = ProcessGroupManager(
args.proc_rank_in_silo,
args.n_proc_in_silo,
args.pg_master_address,
args.pg_master_port,
only_gpu,
)
model = DDP(model, device_ids=[device] if only_gpu else None)
if model_trainer is None:
model_trainer = create_model_trainer(args, model)
else:
model_trainer.model = model
client_index = client_rank - 1
model_trainer.set_id(client_index)
logging.info("Initiating Trainer")
trainer = self.get_trainer(
client_index,
train_data_local_dict,
train_data_local_num_dict,
test_data_local_dict,
train_data_num,
device,
args,
model_trainer,
)
self.client_index=client_index
self.client_rank=client_rank
self.device=device
self.trainer=trainer
self.args=args
def get_trainer(
self,
client_index,
train_data_local_dict,
train_data_local_num_dict,
test_data_local_dict,
train_data_num,
device,
args,
model_trainer,
):
return FedMLTrainer(
client_index,
train_data_local_dict,
train_data_local_num_dict,
test_data_local_dict,
train_data_num,
device,
args,
model_trainer,
)
def train(self,round_idx):
weights,local_sample_num = self.trainer.train(round_idx)
return weights,local_sample_num
def update_model(self,model_params):
self.trainer.update_model(model_params)
def update_dataset(self,client_index=None):
_client_index=client_index or self.client_index
self.trainer.update_model(int(_client_index))
def cleanup_pg(self):
if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
logging.info(
"Cleaningup process group for client %s in silo %s"
% (self.args.proc_rank_in_silo, self.args.rank_in_node)
)
self.process_group_manager.cleanup()
其他的文件我觉得是不需要动的