websocket_flask

1.使用socket协议构建server client文件,服务端构建maskrcnn分割模型,客户端发送图片返回分割结果;使用纯socket通信,通信传输效率较低,接收数据需要1024byte连续接收

代码如下

#server.py 

import socket
import torchvision
import torch
import numpy as np
import cv2
import time

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# model.cuda()
model.eval()

def forward_image_list(input_tensor_list):
    pred_list = model(input_tensor_list)
    mask_list = []
    for pred in pred_list:
        pred_score = list(pred['scores'].detach().cpu().numpy())
        pred_class = list(pred['labels'].detach().cpu().numpy())
        select_ind = [pred_score.index(x) for x, label in zip(pred_score, pred_class) if x > 0.9 and label == 1]
        masks = pred['masks']
        select_mask = masks[select_ind, :, :, :] > 0.3
        total_mask = torch.sum(select_mask, dim=0).float()
        total_mask = (total_mask>=1).int()*255
        mask_list.append(total_mask)
    return mask_list


s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('127.0.0.1', 2222))
s.listen(5)
print("waiting...")

rescale = 1
height = int(500 * rescale)
width = int(1200 * rescale)
lenth = width*height*3

while True:
    sock, addr = s.accept()
    print("sock ",sock)
    print("addr",addr)
    while True:
        data = sock.recv(1024)
        if len(data)>0:
            total_data = data
            while len(total_data)<lenth and len(data)>0:
                data = sock.recv(1024)
                total_data += data
                # print(len(total_data))
            # print("data",total_data)c
            # total_data recv finished
            np_array = np.frombuffer(total_data, dtype=np.uint8)
            # print("np_array.shape",np_array.shape)
            # cv2.imwrite("person_resize.jpg",np_array.reshape((500,1200,3)))

            input_tensor = torch.from_numpy(np_array).float().view((height, width, 3))
            input_tensor = input_tensor.permute((2, 0, 1))
            input_tensor = input_tensor/255
            t1=time.time()
            mask_list = forward_image_list([input_tensor])
            t2=time.time()
            print("time is(s) :",(t2-t1))
            mask0_numpy = mask_list[0].detach().cpu().numpy().astype(np.uint8)
            # print("mask0_numpy.shape",mask0_numpy.shape)
            # cv2.imwrite("mask.jpg",mask0_numpy[0])
            mask0_numpy_bytes = mask0_numpy.tobytes()
            sock.sendall(mask0_numpy_bytes)
            # print("send mask bytes!" + str(len(mask0_numpy_bytes)))
        else:
            break

client.py

import socket
import torchvision
import torch
import numpy as np
import time,sys
import cv2

for i in range(1):
    try:
        client_send = socket.socket()
        ip_port = ("127.0.0.1", 2222)
        client_send.connect(ip_port)
        t1=time.time()
        img_data=cv2.imread("person.jpg")
        img_data=cv2.resize(img_data,(1200,500))
        cmd_data=img_data.tobytes()
        client_send.sendall(cmd_data)
        data = client_send.recv(1024)
        if len(data)>0:
            rec_data=data
            while len(rec_data)<600000 and len(data)>0:
                data = client_send.recv(1024)
                rec_data += data
                print(len(rec_data))
  
        np_array = np.frombuffer(rec_data, dtype=np.uint8)
        re_np=np_array.reshape((500,1200))
        # cv2.imwrite("mask0.jpg",re_np)
        t2=time.time()
        print("fps",1/(t2-t1))

        client_send.close
    except:
        time.sleep(0.1)
        if(i >= 20):
            print('退出')
            sys.exit()
        print('发送命令[{}]时与主程序连接失败,次数:{}'.format("cmd", i+1))
    else:
        break

# re=str(data, encoding="utf-8").split("\n", 1)[0]

结果如图所示

2.Http服务器实现-基于python的简单服务器

1. 接受静态请求,`html`,`png`等文件

2. 接受动态请求,脚本类型为`python`

3. 提供`Session`服务

4. `root`是根目录,包含资源文件,脚本等

5. 使用线程池来管理请求

python server.py/client.py

实现client send req: (GET / HTTP/1.1 Host: 127.0.0.1:9999)

返回 res.html

实现原理 线程池管理+socket通信

server.py
# -*- coding=utf-8 -*-
import socket
import threading
import queue
from HttpHead import HttpRequest


