一,本周工作
1),详细读了《2017-FedAvgCommunication-Efficient Learning of Deep Networks》中代码的实现部分,并做了如下记录
辅助代码理解
1,数据应该没有被送到数据中心
【该数据是隐私敏感的或者大小很大(与模型的大小相比),因此最好不要纯粹为了模型训练的目的而将其记录到数据中心(服务于集中收集原则)。】
2,模型结构
①有K个固定客户端,每个客户端有自己的本地数据集
②每轮随机算则C个客户端,服务器把当前全局算法状态发给这C个
③每个选定的客户端基于全局状态及本地数据集执行本地计算,并行向服务器发送更新
④服务器将这些更新应用于全局状态
⑤重复该过程
3,本实验要优化通讯成本,代码有改动,让每轮客户端增加了更多计算
4,模型优化的核心应该是SGD(梯度下降)
5,C=1对应于全批次(非随机)
6,计算由三个关键参数控制
C:每轮执行运算的客户端比例
E:每个客户端,在一轮中对本地数据集训练的次数
B:用于客户端更新的本地微型批次大小(B=无穷,表示本地数据集被视作单个迷你批次)
7,本文提高一种关键方法:naive parameter averaging(朴素参数平均)
8,模型算法伪代码
2)伪代码对应原模型
对应server部分
# TODO 从这里开始就是伪代码的结构
for i in range(args.num_comm): # 这个num_comm指的是训练轮次
print("communicate round {}".format(i))
order = np.arange(args.num_of_clients)
np.random.shuffle(order)
clients_in_comm = ['client{}'.format(i) for i in order[0:num_in_comm]]
sum_vars = None
for client in tqdm(clients_in_comm):
local_vars = myClients.ClientUpdate(client, global_vars) # TODO 调用客户端的更新
if sum_vars is None:
sum_vars = local_vars
else: # TODO zip函数,将传入的两个列表,打包成一个元组
for sum_var, local_var in zip(sum_vars, local_vars):
sum_var += local_var
对应clientUpdate部分
def ClientUpdate(self, client, global_vars):
all_vars = tf.trainable_variables()
for variable, value in zip(all_vars, global_vars):
variable.load(value, self.session)
for i in range(self.E):
for j in range(self.clientsSet[client].dataset_size // self.B):
train_data, train_label = self.clientsSet[client].next_batch(self.B)
self.session.run(self.train, feed_dict={self.inputsx: train_data, self.inputsy: train_label})
return self.session.run(tf.trainable_variables())
二,下周工作
在本地/服务器配置环境,让代码跑起来,去确定后端,聚合端传递的val是什么类型,同时研究聚合算法,看能否移用