总结十一
近期工作
这周的工作主要是1.任务训练的测试和解决相应问题、2.解决django读取数据库blob文件的二次编码(encode)问题
1.任务训练测试以及问题解决
首先进行router(路由)绑定,请求绑定到应用
from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from Client.client_rear.rear_core import routing
application = ProtocolTypeRouter({
'websocket': AuthMiddlewareStack(
URLRouter(
routing.websocket_urlpatterns
)
)
})
应用绑定到方法
from django.conf.urls import url
from . import consumers
websocket_urlpatterns = [
# 模型训练
url(r'^training/', consumers.TrainConsumer.as_asgi()),
# 等待他人到来
url(r'^waiting/', consumers.WaitingConsumer.as_asgi())
]
任务训练的代码
class TrainConsumer(WebsocketConsumer):
def connect(self):
self.accept()
print('TrainConsumer connect!')
def disconnect(self, code):
print('TrainConsumer disconnect!' + code)
'''
import importlib
def main():
modname = input('输入想要导入的模块名称:')
string = importlib.import_module(modname)
string.helloworld()
if __name__ == '__main__':
main()
通过importlib动态导入模块
'''
def receive(self, text_data=None, bytes_data=None):
global clientSocket
message = Message()
# True 表示断开连接
# False 表示没有断开连接
judge = False
# 前端提示返回值
returnValue = {
'message': '',
'message_state': False,
'accuracy': 0.0,
'judge': judge,
'step': 0,
}
# 首先获取到聚合端的ip和端口号(如果为空,就返回提示:尚未聚合端,请耐心等待)
uniteServers = UniteServer.objects
if not uniteServers.exists():
# 数据库为空,返回给前端提示
returnValue['message'] = '聚合端为空,请耐心等待!'
returnValue['message_state'] = True
self.send(text_data=json.dumps(returnValue))
return
uniteServer = uniteServers.first()
ip, port = uniteServer.ip, uniteServer.port
# 通过从数据库中读取的ip和port与聚合端建立连接
clientSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
clientSocket.connect((ip, port))
print("与聚合端建立连接")
# 根据任务id,从数据库中获取模型的名称、模型的py文件和初始模型,放到固定位置
text_data_json = json.loads(text_data)
print(text_data_json)
mission_id = text_data_json['task_id']
# 这里需要更改任务状态
mission = Mission.objects.get(mission_id=mission_id)
# 这里的文件转换先不考虑
# 文件转换保存已完成
python_file, original_model, pyname, omname, launcher = \
mission.python_file, mission.original_model, \
mission.pyname, mission.omname, mission.mission_launcher
if mission.mission_state == 'waiting':
Mission.objects.filter(mission_id=mission_id).update(mission_state='training')
message.setLabel('Create')
message.setGroupId(int(mission_id))
muteki_send(clientSocket, message)
muteki_recv(clientSocket)
# 模型文件和初始模型保存在model文件夹下
print("当前路径", str(os.getcwd()))
python_file = python_file.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
print(python_file)
with open('rear_core/model/' + pyname, 'wb+') as file:
file.write(python_file)
if original_model is not None:
# print("original_model",str(original_model))
original_model = original_model.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
with open('rear_core/model/' + omname, 'wb+') as file:
file.write(original_model)
# 动态导入module
model = importlib.import_module('Client.client_rear.rear_core.model.' + pyname[:-3])
# 动态获取module中的类实例变量
net = getattr(model, pyname[:-3])()
if original_model is not None:
# 加载初始模型
net.load_state_dict(torch.load('rear_core/model/' + omname))
self.send(text_data=json.dumps(returnValue))
returnValue['step'] = 1
dev = torch.device('cpu')
# if torch.cuda.is_available():
# dev = torch.device("cuda")
lossFun = F.cross_entropy
opti = optim.SGD(net.parameters(), lr=0.01)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
net = torch.nn.DataParallel(net)
net = net.to(dev)
# 获取训练和测试数据集
train_dataset, test_dataset = model.getDataSetAllocation()
# 循环:训练->与聚合端交互->计算精度->反馈给前端展示
while True: # desired_epoch代表要聚合的轮次
# 模型训练,参数待定
model_dict = model.modelTrain(net, dev, lossFun, opti, trainDataSet=train_dataset)
message.setGroupId(mission_id)
message.setData(model_dict)
message.setLabel("Train")
try:
muteki_send(clientSocket, message)
# 接受聚合端返回的模型数据
message = pickle.loads(muteki_recv(clientSocket))
if message.getLabel() == 'Close':
# 训练完成之后直接断开连接
returnValue['judge'] = disconnectWithAgg(message)
returnValue['step'] = 2
self.send(text_data=json.dumps(returnValue))
# 查询当前任务的状态和用户的
mission = Mission.objects.get(mission_id=mission_id)
mission_state = mission.mission_state
mission_omName = mission.omname
if mission_omName is None:
Mission.objects.filter(mission_id=mission_id).update(omname='default.model')
omname = 'default.model'
if mission_state == 'training':
# 保存模型数据为model文件
torch.save(message.getData(), 'rear_core/model/' + omname)
# 保存到数据库中
with open('rear_core/model/' + omname, 'rb') as file:
strb = ''.encode()
strb += file.read()
Mission.objects.filter(mission_id=mission_id).update(
mission_state='finish',
current_model=strb
)
break
# 设置模型数据
net.load_state_dict(message.getData(), strict=True)
except socket.error:
print('Socket.error happened')
# 计算准确度,参数待定
accuracy = model.accuracyCompute(net, dev, test_dataset)
print('准确度等于%f' % accuracy)
returnValue['accuracy'] = accuracy.cpu().detach().numpy().tolist()
# 返回给前端准确度等参数展示, 使用self.send()方法
# False 表示没有断开连接
# True 表示断开连接
# 断开连接后前端显示模型训练完成
self.send(text_data=json.dumps(returnValue))
self.close(code=1000)
模型训练的代码主要可以分为 块:
- 从数据库中获取聚合端的ip和port,与聚合端建立连接
- 从数据库中获取任务模型相关信息及文件,动态加载模型
- 循环:训练->与聚合端交互->计算精度反馈给前端展示
- 判断是否结束同时判断任务状态,是否进行模型参数的数据库保存
下面分别介绍一下每部分的代码
(1)从数据库中获取聚合端的ip和port,与聚合端建立连接
# 首先获取到聚合端的ip和端口号(如果为空,就返回提示:尚未聚合端,请耐心等待)
uniteServers = UniteServer.objects
if not uniteServers.exists():
# 数据库为空,返回给前端提示
returnValue['message'] = '聚合端为空,请耐心等待!'
returnValue['message_state'] = True
self.send(text_data=json.dumps(returnValue))
return
uniteServer = uniteServers.first()
ip, port = uniteServer.ip, uniteServer.port
# 通过从数据库中读取的ip和port与聚合端建立连接
clientSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
clientSocket.connect((ip, port))
print("与聚合端建立连接")
从数据库获取的聚合端的ip和端口号,判断是否为空,如果不为空则与聚合端建立连接。
(2)从数据库中获取任务模型相关信息及文件,动态加载模型
# 根据任务id,从数据库中获取模型的名称、模型的py文件和初始模型,放到固定位置
text_data_json = json.loads(text_data)
print(text_data_json)
mission_id = text_data_json['task_id']
# 这里需要更改任务状态
mission = Mission.objects.get(mission_id=mission_id)
# 这里的文件转换先不考虑
# 文件转换保存已完成
python_file, original_model, pyname, omname, launcher = \
mission.python_file, mission.original_model, \
mission.pyname, mission.omname, mission.mission_launcher
if mission.mission_state == 'waiting':
Mission.objects.filter(mission_id=mission_id).update(mission_state='training')
message.setLabel('Create')
message.setGroupId(int(mission_id))
muteki_send(clientSocket, message)
muteki_recv(clientSocket)
# 模型文件和初始模型保存在model文件夹下
print("当前路径", str(os.getcwd()))
python_file = python_file.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
# print(python_file)
with open('rear_core/model/' + pyname, 'wb+') as file:
file.write(python_file)
if original_model is not None:
# print("original_model",str(original_model))
original_model = original_model.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
with open('rear_core/model/' + omname, 'wb+') as file:
file.write(original_model)
# 动态导入module
model = importlib.import_module('Client.client_rear.rear_core.model.' + pyname[:-3])
# 动态获取module中的类实例变量
net = getattr(model, pyname[:-3])()
if original_model is not None:
# 加载初始模型
net.load_state_dict(torch.load('rear_core/model/' + omname))
首先根据任务id从数据库中取出数据,然后将文件保存到本地,最后使用importlib.import_module
进行动态导入。
(3)循环:训练->与聚合端交互->计算精度反馈给前端展示
while True: # desired_epoch代表要聚合的轮次
# 模型训练,参数待定
model_dict = model.modelTrain(net, dev, lossFun, opti, trainDataSet=train_dataset)
message.setGroupId(mission_id)
message.setData(model_dict)
message.setLabel("Train")
try:
muteki_send(clientSocket, message)
# 接受聚合端返回的模型数据
message = pickle.loads(muteki_recv(clientSocket))
if message.getLabel() == 'Close':
# 训练完成之后直接断开连接
returnValue['judge'] = disconnectWithAgg(message)
returnValue['step'] = 2
self.send(text_data=json.dumps(returnValue))
# 查询当前任务的状态和用户的
mission = Mission.objects.get(mission_id=mission_id)
mission_state = mission.mission_state
mission_omName = mission.omname
if mission_omName is None:
Mission.objects.filter(mission_id=mission_id).update(omname='default.model')
omname = 'default.model'
if mission_state == 'training':
# 保存模型数据为model文件
torch.save(message.getData(), 'rear_core/model/' + omname)
# 保存到数据库中
with open('rear_core/model/' + omname, 'rb') as file:
strb = ''.encode()
strb += file.read()
Mission.objects.filter(mission_id=mission_id).update(
mission_state='finish',
current_model=strb
)
break
# 设置模型数据
net.load_state_dict(message.getData(), strict=True)
except socket.error:
print('Socket.error happened')
# 计算准确度,参数待定
accuracy = model.accuracyCompute(net, dev, test_dataset)
print('准确度等于%f' % accuracy)
returnValue['accuracy'] = accuracy.cpu().detach().numpy().tolist()
# 返回给前端准确度等参数展示, 使用self.send()方法
# False 表示没有断开连接
# True 表示断开连接
# 断开连接后前端显示模型训练完成
self.send(text_data=json.dumps(returnValue))
获取到模型参数,发送给聚合端聚合,接收聚合端返回的模型参数,判断是否断开连接,如果断开则执行(4);如果不断开连接,则更新本地模型参数,计算精度给前端返回,重复执行上述操作。
(4)判断是否结束同时判断任务状态,是否进行模型参数的数据库保存
if message.getLabel() == 'Close':
# 训练完成之后直接断开连接
returnValue['judge'] = disconnectWithAgg(message)
returnValue['step'] = 2
self.send(text_data=json.dumps(returnValue))
# 查询当前任务的状态和用户的
mission = Mission.objects.get(mission_id=mission_id)
mission_state = mission.mission_state
mission_omName = mission.omname
if mission_omName is None:
Mission.objects.filter(mission_id=mission_id).update(omname='default.model')
omname = 'default.model'
if mission_state == 'training':
# 保存模型数据为model文件
torch.save(message.getData(), 'rear_core/model/' + omname)
# 保存到数据库中
with open('rear_core/model/' + omname, 'rb') as file:
strb = ''.encode()
strb += file.read()
Mission.objects.filter(mission_id=mission_id).update(
mission_state='finish',
current_model=strb
)
break
如果聚合端返回的信息是断开连接,则通过disconnectWithAgg
方法与聚合端断开连接,反馈给前端信息,接下来判断任务状态,如果是training状态,标识模型尚未更新,将本地模型上传到数据库中,最后结束外层循环。
2.二次编码问题
将文件从数据库中读取的过程存在着比较大的问题:django的models(数据库处理文件)无法处理数据库的blob类型的数据,于是乎他认为是字符串类型,同时需要转成二进制类型,就有了提到的二次编码问题。
下面简单展示一下二次编码问题:
# 假设我有一个初始的字符串
temp = ' \
'
# 数据库中blob存储的是二进制串
temp = temp.encode()
# 此时temp = b'\r\n'
# django从数据库中读取到后以为它是字符串类型,同时需要二进制类型又对它进行的编码
temp = temp.encode() # 这个做不到的,只是描述一下这个过程
# 此时temp = b'b\'\\r\\n\''
# 而我们需要的是temp = b'\r\n',这就是问题所在
# 于是乎我通过replace方法解决了这个问题
temp = temp.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
但是这个目前还存在着缺陷:
- 无法处理中文的转码
- 文件中所有单引号会转化成空
上述缺陷暂时没有解决方法。
总结
以上便是我项目实训最后的博客内容了,完成了对于项目最后一个部分(任务训练)的介绍和代码展示。
整个项目走下来,收获满满,同时感觉自己的文笔都有了进步,希望接下来的道路一路通畅。
下步计划
准备结项检查