Socket 编程遇到的同步问题

Socket 编程遇到的同步问题

背景

baseline Fedavg复现

图一:fedavg流程图
预期实现功能:
同步客户机和服务器之间的交流,实现客户本地计算后服务器统一收集客户的模型参数,然后再由服务器统一分发平均后的参数。

代码:

1.server
# 在循环外监听TCP socket
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(("127.0.0.1", 7788))
server.listen(128)

for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")

    # theta清零
    for key in theta.keys():
        theta[key] -= theta[key]

    # local_weights 存放 local_model
    local_weights = []

    # 建立TCP连接
    client_list = []
    print(f"{time.ctime()}   wait for connection")
    client01, addr01 = server.accept()
    print(f"{time.ctime()}: one client:{client01,addr01} is connected")
    client02, addr02 = server.accept()
    print(f"{time.ctime()}: one client:{client02,addr02} is connected")
    client_list.append(client01)
    client_list.append(client02)
    print(f"{time.ctime()}   all client connect")

    # 接收来自各个选中客户端的更新模型
    for i in range(0, m):
        print(f"{time.ctime()}   client {i} start uploading")
        total_data = b''
        while True:
            print(f"{time.ctime()}   start downloading...")
            data = client_list[i].recv(1024)
            print(f"{time.ctime()}   finish one loop...")
            if len(data)==0: break
            total_data += data
        print(f"{time.ctime()}   out of loop")
        # 断开连接
        client_list[i].close()
        local_weights.append(pickle.loads(total_data))
    print(f"{time.ctime()}   download success...")

    # 平均得到共识变量theta
    theta = average_weights(local_weights)

    # 将更新后的共识变量theta发送给所有节点
    if t < 9:
        # 建立TCP连接
        client_list = []
        client01, addr01 = server.accept()
        print(f"{time.ctime()}: one client:{client01,addr01} is connected")
        client02, addr02 = server.accept()
        print(f"{time.ctime()}: one client:{client02,addr02} is connected")
        client_list.append(client01)
        client_list.append(client02)
        print("connect success...")
        for i in range(0, m):
            print("start2......")
            client_list[i].sendall(pickle.dumps(theta))
            print("send theta succees!")
            # 断开连接
            client_list[i].close()

    # server不进行训练
    # upload_msg, model, alpha = train_loop(train_dataloader, model, loss_fn, optimizer, alpha, theta)

    # 服务器测试当前loss和精度
    model.load_state_dict(theta)
    test_loop(test_dataloader, model, loss_fn)

server.close()
2.client
for t in range(epochs):
    print(f"{time.ctime()}   Epoch {t + 1}\n-------------------------------")

    if t:
        # 接收来自server更新的theta
        # 创建TCP socket
        client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        # 连接server
        print(f"{time.ctime()}:   a client socket is created")
        client.connect(server)
        print(f"{time.ctime()}:   connect success......")
        print(f"socket info: {client.getsockname()}")

        # 接受数据
        total_data = b''
        while True:
            print(f"{time.ctime()}   start downloading...")
            data = client.recv(1024)
            print(f"{time.ctime()}   finish one loop...")
            total_data += data
            if len(data)==0: break
        theta = pickle.loads(total_data)
        print("downloading success......")
        client.close()

    # 本地训练,更新参数
    upload_msg, model, alpha = train_loop(train_dataloader, model, loss_fn, optimizer, alpha, theta)

    # 上传upload_msg给服务器
    # 创建TCP socket
    client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    print(f"one client {client} is created")
    client.connect(server)
    print(f"{time.ctime()}:   connect success......")
    print(f"socket info: {client.getsockname()}")
    print("uploading...")
    client.sendall(pickle.dumps(upload_msg))
    print(f"{time.ctime()}   upload success")
    client.close()

    # client可以不需要test
    test_loop(test_dataloader, model, loss_fn)

print("Done!")

日志:

