【代码记录】Python——使用socket+threading+tensorflow实现python多线程调用模型

需求

为了实现在服务器上进行网络模型的调用计算,单独设计一个模块,可以独立于服务器主体,这样便于管理和调用。
模型计算需要一个6x8的list(6个特征值,8组历史值)作为输入,同时只有当接受完整的6x8数据后才进行计算。

socket

首先实现简单的socket通信

#  server
import socket
# AF_INET表示IPV4地址家族
# SOCK_STREAM表示使用TCP协议
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind(("0.0.0.0", 1234))   # 绑定IP地址和端口号
    s.listen()  # 监听状态
    c, addr = s.accept()    # s用于监听,c用于通信
    with c:
        print(addr, "connected.")
        while True:
            data = c.recv(1024)     # 1024表示一次性接受数据的最大长度
            if not data:
                break
            c.sendall(data)
#  client
import socket
# AF_INET表示IPV4地址家族
# SOCK_STREAM表示使用TCP协议
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.connect(("127.0.0.1", 1234))
    s.sendall(b"Hello, Sever!")
    data = s.recv(1024)
    print("Received:", repr(data))

在这里插入图片描述

在这里插入图片描述

threading

安装netcat

使用netcat工具可以更加清晰明了地进行socket通信。
对代码进行改进,可以通过threading实现多线程socket通信。

#  server
import socket
import threading

def handle_client(c, addr):
    print(addr, "connected.")

    while True:
        data = c.recv(1024)     # 1024表示一次性接受数据的最大长度
        if not data:
            break
        c.sendall(data)

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind(("0.0.0.0", 1234))   # 绑定IP地址和端口号
    s.listen()  # 监听状态

    while True:
        c, addr = s.accept()

        # 创建新的线程
        t = threading.Thread(target=handle_client, args=(c, addr))
        t.start()

在这里插入图片描述

json

由于模型的计算需要一个多维的列表形式,所以使用json传输所需的数据,random_state()为随机生成状态矩阵的函数。

#  client
import socket
import json
import numpy as np
import time
S_INFO = 6
S_LEN = 8

def random_state(num):
    state = [num, num, num, num, num, num]
    data_array = state
    return data_array
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.connect(("127.0.0.1", 1234))
    num = 0
    while True:
        data_array = random_state(num)
        data = json.dumps(data_array)
        s.sendall(bytes(data.encode('utf-8')))
        data = []
        time.sleep(4)
        num = num + 1

同时在server端修改handle_client()函数。

def handle_client(c, addr):
    print(addr, "connected.")
    state = np.zeros(shape=(8, 6))
    action = default_action
    num = 0
    while True:
        # c.settimeout(15)
        json_string = json.loads(c.recv(1024))  # list格式
        if not json_string:
            break
        print("地址:", addr)
        print("数据:", json_string)
        # print('\n')

        if num < S_LEN:  # 没有充满state时,采用默认的决策
            num = num + 1
            state = add_state(state, json_string)
        else:
            state = add_state(state, json_string)
            action = cal_action(state)
        print(state)
        print("action:", action)
        print('\n')
    c.close()

最后加入Tensorflow模型加载的代码,完整代码如下所示。其中add_state()函数负责对新的状态矩阵进行更新,cal_action()是调用模型进行计算的函数。

# ars负责接受sever发送过来的信息,每满6*8之后,就可以进行一次模型的计算
import socket
import threading
import numpy as np
import json
# import tensorflow as tf
import our_a3c as a3c

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

S_INFO = 6
S_LEN = 8
A_DIM = 6

ACTOR_LR_RATE = 0.0001
CRITIC_LR_RATE = 0.001
default_action = 0
NN_MODEL = './model/nn_model_ep_13300.ckpt'
rand_times = 1
RAND_RANGE = 1000

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

# [[0. 0. 0. 0. 0. 0.]
#  [0. 0. 0. 0. 0. 0.]
#  [0. 0. 0. 0. 0. 0.]
#  [1. 1. 1. 1. 1. 1.]
#  [2. 2. 2. 2. 2. 2.]
#  [3. 3. 3. 3. 3. 3.]
#  [4. 4. 4. 4. 4. 4.]
#  [5. 5. 5. 5. 5. 5.]]

def add_state(state, json_string):
    state = np.roll(state, -1, axis=0)
    state[-1, 0] = json_string[0]
    state[-1, 1] = json_string[1]
    state[-1, 2] = json_string[2]
    state[-1, 3] = json_string[3]
    state[-1, 4] = json_string[4]
    state[-1, 5] = json_string[5]
    new_state = state
    return new_state

def cal_action(state):
    # action = state[-1][-1]
    tf.reset_default_graph()    # 清除计算图
    with tf.Session(config=config) as sess:
        actor = a3c.ActorNetwork(sess,
                                 state_dim=[S_INFO, S_LEN], action_dim=A_DIM,
                                 learning_rate=ACTOR_LR_RATE)
        critic = a3c.CriticNetwork(sess,
                                   state_dim=[S_INFO, S_LEN],
                                   learning_rate=CRITIC_LR_RATE)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()  # save neural net parameters
        nn_model = NN_MODEL
        if nn_model is not None:  # nn_model is the path to file
            saver.restore(sess, nn_model)
            print("Model restored.")

        action_prob = actor.predict(np.reshape(state, (1, S_INFO, S_LEN)))
        action_cumsum = np.cumsum(action_prob)

        hitcount = np.zeros(A_DIM)
        for i in range(rand_times):
            hit = (action_cumsum > np.random.randint(1, RAND_RANGE) / float(RAND_RANGE)).argmax()
            hitcount[hit] = hitcount[hit] + 1
        action = hitcount.argmax()

    return action

def handle_client(c, addr):
    print(addr, "connected.")
    state = np.zeros(shape=(8, 6))
    action = default_action
    num = 0
    while True:
        # c.settimeout(15)
        json_string = json.loads(c.recv(1024))  # list格式
        if not json_string:
            break
        print("地址:", addr)
        print("数据:", json_string)
        # print('\n')
        if num < S_LEN:  # 没有充满state时,采用默认的码率决策
            num = num + 1
            state = add_state(state, json_string)
        else:
            state = add_state(state, json_string)
            action = cal_action(state)
        print(state)
        print("action:", action)
        print('\n')
        # add_state(json_string)
    c.close()

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind(("0.0.0.0", 1234))  # 绑定IP地址和端口号
    s.listen()  # 监听状态

    while True:
        c, addr = s.accept()

        # 创建新的线程
        t = threading.Thread(target=handle_client, args=(c, addr))
        t.start()

在这里插入图片描述

Ref

https://blog.csdn.net/zhangpeterx/article/details/83685758
https://blog.csdn.net/tang6457/article/details/109121387
https://blog.csdn.net/qq_37585545/article/details/82250984

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值