# 每个任务线程
class WorkThread(threading.Thread):
    def __init__(self, work_queue):
        super().__init__()
        self.work_queue = work_queue
        self.daemon = True

    def run(self):
        while True:
            func, args = self.work_queue.get()
            func(*args)
            self.work_queue.task_done()


# 线程池
class ThreadPoolManger():
    def __init__(self, thread_number):
        self.thread_number = thread_number
        self.work_queue = queue.Queue()
        for i in range(self.thread_number):     # 生成一些线程来执行任务
            thread = WorkThread(self.work_queue)
            thread.start()

    def add_work(self, func, *args):
        self.work_queue.put((func, args))


def tcp_link(sock, addr):
    print('Accept new connection from %s:%s...' % addr)
    request = sock.recv(1024)
    # print("request",request)
    http_req = HttpRequest()
    
    http_req.passRequest(request)
    sock.send(http_req.getResponse().encode('utf-8'))
    sock.close()


def start_server():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    ip_addr=('127.0.0.1', 9999)
    s.bind(ip_addr)
    s.listen(10)
    thread_pool = ThreadPoolManger(5)
    print('listen in %s:%d' % ('127.0.0.1', 9999))
    while True:
        sock, addr = s.accept()
        print("sock ",sock)
        print("addr",addr)

        thread_pool.add_work(tcp_link, *(sock, addr))


if __name__ == '__main__':
    start_server()
    pass

client.py
#!E:\python\venv\Scripts

# -*- coding:utf-8 -*-

import socket
# from flask import template_rendered
import numpy as np
from importlib_metadata import re


def post_request():
    req = 'POST /?ni=00 HTTP/1.1\r\n'
    req = req + 'Host: 127.0.0.1:9999\r\n\r\n'
    req = req + 'name=linyi&data=163'
    return req


def get_request():
    req = 'GET / HTTP/1.1\r\n'
    req = req + 'Host: 127.0.0.1:9999\r\n\r\n'
    return req


def start_request():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.connect(('127.0.0.1', 9999))
    req = get_request()
    print("req:",req)

    # temp_data=np.array([2,3,4])
    # req=temp_data.tobytes()

    s.sendall(req.encode())
    buff = []
    while True:
        d = s.recv(1024)
        if d:
            buff.append(d.decode())
        else:
            break
    data = ''.join(buff)
    s.close()
    header, html = data.split('\r\n\r\n', 1)
    f = open('res.html', 'w')
    f.write(html)
    f.close()


if __name__ == '__main__':
    start_request()
    input("press any key to exit;")

Httphead.py
# -*- coding:utf-8 -*-
import os
import xml.dom.minidom


# 返回码
class ErrorCode(object):
    OK = "HTTP/1.1 200 OK\r\n"
    NOT_FOUND = "HTTP/1.1 404 Not Found\r\n"


# 将字典转成字符串
def dict2str(d):
    s = ''
    for i in d:
        s = s + i+': '+d[i]+'\r\n'
    return s

class Session(object):
    def __init__(self):
        self.data = dict()
        self.cook_file = None

    def getCookie(self, key):
        if key in self.data.keys():
            return self.data[key]
        return None

    def setCookie(self, key, value):
        self.data[key] = value

    def loadFromXML(self):
        import xml.dom.minidom as minidom
        root = minidom.parse(self.cook_file).documentElement
        for node in root.childNodes:
            if node.nodeName == '#text':
                continue
            else:
                self.setCookie(node.nodeName, node.childNodes[0].nodeValue)        

    def write2XML(self):
        import xml.dom.minidom as minidom
        dom = xml.dom.minidom.getDOMImplementation().createDocument(None, 'Root', None)
        root = dom.documentElement
        for key in self.data:
            node = dom.createElement(key)
            node.appendChild(dom.createTextNode(self.data[key]))
            root.appendChild(node)
        print(self.cook_file)
        with open(self.cook_file, 'w') as f:
            dom.writexml(f, addindent='\t', newl='\n', encoding='utf-8')


