Socket 编程遇到的同步问题
背景
baseline Fedavg复现:
预期实现功能:
同步客户机和服务器之间的交流,实现客户本地计算后服务器统一收集客户的模型参数,然后再由服务器统一分发平均后的参数。
代码:
1.server
# 在循环外监听TCP socket
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(("127.0.0.1", 7788))
server.listen(128)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
# theta清零
for key in theta.keys():
theta[key] -= theta[key]
# local_weights 存放 local_model
local_weights = []
# 建立TCP连接
client_list = []
print(f"{time.ctime()} wait for connection")
client01, addr01 = server.accept()
print(f"{time.ctime()}: one client:{client01,addr01} is connected")
client02, addr02 = server.accept()
print(f"{time.ctime()}: one client:{client02,addr02} is connected")
client_list.append(client01)
client_list.append(client02)
print(f"{time.ctime()} all client connect")
# 接收来自各个选中客户端的更新模型
for i in range(0, m):
print(f"{time.ctime()} client {i} start uploading")
total_data = b''
while True:
print(f"{time.ctime()} start downloading...")
data = client_list[i].recv(1024)
print(f"{time.ctime()} finish one loop...")
if len(data)==0: break
total_data += data
print(f"{time.ctime()} out of loop")
# 断开连接
client_list[i].close()
local_weights.append(pickle.loads(total_data))
print(f"{time.ctime()} download success...")
# 平均得到共识变量theta
theta = average_weights(local_weights)
# 将更新后的共识变量theta发送给所有节点
if t < 9:
# 建立TCP连接
client_list = []
client01, addr01 = server.accept()
print(f"{time.ctime()}: one client:{client01,addr01} is connected")
client02, addr02 = server.accept()
print(f"{time.ctime()}: one client:{client02,addr02} is connected")
client_list.append(client01)
client_list.append(client02)
print("connect success...")
for i in range(0, m):
print("start2......")
client_list[i].sendall(pickle.dumps(theta))
print("send theta succees!")
# 断开连接
client_list[i].close()
# server不进行训练
# upload_msg, model, alpha = train_loop(train_dataloader, model, loss_fn, optimizer, alpha, theta)
# 服务器测试当前loss和精度
model.load_state_dict(theta)
test_loop(test_dataloader, model, loss_fn)
server.close()
2.client
for t in range(epochs):
print(f"{time.ctime()} Epoch {t + 1}\n-------------------------------")
if t:
# 接收来自server更新的theta
# 创建TCP socket
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 连接server
print(f"{time.ctime()}: a client socket is created")
client.connect(server)
print(f"{time.ctime()}: connect success......")
print(f"socket info: {client.getsockname()}")
# 接受数据
total_data = b''
while True:
print(f"{time.ctime()} start downloading...")
data = client.recv(1024)
print(f"{time.ctime()} finish one loop...")
total_data += data
if len(data)==0: break
theta = pickle.loads(total_data)
print("downloading success......")
client.close()
# 本地训练,更新参数
upload_msg, model, alpha = train_loop(train_dataloader, model, loss_fn, optimizer, alpha, theta)
# 上传upload_msg给服务器
# 创建TCP socket
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print(f"one client {client} is created")
client.connect(server)
print(f"{time.ctime()}: connect success......")
print(f"socket info: {client.getsockname()}")
print("uploading...")
client.sendall(pickle.dumps(upload_msg))
print(f"{time.ctime()} upload success")
client.close()
# client可以不需要test
test_loop(test_dataloader, model, loss_fn)
print("Done!")
日志:
server卡在了收集第二个客户本地参数的循环中:
client0提示上传成功,等待服务器发送参数
client1也提示发送成功,等待服务器分发更新的参数
这样也就陷入了死锁之中,即服务器等待客户1上传参数,而两个客户均等待服务器分发参数。
通过打印每次连接socket的地址也就是socket.getsocketname(),发现问题的原因是两个客户的训练速度不同步,导致服务器开放的两个accept第一次被客户0拿去上传本地参数,第二个又被客户0拿去准备接收更新参数。而客户1虽然没有拿到accept但是connect成功后正常发送参数,然后再次与服务器相连,进入等待传参的阻塞状态。
socket基本command特性
socket的基本command包括六个:
- socket.listen()
- socket.accept()
- socket.connect()
- socket.send()/socket.sendall()
- socket.recv()
- socket.bind()
一般的server-client模式下先由服务器建立一个socket并且和固定ip的固定端口相绑定:
import socket
server=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind((''127.0.0.1",1500))
接着server开始侦听连接:
server.listen(10)
参数代表等候连接的队列最大容量,于此同时,客户创建socket并且连接服务器端口,listen的缓冲队列就会加入连接请求。
client=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect((''127.0.0.1",1500)) ##需要指定连接的ip以及端口
而服务器若想与客户建立连接还需要accept客户机的连接请求,也就是socket.accept(),函数返回创立连接的专用socket(conn)以及客户的地址(addr)
conn,addr=server.accept()
如果客户机send()二进制数据,服务器的conn的缓冲区就会接收到相应的数据。
## client:
client.send(b'hello world!')
## server:
data=conn.recv(1024)
那么打印data就可以得到"hello world!"了。
socket基本command的阻塞与返回
要想理解baseline网络编程出现的问题不仅得弄明白socket编程的这些基础指令的功能,还得明白这些指令什么时候阻塞,什么时候返回。
非阻塞指令:
- socket.bind()
- socket.listen()
- socket.send()/sendall()
- socket.connect()
阻塞指令:
- socket.accept()
- socket.recv()
需要理解listen(size)创建了一个连接池,当客户发送connect请求,请求直接入队。而当listen队列里无请求,服务器的accept()指令将会阻塞直到得到一个请求连接。
同样的,当客户发送数据给服务器时(client.send(data)),数据也是立即被发送的,并不等server是否接收。而当server.recv(size)数据时,只有下述三种情况下会返回:
- recv(size)的数据超过size大小
- 缓冲区收到数据并且一段时间后(测试大约0.1s)无新数据到达
- 客户断开连接.
因此在baseline中,因为服务器不会发送数据给客户,客户将一直阻塞在接收数据的状态中。
所以可见本问题的核心在于协调不同客户的连接请求的次序,让服务器清楚的知道socket连接的对面是哪个客户机,而不是随机让客户争抢连接请求次序。
解决办法
1.让客户停等(stop-wait)
将初始baseline的客户本地训练完即发送模型参数改为本地训练完等待服务器指令,从而将沟通的主动权还给服务器。
让服务器挨个询问并接收客户的本地训练参数。例如服务器一开始先问客户机0是否训练完毕,发送完请求后开始等待回应,这样服务器开始进入等待状态。当客户机0本地训练完后接收服务器的指令,发现服务器询问状态,客户机0随即返回已经准备好,并立刻发送本地数据,然后进入下一步指令的等待状态。而服务器接收到客户机0返回的准备好的反馈后,开始循环接收本地参数,直到接收参数大小不满最大size,说明是最后一个参数包。接着继续询问其他客户是否准备好。
当收集完所有客户的反馈参数后,服务器挨个给客户机发送接收更新参数的指令,并发送更新参数。然后进入下一轮更新。
通过这样的逻辑结构实现了客户机和服务器的同步。但是缺点在于效率比较低下。
2.引入多线程
自然而然的,可以想到将第一步等待询问优化为多线程编程,实现服务器并行对多客户指令沟通。
server:
import socket
import time
import threading
import numpy as np
import sys
import csv
num_devices=2
num_epoch=5
flags=np.zeros(num_devices).tolist()
def download(conn,addr,i):
conn.send(b'upload theta!')
file = open("communication.txt", "a")
writer = csv.writer(file)
writer.writerow([time.ctime(),server.getsockname(),f'ask client {addr} to upload theta'])
file.close()
print(f'ask client {addr} to upload theta')
answer=conn.recv(1024)
print(f'client {addr} reply {answer}')
total_data=b''
while True:
data = conn.recv(1024)
total_data += data
if len(data)<1024: break ##更改判断条件
local_weights.append(pickle.loads(total_data))
file = open("communication.txt", "a")
writer = csv.writer(file)
writer.writerow([time.ctime(), server.getsockname(),f"finish collecting theta from client {addr}"])
file.close()
print(f"client {addr} finished uploading theta")
flags[i]=1
sys.exit()
server=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(('127.0.0.1',1500))
server.listen(10)
conn=[]
for i in range(num_devices):
connection,addr=server.accept()
file = open("communication.txt", "a")
writer = csv.writer(file)
writer.writerow([time.ctime(), server.getsockname(),f' client {addr} is accepted'])
file.close()
print(f'{time.ctime()}: client {addr} is accepted')
conn.append([connection,addr])
print("all clients are connected")
for epoch in range(num_epoch):
print(f'Epoch {epoch}---------------------------')
##downloading theta
flags=np.zeros(num_devices).tolist()
for i in range(num_devices):
t=threading.Thread(target=download,args=(conn[i][0],conn[i][1],i))
t.start()
##checking flags
while True:
count=0
for i in flags:
if i==1:count=count+1
if count==num_devices: break
file = open("communication.txt", "a")
writer = csv.writer(file)
writer.writerow([time.ctime(), server.getsockname(),'finish collecting'])
file.close()
print("finish collecting")
##计算
theta = average_weights(local_weights)
file = open("communication.txt", "a")
writer = csv.writer(file)
writer.writerow([time.ctime(), server.getsockname(),'finished averaging!'])
file.close()
print("finished averaging!")
##delivering consensus theta
file = open("communication.txt", "a")
writer = csv.writer(file)
writer.writerow([time.ctime(), server.getsockname(),f'Epoch {epoch}: start delivering'])
file.close()
print(f"Epoch {epoch}: start delivering")
for i in range(num_devices):
file = open("communication.txt", "a")
writer = csv.writer(file)
writer.writerow([time.ctime(), server.getsockname(),f'ask client {conn[i][1]} to download consensus theta'])
file.close()
conn[i][0].send(b'download theta!')
time.sleep(1)
conn[i][0].sendall(pickle.dumps(theta))
time.sleep(1)
file = open("communication.txt", "a")
writer = csv.writer(file)
writer.writerow([time.ctime(), server.getsockname(),f'Epoch {epoch}: finish delivering'])
file.close()
print(f"Epoch {epoch}: finish delivering")
client:
import time
import socket
client=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(('127.0.0.1',1500))
print("connection established!")
num_epoch=5
for i in range(num_epoch):
print(f"Epoch {i}: start calculation")
#calculate locally
upload_msg, model, alpha = train_loop(train_dataloader, model, loss_fn, optimizer, alpha, theta)
test_loop(test_dataloader, model, loss_fn)
print("calculation finished!")
##upload waiting
while True:
data=client.recv(1024)
if data==b'upload theta!':
print(f"Epoch {i}:server required uploading theta")
client.send(b'Ready to upload theta')
time.sleep(1)
client.sendall(pickle.dumps(upload_msg))
print("send all theta")
elif data==b'download theta!':
print(f"Epoch {i}:download theta from server")
total_data = b''
while True:
data = client.recv(1024)
total_data += data
if len(data)<1024: break
break
else:
continue
print("finished!")
client.close()
实现了交流的优化