一,本周工作
1,根据新设计,重新改模型架构如下:
①为了针对多种任务,并把每一个任务当做一个组,需要引入新的Group对象,用于管理每个组的全局变量
②增加groupId为Message的属性
③测试了确实可以通过全局的GoupList中的变量,全局的管理各个组
④从代码角度修改结构
修改后的变量部分:
GroupDict = {}
mutex = threading.Lock() # 构建线程锁,对GroupDict的读取和修改需要锁
group对象管理全局变量:
class Group:
def __init__(self, g_label, g_global_parameters=None, g_mutex=threading.Lock(), g_num_client=1, g_unite_round=1):
self.g_label = g_label
self.g_global_parameters = g_global_parameters
self.g_mutex = g_mutex
self.g_num_client = g_num_client # 这个参数应该不需要了
self.g_unite_round = g_unite_round
def setGroupLabel(self, g_label):
self.g_label = g_label
def getGroupLabel(self):
return self.g_label
def setGloPar(self, g_global_parameters):
self.g_global_parameters = g_global_parameters
def getGloPar(self):
return self.g_global_parameters
def setMutex(self, g_mutex):
self.g_mutex = g_mutex
def getMutex(self):
return self.g_mutex
def setNumClient(self, g_num_client):
self.g_num_client = g_num_client
def getNumClient(self):
return self.g_num_client
def setUniteRound(self, g_unite_round):
self.g_unite_round = g_unite_round
def getUniteRound(self):
return self.g_unite_round
新修改Message对象:
class Message:
def __init__(self, data={}, label='Train', round = 1, groupid = 0):
self.data = data
self.label = label
self.round = round
self.groupid = groupid
def getGroupId(self):
return self.groupid
def setGroupId(self, groupid):
self.groupid = groupid
def getData(self):
return self.data
def getLabel(self):
return self.label
def getRound(self):
return self.round
def setData(self, temp_dict):
self.data = temp_dict
def setLabel(self, temp_label):
self.label = temp_label
def setRound(self, temp_round):
self.round = temp_round
新的algorithon结构:
def algorithm(data, rsocket, address):
global mutex
global GroupDict
try:
if data.getLabel() == 'Close':
print(f'接收到Close,{address}断开连接!')
rsocket.close()
return False
if data.getLabel() == 'Create':
# 如果是一个创建任务的人
mutex.acquire()
# GroupDict[data.getGroupId()] = Group(data.getGroupId())
temp = Group(data.groupid)
GroupDict[data.groupid] = temp
mutex.release()
data.setLabel("Created")
rsocket.sendall(pickle.dumps(data))
return True
if data.getLabel() == 'Train': # 这里不考虑错误请求,即发送的Label一定在字典中
# TODO 这一块复制好全局变量
mutex.acquire()
group = GroupDict[data.groupid]
mutex.release()
local_parameters = data.getData()
group.getMutex().acquire()
unite_round = group.getUniteRound()
global_parameters = group.getGloPar()
if global_parameters is None:
global_parameters = local_parameters
else:
for var in global_parameters: # 优化训练轮次参与权重,后面来的项目贡献就低
weight = float(unite_round) / (unite_round + data.getRound())
print()
global_parameters[var] = global_parameters[var] * weight + local_parameters[var] * (1 - weight)
unite_round = unite_round + 1
data.setData(global_parameters)
group.setGloPar(global_parameters)
group.setUniteRound(unite_round)
mutex.acquire()
GroupDict[data.groupid] = group
mutex.release()
group.getMutex().release() # TODO 采用了这个组的mutex来保证同一组的请求无法同时执行
data.setRound(data.getRound() + 1)
data.setLabel("Trained")
rsocket.sendall(pickle.dumps(data))
print(fr'已成功返回给{address}')
time.sleep(0.1) # 已经处理完的经常稍微等待一下,防止一个进程一直mutex抢占,另外一个进程没机会学习
return True
except:
rsocket.close()
print("客户" + str(address) + "异常断开了连接!")
# 考虑到给成员重新连接的机会,先不把他从组里踢掉
return False
二,下周安排
①聚合端启动要把自己的ip和端口号传到服务器==>需要去学习python中的数据库操作
②希望能通过代码自动获取本机的ip和端口号