class HttpRequest(object):
    RootDir = 'root'
    NotFoundHtml = RootDir+'/404.html'
    CookieDir = 'root/cookie/'

    def __init__(self):
        self.method = None
        self.url = None
        self.protocol = None
        self.head = dict()
        self.Cookie = None
        self.request_data = dict()
        self.response_line = ''
        self.response_head = dict()
        self.response_body = ''
        self.session = None

    def passRequestLine(self, request_line):
        header_list = request_line.split(' ')
        self.method = header_list[0].upper()
        self.url = header_list[1]
        if self.url == '/':
            self.url = '/index.html'
        self.protocol = header_list[2]

    def passRequestHead(self, request_head):
        head_options = request_head.split('\r\n')
        for option in head_options:
            key, val = option.split(': ', 1)
            self.head[key] = val
            # print key, val
        if 'Cookie' in self.head:
            self.Cookie = self.head['Cookie']

    def passRequest(self, request):
        request = request.decode('utf-8')
        if len(request.split('\r\n', 1)) != 2:
            return
        request_line, body = request.split('\r\n', 1)
        request_head = body.split('\r\n\r\n', 1)[0]     # 头部信息
        self.passRequestLine(request_line)
        self.passRequestHead(request_head)

        # 所有post视为动态请求
        # get如果带参数也视为动态请求
        # 不带参数的get视为静态请求
        if self.method == 'POST':
            self.request_data = {}
            request_body = body.split('\r\n\r\n', 1)[1]
            parameters = request_body.split('&')   # 每一行是一个字段
            for i in parameters:
                if i=='':
                    continue
                key, val = i.split('=', 1)
                self.request_data[key] = val
            self.dynamicRequest(HttpRequest.RootDir + self.url)
        if self.method == 'GET':
            if self.url.find('?') != -1:        # 含有参数的get
                self.request_data = {}
                req = self.url.split('?', 1)[1]
                s_url = self.url.split('?', 1)[0]
                parameters = req.split('&')
                for i in parameters:
                    key, val = i.split('=', 1)
                    self.request_data[key] = val
                self.dynamicRequest(HttpRequest.RootDir + s_url)
            else:
                self.staticRequest(HttpRequest.RootDir + self.url)

    # 只提供制定类型的静态文件
    def staticRequest(self, path):
        # print path
        if not os.path.isfile(path):
            f = open(HttpRequest.NotFoundHtml, 'r')
            self.response_line = ErrorCode.NOT_FOUND
            self.response_head['Content-Type'] = 'text/html'
            self.response_body = f.read()
        else:
            extension_name = os.path.splitext(path)[1]  # 扩展名
            extension_set = {'.css', '.html', '.js'}
            if extension_name == '.png':
                f = open(path, 'rb')
                self.response_line = ErrorCode.OK
                self.response_head['Content-Type'] = 'text/png'
                self.response_body = f.read()
            elif extension_name in extension_set:
                f = open(path, 'r')
                self.response_line = ErrorCode.OK
                self.response_head['Content-Type'] = 'text/html'
                self.response_body = f.read()
            elif extension_name == '.py':
                self.dynamicRequest(path)
            # 其他文件不返回
            else:
                f = open(HttpRequest.NotFoundHtml, 'r')
                self.response_line = ErrorCode.NOT_FOUND
                self.response_head['Content-Type'] = 'text/html'
                self.response_body = f.read()

    def processSession(self):
        self.session = Session()
        # 没有提交cookie,创建cookie
        if self.Cookie is None:
            self.Cookie = self.generateCookie()
            cookie_file = self.CookieDir + self.Cookie
            self.session.cook_file = cookie_file
            self.session.write2XML()
        else:            
            cookie_file = self.CookieDir + self.Cookie
            self.session.cook_file = cookie_file
            if os.path.exists(cookie_file):
                self.session.loadFromXML()                
            # 当前cookie不存在,自动创建
            else:
                self.Cookie = self.generateCookie()
                cookie_file = self.CookieDir+self.Cookie
                self.session.cook_file = cookie_file
                self.session.write2XML()                
        return self.session


    def generateCookie(self):
        import time, hashlib
        cookie = str(int(round(time.time() * 1000)))
        hl = hashlib.md5()
        hl.update(cookie.encode(encoding='utf-8'))
        return cookie

    def dynamicRequest(self, path):
        # 如果找不到或者后缀名不是py则输出404
        if not os.path.isfile(path) or os.path.splitext(path)[1] != '.py':
            f = open(HttpRequest.NotFoundHtml, 'r')
            self.response_line = ErrorCode.NOT_FOUND
            self.response_head['Content-Type'] = 'text/html'
            self.response_body = f.read()
        else:
            # 获取文件名,并且将/替换成.
            file_path = path.split('.', 1)[0].replace('/', '.')
            self.response_line = ErrorCode.OK
            m = __import__(file_path)
            m.main.SESSION = self.processSession()            
            if self.method == 'POST':
                m.main.POST = self.request_data
                m.main.GET = None
            else:
                m.main.POST = None
                m.main.GET = self.request_data
            self.response_body = m.main.app()            
            self.response_head['Content-Type'] = 'text/html'
            self.response_head['Set-Cookie'] = self.Cookie

    def getResponse(self):
        return self.response_line+dict2str(self.response_head)+'\r\n'+self.response_body

