0608个人总结十一(补)

总结十一

近期工作

这周的工作主要是1.任务训练的测试和解决相应问题、2.解决django读取数据库blob文件的二次编码(encode)问题

1.任务训练测试以及问题解决

首先进行router(路由)绑定,请求绑定到应用

from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from Client.client_rear.rear_core import routing

application = ProtocolTypeRouter({
    'websocket': AuthMiddlewareStack(
        URLRouter(
            routing.websocket_urlpatterns
        )
    )
})

应用绑定到方法

from django.conf.urls import url

from . import consumers

websocket_urlpatterns = [
    # 模型训练
    url(r'^training/', consumers.TrainConsumer.as_asgi()),
    # 等待他人到来
    url(r'^waiting/', consumers.WaitingConsumer.as_asgi())
]

任务训练的代码

class TrainConsumer(WebsocketConsumer):
    def connect(self):
        self.accept()
        print('TrainConsumer connect!')

    def disconnect(self, code):
        print('TrainConsumer disconnect!' + code)

    '''
    import importlib


    def main():
        modname = input('输入想要导入的模块名称:')
        string = importlib.import_module(modname)
        string.helloworld()


    if __name__ == '__main__':
        main()
    
    通过importlib动态导入模块
    '''

    def receive(self, text_data=None, bytes_data=None):
        global clientSocket
        message = Message()

        # True 表示断开连接
        # False 表示没有断开连接
        judge = False

        # 前端提示返回值
        returnValue = {
            'message': '',
            'message_state': False,
            'accuracy': 0.0,
            'judge': judge,
            'step': 0,
        }

        # 首先获取到聚合端的ip和端口号(如果为空,就返回提示:尚未聚合端,请耐心等待)
        uniteServers = UniteServer.objects
        if not uniteServers.exists():
            # 数据库为空,返回给前端提示
            returnValue['message'] = '聚合端为空,请耐心等待!'
            returnValue['message_state'] = True
            self.send(text_data=json.dumps(returnValue))
            return
        uniteServer = uniteServers.first()
        ip, port = uniteServer.ip, uniteServer.port

        # 通过从数据库中读取的ip和port与聚合端建立连接
        clientSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        clientSocket.connect((ip, port))
        print("与聚合端建立连接")

        # 根据任务id,从数据库中获取模型的名称、模型的py文件和初始模型,放到固定位置
        text_data_json = json.loads(text_data)
        print(text_data_json)
        mission_id = text_data_json['task_id']

        # 这里需要更改任务状态
        mission = Mission.objects.get(mission_id=mission_id)
        # 这里的文件转换先不考虑
        # 文件转换保存已完成
        python_file, original_model, pyname, omname, launcher = \
            mission.python_file, mission.original_model, \
            mission.pyname, mission.omname, mission.mission_launcher
        if mission.mission_state == 'waiting':
            Mission.objects.filter(mission_id=mission_id).update(mission_state='training')
            message.setLabel('Create')
            message.setGroupId(int(mission_id))
            muteki_send(clientSocket, message)
            muteki_recv(clientSocket)

        # 模型文件和初始模型保存在model文件夹下
        print("当前路径", str(os.getcwd()))
        python_file = python_file.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
        print(python_file)
        with open('rear_core/model/' + pyname, 'wb+') as file:
            file.write(python_file)
        if original_model is not None:
            # print("original_model",str(original_model))
            original_model = original_model.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
            with open('rear_core/model/' + omname, 'wb+') as file:
                file.write(original_model)

        # 动态导入module
        model = importlib.import_module('Client.client_rear.rear_core.model.' + pyname[:-3])
        # 动态获取module中的类实例变量
        net = getattr(model, pyname[:-3])()

        if original_model is not None:
            # 加载初始模型
            net.load_state_dict(torch.load('rear_core/model/' + omname))

        self.send(text_data=json.dumps(returnValue))
        returnValue['step'] = 1

        dev = torch.device('cpu')
        # if torch.cuda.is_available():
        #     dev = torch.device("cuda")
        lossFun = F.cross_entropy
        opti = optim.SGD(net.parameters(), lr=0.01)

        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            net = torch.nn.DataParallel(net)
        net = net.to(dev)

        # 获取训练和测试数据集
        train_dataset, test_dataset = model.getDataSetAllocation()

        # 循环:训练->与聚合端交互->计算精度->反馈给前端展示
        while True:  # desired_epoch代表要聚合的轮次
            # 模型训练,参数待定
            model_dict = model.modelTrain(net, dev, lossFun, opti, trainDataSet=train_dataset)

            message.setGroupId(mission_id)
            message.setData(model_dict)
            message.setLabel("Train")
            try:
                muteki_send(clientSocket, message)

                # 接受聚合端返回的模型数据
                message = pickle.loads(muteki_recv(clientSocket))
                if message.getLabel() == 'Close':
                    # 训练完成之后直接断开连接
                    returnValue['judge'] = disconnectWithAgg(message)
                    returnValue['step'] = 2
                    self.send(text_data=json.dumps(returnValue))
                    # 查询当前任务的状态和用户的
                    mission = Mission.objects.get(mission_id=mission_id)
                    mission_state = mission.mission_state
                    mission_omName = mission.omname
                    if mission_omName is None:
                        Mission.objects.filter(mission_id=mission_id).update(omname='default.model')
                        omname = 'default.model'
                    if mission_state == 'training':
                        # 保存模型数据为model文件
                        torch.save(message.getData(), 'rear_core/model/' + omname)
                        # 保存到数据库中
                        with open('rear_core/model/' + omname, 'rb') as file:
                            strb = ''.encode()
                            strb += file.read()
                        Mission.objects.filter(mission_id=mission_id).update(
                            mission_state='finish',
                            current_model=strb
                        )
                    break
                # 设置模型数据
                net.load_state_dict(message.getData(), strict=True)
            except socket.error:
                print('Socket.error happened')

            # 计算准确度,参数待定
            accuracy = model.accuracyCompute(net, dev, test_dataset)
            print('准确度等于%f' % accuracy)
            returnValue['accuracy'] = accuracy.cpu().detach().numpy().tolist()
    
            # 返回给前端准确度等参数展示, 使用self.send()方法
            # False 表示没有断开连接
            # True 表示断开连接
            # 断开连接后前端显示模型训练完成
            self.send(text_data=json.dumps(returnValue))
        self.close(code=1000)

