一,本周工作
根据3月20日开会讨论,进行聚合端的设计和开发,因为我们的打算是先用代码将整个联邦学习流程过一遍,故而现在的聚合端只需要针对一种聚合模型。
设计如下:
1,聚合端服务器采用多线程技术,对于每一个到来的训练请求,都分一个线程去处理。
2,设计一组全局变量,比如mutex:用于锁变量,防止并发,socket_list:管理每个用户的socket,global_parametser:保存留在聚合服务器上的训练模型,等等。
3,设定聚合端与后端通过Message对象沟通,初步设定有,data:用于与后端沟通的数组,
label:表明用户意图的标签,初步有Train和Close.
具体代码如下:
交互类Message:
class Message:
def __init__(self, data={}, label='Train'):
self.data = data
self.label = label
def getData(self):
return self.data
def getLabel(self):
return self.label
def getRound(self):
return self.round
def getGroupLabel(self):
return self.glabel
def setData(self, temp_dict):
self.data = temp_dict
启动聚合服务器:
if __name__ == '__main__':
Aggregate_Server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
HOST = get_ip()
print('host',HOST)
# POST = get_port()
POST = 9005 # easy to test
Aggregate_Server.bind((HOST, POST))
# conn = pymysql.connect(host='rm-bp1ocx5t725vc9t85fo.mysql.rds.aliyuncs.com', port=3306, user='user1', db='federal',
# password='chy352196487!')
# cursor = conn.cursor(cursor=pymysql.cursors.DictCursor)
# judge = find_exist(conn, cursor, HOST, POST)
# if find_exist(conn, cursor, HOST, POST) is False:
# insert(conn, cursor, HOST, POST)
# print("insert server info success")
# conn.close()
print("聚合端初始化完毕,准备接受模型...")
# uniteThread().start()
while True:
Aggregate_Server.listen(1)
receive_socket, address = Aggregate_Server.accept()
# socket_List.append(receive_socket)
handleThread(receive_socket, address).start() # 新建一个进程来接受请求并提供服务
线程分配结构:值的一提的是,这里我们采用的socket传输数据,用pickle封装对象
class handleThread (threading.Thread):
def __init__(self,receive_socket, address):
threading.Thread.__init__(self)
self.rsocket = receive_socket
self.address = address
def run(self):
print("starting",str(self.address))
while True: # 这里循环持续与客户端交互r
# try:
data = self.rsocket.recv(8000000)
print(sys.getsizeof(data))
data = pickle.loads(data) # 这里接受到Message类型的对象
print(str(self.address) + "的数据loads完成")
if algorithm(data, self.rsocket, self.address) != True:
break
算法部分:
def algorithm(data, rsocket, address):
global mutex
global GroupDict
try:
if data.getLabel() == 'Close':
print(f'接收到Close,{address}断开连接!')
rsocket.close()
return False
# if data.getLabel() == 'Create':
# # 如果是一个创建任务的人
# mutex.acquire()
# GroupDict[data.getGroupId()] = Group(data.getGroupId())
# mutex.release()
# data.setLabel("Created")
# rsocket.sendall(pickle.dumps(data))
# return True
if data.getLabel() == 'Train': # 这里不考虑错误请求,即发送的Label一定在字典中
# TODO 这一块复制好全局变量
mutex.acquire()
group = GroupDict[data.getGroupId()]
mutex.release()
local_parameters = data.getData()
group.getMutex().acquire()
unite_round = group.getUniteRound()
global_parameters = group.getGloPar()
if global_parameters is None:
global_parameters = local_parameters
else:
for var in global_parameters: # 优化训练轮次参与权重,后面来的项目贡献就低
weight = float(unite_round) / (unite_round + data.getRound())
global_parameters[var] = global_parameters[var] * weight + local_parameters[var] * (1 - weight)
unite_round = unite_round + 1
data.setData(global_parameters)
group.setGloPar(global_parameters)
group.setUniteRound(unite_round)
group.getMutex().release() # TODO 采用了这个组的mutex来保证同一组的请求无法同时执行
data.setRound(data.getRound() + 1)
data.setLabel("Trained")
rsocket.sendall(pickle.dumps(data))
print(fr'已成功返回给{address}')
time.sleep(0.1) # 已经处理完的经常稍微等待一下,防止一个进程一直mutex抢占,另外一个进程没机会学习
return True
二,开会总结
本周后端实现了模型处理与训练,前端确定了界面架构,并进行了相关开发。
聚合端测试基本搭建了聚合端框架,并进行了模型聚合。
三,下周工作
我需要继续完成结构,并与后端同学交流,确定Message的最终结构。
同时要为与他人主机通过socket传递信息做相关准备,目前仅在自己主机上完成测试。