返回的html文件

res.html
<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <style type="text/css">
        div{
            width: 100%;
            text-align:center;
        }
    </style>
</head>
<body>
    <div>
        <h1>this is index html</h1>
    </div>
</body>
<html>

3 使用flask构建web服务

resent分类任务

export_jit.py 导出模型
# export_jit_model.py
import torch
import torchvision.models as models

model = models.resnet50("/Users/ludongsheng/code/pycode/web_http/flask/resnet50.pth")
model.eval()

example_input = torch.rand(1, 3, 224, 224)

jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

两个html文件

home.html
<html>
    <head>
      <title>PyTorch Image Classification</title>
    </head>
    <body>
      <h1>PyTorch Image Classification</h1>
      <form method="POST" enctype="multipart/form-data" action="/predict">
        <input type="file" name="image">
        <input type="submit" value="Predict">
      </form>
    </body>
  </html>

predict.html
<html>
   <head>
     <title>Prediction Results</title>
   </head>
   <body>
     <h1>Prediction Results</h1>
     <p>Predicted Class: {{ predicted_class }}</p>
     <p>Probability: {{ probability }}</p>
     <h2>Other Classes</h2>
     <ul>
      {% for class_name, prob in class_probs %}
         <li>{{ class_name }}: {{ prob }}</li>
       {% endfor %}
     </ul>
   </body>
 </html>
app.py
from flask import Flask, request, render_template
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms

model = torch.jit.load('resnet50_jit.pth')
app = Flask(__name__)
# @app.route('/')
# def home():
#     return render_template('home.html')

def process_image(image):
    # Preprocess image for model
    transformation = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
    # print("transformation(image)",transformation(image))
    image_tensor = transformation(image).unsqueeze(0)
    return image_tensor


class_names = ['apple', 'banana'] #REPLACE THIS WITH YOUR CLASSES
class_names=[str(i) for i in range(1000)]

@app.route('/predict', methods=['POST'])
def predict():
    # Get uploaded image file
    image = request.files['image']
    
    # Process image and make prediction
    image_tensor = process_image(Image.open(image))
    output = model(image_tensor)

    # Get class probabilities
    probabilities = torch.nn.functional.softmax(output, dim=1)
    probabilities = probabilities.detach().numpy()[0]

    # Get the index of the highest probability
    class_index = probabilities.argmax()

    # Get the predicted class and probability
    predicted_class = class_names[class_index]
    probability = probabilities[class_index]

    # Sort class probabilities in descending order
    class_probs = list(zip(class_names, probabilities))
    class_probs.sort(key=lambda x: x[1], reverse=True)

    # Render HTML page with prediction results
    return render_template('predict.html', class_probs=class_probs,
                        predicted_class=predicted_class, probability=probability)
if __name__ == '__main__':
    app.run()
client.py
#can work 2023-1-8
import requests
import time
PyTorch_REST_API_URL = 'http://127.0.0.1:5000/predict'
def predict_result(image_path):
    # Initialize image path
    image = open(image_path, 'rb').read()
    payload = {'image': image}

    # Submit the request..json()
    # r = requests.post(PyTorch_REST_API_URL, files=payload).json()

    r = requests.post(PyTorch_REST_API_URL, files=payload)

    # 这里没执行,因为返回的是html文件 需要解析内容
    if r['success']:
        # Loop over the predictions and display them.
        for (i, result) in enumerate(r['predictions']):
            print("log is ...")
            # print('{}. {}: {:.4f}'.format(i + 1, result['label'],result['probability']))
    # Otherwise, the request failed.
    else:
        print('Request failed')
t1=time.time()
res=predict_result("/Users/ludongsheng/code/pycode/web_http/flask/flower.jpg")
t2=time.time()
print("----------------")
print("time is :",round((t2-t1),5))
 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值