模型训练的代码主要可以分为 块:

  1. 从数据库中获取聚合端的ip和port,与聚合端建立连接
  2. 从数据库中获取任务模型相关信息及文件,动态加载模型
  3. 循环:训练->与聚合端交互->计算精度反馈给前端展示
  4. 判断是否结束同时判断任务状态,是否进行模型参数的数据库保存

下面分别介绍一下每部分的代码

(1)从数据库中获取聚合端的ip和port,与聚合端建立连接
# 首先获取到聚合端的ip和端口号(如果为空,就返回提示:尚未聚合端,请耐心等待)
uniteServers = UniteServer.objects
if not uniteServers.exists():
	# 数据库为空,返回给前端提示
    returnValue['message'] = '聚合端为空,请耐心等待!'
    returnValue['message_state'] = True
    self.send(text_data=json.dumps(returnValue))
    return
uniteServer = uniteServers.first()
ip, port = uniteServer.ip, uniteServer.port

# 通过从数据库中读取的ip和port与聚合端建立连接
clientSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
clientSocket.connect((ip, port))
print("与聚合端建立连接")

从数据库获取的聚合端的ip和端口号,判断是否为空,如果不为空则与聚合端建立连接。

(2)从数据库中获取任务模型相关信息及文件,动态加载模型
# 根据任务id,从数据库中获取模型的名称、模型的py文件和初始模型,放到固定位置
text_data_json = json.loads(text_data)
print(text_data_json)
mission_id = text_data_json['task_id']

# 这里需要更改任务状态
mission = Mission.objects.get(mission_id=mission_id)
# 这里的文件转换先不考虑
# 文件转换保存已完成
python_file, original_model, pyname, omname, launcher = \
	mission.python_file, mission.original_model, \
    mission.pyname, mission.omname, mission.mission_launcher
if mission.mission_state == 'waiting':
    Mission.objects.filter(mission_id=mission_id).update(mission_state='training')
    message.setLabel('Create')
    message.setGroupId(int(mission_id))
    muteki_send(clientSocket, message)
    muteki_recv(clientSocket)

# 模型文件和初始模型保存在model文件夹下
print("当前路径", str(os.getcwd()))
python_file = python_file.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
# print(python_file)
with open('rear_core/model/' + pyname, 'wb+') as file:
    file.write(python_file)
if original_model is not None:
    # print("original_model",str(original_model))
    original_model = original_model.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')
    with open('rear_core/model/' + omname, 'wb+') as file:
    	file.write(original_model)

# 动态导入module
model = importlib.import_module('Client.client_rear.rear_core.model.' + pyname[:-3])
# 动态获取module中的类实例变量
net = getattr(model, pyname[:-3])()

if original_model is not None:
	# 加载初始模型
    net.load_state_dict(torch.load('rear_core/model/' + omname))

首先根据任务id从数据库中取出数据,然后将文件保存到本地,最后使用importlib.import_module进行动态导入。