server卡在了收集第二个客户本地参数的循环中:
在这里插入图片描述
client0提示上传成功,等待服务器发送参数
在这里插入图片描述
client1也提示发送成功,等待服务器分发更新的参数
在这里插入图片描述
这样也就陷入了死锁之中,即服务器等待客户1上传参数,而两个客户均等待服务器分发参数。
通过打印每次连接socket的地址也就是socket.getsocketname(),发现问题的原因是两个客户的训练速度不同步,导致服务器开放的两个accept第一次被客户0拿去上传本地参数,第二个又被客户0拿去准备接收更新参数。而客户1虽然没有拿到accept但是connect成功后正常发送参数,然后再次与服务器相连,进入等待传参的阻塞状态。

socket基本command特性

socket的基本command包括六个:

  1. socket.listen()
  2. socket.accept()
  3. socket.connect()
  4. socket.send()/socket.sendall()
  5. socket.recv()
  6. socket.bind()

一般的server-client模式下先由服务器建立一个socket并且和固定ip的固定端口相绑定:

import socket
server=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind((''127.0.0.1",1500))

接着server开始侦听连接:

server.listen(10)

参数代表等候连接的队列最大容量,于此同时,客户创建socket并且连接服务器端口,listen的缓冲队列就会加入连接请求。

client=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect((''127.0.0.1",1500)) ##需要指定连接的ip以及端口

而服务器若想与客户建立连接还需要accept客户机的连接请求,也就是socket.accept(),函数返回创立连接的专用socket(conn)以及客户的地址(addr)

conn,addr=server.accept()

如果客户机send()二进制数据,服务器的conn的缓冲区就会接收到相应的数据。

## client:
client.send(b'hello world!')

## server:
data=conn.recv(1024)

那么打印data就可以得到"hello world!"了。

socket基本command的阻塞与返回

要想理解baseline网络编程出现的问题不仅得弄明白socket编程的这些基础指令的功能,还得明白这些指令什么时候阻塞,什么时候返回。
非阻塞指令:

  1. socket.bind()
  2. socket.listen()
  3. socket.send()/sendall()
  4. socket.connect()

阻塞指令:

  1. socket.accept()
  2. socket.recv()

需要理解listen(size)创建了一个连接池,当客户发送connect请求,请求直接入队。而当listen队列里无请求,服务器的accept()指令将会阻塞直到得到一个请求连接。
同样的,当客户发送数据给服务器时(client.send(data)),数据也是立即被发送的,并不等server是否接收。而当server.recv(size)数据时,只有下述三种情况下会返回:

  1. recv(size)的数据超过size大小
  2. 缓冲区收到数据并且一段时间后(测试大约0.1s)无新数据到达
  3. 客户断开连接.
    因此在baseline中,因为服务器不会发送数据给客户,客户将一直阻塞在接收数据的状态中。

所以可见本问题的核心在于协调不同客户的连接请求的次序,让服务器清楚的知道socket连接的对面是哪个客户机,而不是随机让客户争抢连接请求次序。

解决办法

1.让客户停等(stop-wait)

将初始baseline的客户本地训练完即发送模型参数改为本地训练完等待服务器指令,从而将沟通的主动权还给服务器。
在这里插入图片描述
让服务器挨个询问并接收客户的本地训练参数。例如服务器一开始先问客户机0是否训练完毕,发送完请求后开始等待回应,这样服务器开始进入等待状态。当客户机0本地训练完后接收服务器的指令,发现服务器询问状态,客户机0随即返回已经准备好,并立刻发送本地数据,然后进入下一步指令的等待状态。而服务器接收到客户机0返回的准备好的反馈后,开始循环接收本地参数,直到接收参数大小不满最大size,说明是最后一个参数包。接着继续询问其他客户是否准备好。
当收集完所有客户的反馈参数后,服务器挨个给客户机发送接收更新参数的指令,并发送更新参数。然后进入下一轮更新。
通过这样的逻辑结构实现了客户机和服务器的同步。但是缺点在于效率比较低下。

2.引入多线程

自然而然的,可以想到将第一步等待询问优化为多线程编程,实现服务器并行对多客户指令沟通。

server:
import socket
import time
import threading
import numpy as np
import sys
import csv
num_devices=2
num_epoch=5
flags=np.zeros(num_devices).tolist()
def download(conn,addr,i):
    conn.send(b'upload theta!')
    file = open("communication.txt", "a")
    writer = csv.writer(file)
    writer.writerow([time.ctime(),server.getsockname(),f'ask client {addr} to upload theta'])
    file.close()
    print(f'ask client {addr} to upload theta')
    answer=conn.recv(1024)
    print(f'client {addr} reply {answer}')
    total_data=b''
    while True:
        data = conn.recv(1024)
        total_data += data
        if len(data)<1024: break ##更改判断条件
    local_weights.append(pickle.loads(total_data))
    file = open("communication.txt", "a")
    writer = csv.writer(file)
    writer.writerow([time.ctime(), server.getsockname(),f"finish collecting theta from client {addr}"])
    file.close()
    print(f"client {addr} finished uploading theta")
    flags[i]=1
    sys.exit()
server=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(('127.0.0.1',1500))
server.listen(10)
conn=[]
for i in range(num_devices):
    connection,addr=server.accept()
    file = open("communication.txt", "a")
    writer = csv.writer(file)
    writer.writerow([time.ctime(), server.getsockname(),f' client {addr} is accepted'])
    file.close()
    print(f'{time.ctime()}: client {addr} is accepted')
    conn.append([connection,addr])
print("all clients are connected")
for epoch in range(num_epoch):
    print(f'Epoch {epoch}---------------------------')
    ##downloading theta
    flags=np.zeros(num_devices).tolist()
    for i in range(num_devices):
        t=threading.Thread(target=download,args=(conn[i][0],conn[i][1],i))
        t.start()
    
    ##checking flags
    while True:
        count=0
        for i in flags:
            if i==1:count=count+1
        if count==num_devices: break
    file = open("communication.txt", "a")
    writer = csv.writer(file)
    writer.writerow([time.ctime(), server.getsockname(),'finish collecting'])
    file.close()
    print("finish collecting")
    ##计算
    theta = average_weights(local_weights)
    file = open("communication.txt", "a")
    writer = csv.writer(file)
    writer.writerow([time.ctime(), server.getsockname(),'finished averaging!'])
    file.close()
    print("finished averaging!")
    ##delivering consensus theta
    file = open("communication.txt", "a")
    writer = csv.writer(file)
    writer.writerow([time.ctime(), server.getsockname(),f'Epoch {epoch}: start delivering'])
    file.close()
    print(f"Epoch {epoch}: start delivering")
    for i in range(num_devices):
        file = open("communication.txt", "a")
        writer = csv.writer(file)
        writer.writerow([time.ctime(), server.getsockname(),f'ask client {conn[i][1]} to download consensus theta'])
        file.close()
        conn[i][0].send(b'download theta!')
        time.sleep(1)
        conn[i][0].sendall(pickle.dumps(theta))
        time.sleep(1)
    file = open("communication.txt", "a")
    writer = csv.writer(file)
    writer.writerow([time.ctime(), server.getsockname(),f'Epoch {epoch}: finish delivering'])
    file.close()
    print(f"Epoch {epoch}: finish delivering")
client:
import time
import socket
client=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(('127.0.0.1',1500))
print("connection established!")
num_epoch=5
for i in range(num_epoch):
    print(f"Epoch {i}: start calculation")
    #calculate locally
    upload_msg, model, alpha = train_loop(train_dataloader, model, loss_fn, optimizer, alpha, theta)
    test_loop(test_dataloader, model, loss_fn)
    print("calculation finished!")
    
    ##upload waiting
    while True:
        data=client.recv(1024)
        if data==b'upload theta!':
            print(f"Epoch {i}:server required uploading theta")
            client.send(b'Ready to upload theta')
            time.sleep(1)
            client.sendall(pickle.dumps(upload_msg))
            print("send all theta")
            
        elif data==b'download theta!':
            print(f"Epoch {i}:download theta from server")
            total_data = b''
            while True:
                data = client.recv(1024)
                total_data += data
                if len(data)<1024: break
            break
        else:
            continue
print("finished!")
client.close()

实现了交流的优化

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值