需求
为了实现在服务器上进行网络模型的调用计算,单独设计一个模块,可以独立于服务器主体,这样便于管理和调用。
模型计算需要一个6x8的list(6个特征值,8组历史值)作为输入,同时只有当接受完整的6x8数据后才进行计算。
socket
首先实现简单的socket通信
# server
import socket
# AF_INET表示IPV4地址家族
# SOCK_STREAM表示使用TCP协议
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("0.0.0.0", 1234)) # 绑定IP地址和端口号
s.listen() # 监听状态
c, addr = s.accept() # s用于监听,c用于通信
with c:
print(addr, "connected.")
while True:
data = c.recv(1024) # 1024表示一次性接受数据的最大长度
if not data:
break
c.sendall(data)
# client
import socket
# AF_INET表示IPV4地址家族
# SOCK_STREAM表示使用TCP协议
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", 1234))
s.sendall(b"Hello, Sever!")
data = s.recv(1024)
print("Received:", repr(data))
threading
安装netcat
使用netcat
工具可以更加清晰明了地进行socket通信。
对代码进行改进,可以通过threading实现多线程socket通信。
# server
import socket
import threading
def handle_client(c, addr):
print(addr, "connected.")
while True:
data = c.recv(1024) # 1024表示一次性接受数据的最大长度
if not data:
break
c.sendall(data)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("0.0.0.0", 1234)) # 绑定IP地址和端口号
s.listen() # 监听状态
while True:
c, addr = s.accept()
# 创建新的线程
t = threading.Thread(target=handle_client, args=(c, addr))
t.start()
json
由于模型的计算需要一个多维的列表形式,所以使用json传输所需的数据,random_state()
为随机生成状态矩阵的函数。
# client
import socket
import json
import numpy as np
import time
S_INFO = 6
S_LEN = 8
def random_state(num):
state = [num, num, num, num, num, num]
data_array = state
return data_array
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", 1234))
num = 0
while True:
data_array = random_state(num)
data = json.dumps(data_array)
s.sendall(bytes(data.encode('utf-8')))
data = []
time.sleep(4)
num = num + 1
同时在server
端修改handle_client()
函数。
def handle_client(c, addr):
print(addr, "connected.")
state = np.zeros(shape=(8, 6))
action = default_action
num = 0
while True:
# c.settimeout(15)
json_string = json.loads(c.recv(1024)) # list格式
if not json_string:
break
print("地址:", addr)
print("数据:", json_string)
# print('\n')
if num < S_LEN: # 没有充满state时,采用默认的决策
num = num + 1
state = add_state(state, json_string)
else:
state = add_state(state, json_string)
action = cal_action(state)
print(state)
print("action:", action)
print('\n')
c.close()
最后加入Tensorflow
模型加载的代码,完整代码如下所示。其中add_state()
函数负责对新的状态矩阵进行更新,cal_action()
是调用模型进行计算的函数。
# ars负责接受sever发送过来的信息,每满6*8之后,就可以进行一次模型的计算
import socket
import threading
import numpy as np
import json
# import tensorflow as tf
import our_a3c as a3c
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
S_INFO = 6
S_LEN = 8
A_DIM = 6
ACTOR_LR_RATE = 0.0001
CRITIC_LR_RATE = 0.001
default_action = 0
NN_MODEL = './model/nn_model_ep_13300.ckpt'
rand_times = 1
RAND_RANGE = 1000
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# [[0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0.]
# [1. 1. 1. 1. 1. 1.]
# [2. 2. 2. 2. 2. 2.]
# [3. 3. 3. 3. 3. 3.]
# [4. 4. 4. 4. 4. 4.]
# [5. 5. 5. 5. 5. 5.]]
def add_state(state, json_string):
state = np.roll(state, -1, axis=0)
state[-1, 0] = json_string[0]
state[-1, 1] = json_string[1]
state[-1, 2] = json_string[2]
state[-1, 3] = json_string[3]
state[-1, 4] = json_string[4]
state[-1, 5] = json_string[5]
new_state = state
return new_state
def cal_action(state):
# action = state[-1][-1]
tf.reset_default_graph() # 清除计算图
with tf.Session(config=config) as sess:
actor = a3c.ActorNetwork(sess,
state_dim=[S_INFO, S_LEN], action_dim=A_DIM,
learning_rate=ACTOR_LR_RATE)
critic = a3c.CriticNetwork(sess,
state_dim=[S_INFO, S_LEN],
learning_rate=CRITIC_LR_RATE)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver() # save neural net parameters
nn_model = NN_MODEL
if nn_model is not None: # nn_model is the path to file
saver.restore(sess, nn_model)
print("Model restored.")
action_prob = actor.predict(np.reshape(state, (1, S_INFO, S_LEN)))
action_cumsum = np.cumsum(action_prob)
hitcount = np.zeros(A_DIM)
for i in range(rand_times):
hit = (action_cumsum > np.random.randint(1, RAND_RANGE) / float(RAND_RANGE)).argmax()
hitcount[hit] = hitcount[hit] + 1
action = hitcount.argmax()
return action
def handle_client(c, addr):
print(addr, "connected.")
state = np.zeros(shape=(8, 6))
action = default_action
num = 0
while True:
# c.settimeout(15)
json_string = json.loads(c.recv(1024)) # list格式
if not json_string:
break
print("地址:", addr)
print("数据:", json_string)
# print('\n')
if num < S_LEN: # 没有充满state时,采用默认的码率决策
num = num + 1
state = add_state(state, json_string)
else:
state = add_state(state, json_string)
action = cal_action(state)
print(state)
print("action:", action)
print('\n')
# add_state(json_string)
c.close()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("0.0.0.0", 1234)) # 绑定IP地址和端口号
s.listen() # 监听状态
while True:
c, addr = s.accept()
# 创建新的线程
t = threading.Thread(target=handle_client, args=(c, addr))
t.start()
Ref
https://blog.csdn.net/zhangpeterx/article/details/83685758
https://blog.csdn.net/tang6457/article/details/109121387
https://blog.csdn.net/qq_37585545/article/details/82250984