(3)循环:训练->与聚合端交互->计算精度反馈给前端展示
while True:  # desired_epoch代表要聚合的轮次
	# 模型训练,参数待定
    model_dict = model.modelTrain(net, dev, lossFun, opti, trainDataSet=train_dataset)

    message.setGroupId(mission_id)
    message.setData(model_dict)
    message.setLabel("Train")
    try:
    	muteki_send(clientSocket, message)

        # 接受聚合端返回的模型数据
        message = pickle.loads(muteki_recv(clientSocket))
        if message.getLabel() == 'Close':
        	# 训练完成之后直接断开连接
            returnValue['judge'] = disconnectWithAgg(message)
            returnValue['step'] = 2
            self.send(text_data=json.dumps(returnValue))
            # 查询当前任务的状态和用户的
            mission = Mission.objects.get(mission_id=mission_id)
            mission_state = mission.mission_state
            mission_omName = mission.omname
            if mission_omName is None:
            	Mission.objects.filter(mission_id=mission_id).update(omname='default.model')
                omname = 'default.model'
            if mission_state == 'training':
                # 保存模型数据为model文件
                torch.save(message.getData(), 'rear_core/model/' + omname)
                # 保存到数据库中
                with open('rear_core/model/' + omname, 'rb') as file:
                    strb = ''.encode()
                    strb += file.read()
                Mission.objects.filter(mission_id=mission_id).update(
                    mission_state='finish',
                    current_model=strb
                )
            break
        # 设置模型数据
        net.load_state_dict(message.getData(), strict=True)
	except socket.error:
        print('Socket.error happened')

    # 计算准确度,参数待定
    accuracy = model.accuracyCompute(net, dev, test_dataset)
    print('准确度等于%f' % accuracy)
    returnValue['accuracy'] = accuracy.cpu().detach().numpy().tolist()

    # 返回给前端准确度等参数展示, 使用self.send()方法
    # False 表示没有断开连接
    # True 表示断开连接
    # 断开连接后前端显示模型训练完成
    self.send(text_data=json.dumps(returnValue))

获取到模型参数,发送给聚合端聚合,接收聚合端返回的模型参数,判断是否断开连接,如果断开则执行(4);如果不断开连接,则更新本地模型参数,计算精度给前端返回,重复执行上述操作。

(4)判断是否结束同时判断任务状态,是否进行模型参数的数据库保存
if message.getLabel() == 'Close':
	# 训练完成之后直接断开连接
    returnValue['judge'] = disconnectWithAgg(message)
    returnValue['step'] = 2
    self.send(text_data=json.dumps(returnValue))
    # 查询当前任务的状态和用户的
    mission = Mission.objects.get(mission_id=mission_id)
    mission_state = mission.mission_state
    mission_omName = mission.omname
    if mission_omName is None:
        Mission.objects.filter(mission_id=mission_id).update(omname='default.model')
    	omname = 'default.model'
    if mission_state == 'training':
        # 保存模型数据为model文件
        torch.save(message.getData(), 'rear_core/model/' + omname)
        # 保存到数据库中
        with open('rear_core/model/' + omname, 'rb') as file:
            strb = ''.encode()
            strb += file.read()
        Mission.objects.filter(mission_id=mission_id).update(
            mission_state='finish',
            current_model=strb
        )
	break

如果聚合端返回的信息是断开连接,则通过disconnectWithAgg方法与聚合端断开连接,反馈给前端信息,接下来判断任务状态,如果是training状态,标识模型尚未更新,将本地模型上传到数据库中,最后结束外层循环。

2.二次编码问题

将文件从数据库中读取的过程存在着比较大的问题:django的models(数据库处理文件)无法处理数据库的blob类型的数据,于是乎他认为是字符串类型,同时需要转成二进制类型,就有了提到的二次编码问题。

下面简单展示一下二次编码问题:

# 假设我有一个初始的字符串
temp = ' \
'
# 数据库中blob存储的是二进制串
temp = temp.encode()
# 此时temp = b'\r\n'
# django从数据库中读取到后以为它是字符串类型,同时需要二进制类型又对它进行的编码
temp = temp.encode()  # 这个做不到的,只是描述一下这个过程
# 此时temp = b'b\'\\r\\n\''
# 而我们需要的是temp = b'\r\n',这就是问题所在
# 于是乎我通过replace方法解决了这个问题
temp = temp.replace(b'b\'', b'').replace(b'\\r', b'\r').replace(b'\\n', b'\n').replace(b'\'', b'')

但是这个目前还存在着缺陷:

  1. 无法处理中文的转码
  2. 文件中所有单引号会转化成空

上述缺陷暂时没有解决方法。

总结

以上便是我项目实训最后的博客内容了,完成了对于项目最后一个部分(任务训练)的介绍和代码展示。

整个项目走下来,收获满满,同时感觉自己的文笔都有了进步,希望接下来的道路一路通畅。

下步计划

准备结项检查

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值