python 基础库实现简单websocket服务端框架

0x01 代码

  • 测试环境:linux、python2.7
# coding:utf-8
import os
import struct
import base64
import hashlib
import socket
try:
    import thread
except ImportError:
    import _thread as thread

class WebsocketServerBase(object):
    """ websocket服务端基类

    Details:
        =================================================================
        0                   1                   2                   3
        0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
        +-+-+-+-+-------+-+-------------+-------------------------------+
        |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
        |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
        |N|V|V|V|       |S|             |   (if payload len==126/127)   |
        | |1|2|3|       |K|             |                               |
        +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
        |     Extended payload length continued, if payload len == 127  |
        + - - - - - - - - - - - - - - - +-------------------------------+
        |                               |Masking-key, if MASK set to 1  |
        +-------------------------------+-------------------------------+
        | Masking-key (continued)       |          Payload Data         |
        +-------------------------------- - - - - - - - - - - - - - - - +
        :                     Payload Data continued ...                :
        + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
        |                     Payload Data continued ...                |
        +---------------------------------------------------------------+    

    Attributes:
        gm_api_map: 全局存放api的散列表
        timeout: 超时时间
    
    """ 
    
    # 全局api散列表
    gm_api_map = {
        "/": True
    }
    timeout = 30

    def __init__(self):
        """ 初始化方法 """
        pass

    def recv_data(self, conn):
        """ 服务器解析浏览器发送的信息

        Args:
            conn: 接入的websocket客户端对象
        """
        try:
            all_data = conn.recv(1024, socket.MSG_WAITALL)
            data_length = 0
            if not len(all_data):
                return str()

            else:
                code_len = ord(all_data[1]) & 127
                # x值是126,则后面2个字节形成的16位无符号整型数的值是payload的真实长度。
                if code_len == 126:
                    data_length = struct.unpack('>H', str(all_data[2:4]))[0]
                    masks = all_data[4:8]
                    data = all_data[8:]
                # x值是127,则后面8个字节形成的64位无符号整型数的值是payload的真实长度。
                elif code_len == 127:
                    data_length = struct.unpack('>Q', str(msg[2:10]))[0]
                    masks = all_data[10:14]
                    data = all_data[14:]
                # x值在0-125,则是payload的真实长度。
                else:
                    data_length = code_len
                    masks = all_data[2:6]
                    data = all_data[6:]
                raw_str = str()
                i = 0
                for d in data:
                    raw_str += chr(ord(d) ^ ord(masks[i % 4]))
                    i += 1
                return raw_str
        except:
            return str()


    def send_data(self, conn, data):
        """ 服务器处理发送给客户端的信息

        Args:
            conn: 接入的websocket客户端对象
            data: 待发送的数据
        """        
        if data:
            data = str(data)
        else:
            return False
        token = "\x81"
        length = len(data)
        if length < 126:
            # struct为Python中处理二进制数的模块,二进制流为C,或网络流的形式。
            token += struct.pack("B", length)    
        elif length <= 0xFFFF:
            token += struct.pack("!BH", 126, length)
        else:
            token += struct.pack("!BQ", 127, length)
        data = '%s%s' % (token, data)
        conn.send(data)
        return True


    def handshake(self, conn, address, thread_name):
        """ 握手建立连接
        
        Args:
            conn: 接入的websocket客户端对象
            address: 客户端地址[0]:ip, [1]port
            thread_name: 线程名
        """
        headers = {}
        shake = conn.recv(1024)
        if not len(shake):
            return dict()

        header, data = shake.split('\r\n\r\n', 1)
        for line in header.split('\r\n'):
            if 'GET' in line:
                list_msg = line.split(' ')
                headers['Api'] = list_msg[1]
                headers['HttpVersion'] = list_msg[2]
            else:
                key, value = line.split(': ', 1)
                # 非空判断
                if not not key:
                    headers[key] = value

        if 'Sec-WebSocket-Key' not in headers:
            self.log("INFO", '%s : This socket is not websocket, client close.' % thread_name)
            self.close(conn)
            return dict()

        MAGIC_STRING = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
        HANDSHAKE_STRING = "{0} 101 Switching Protocols\r\n" \
                        "Upgrade:websocket\r\n" \
                        "Connection: Upgrade\r\n" \
                        "Sec-WebSocket-Accept: {1}\r\n" \
                        "WebSocket-Origin: {2}\r\n" \
                        "WebSocket-Location: ws://{3}\r\n\r\n"

        sec_key = headers['Sec-WebSocket-Key']
        res_key = base64.b64encode(hashlib.sha1(sec_key + MAGIC_STRING).digest())
        str_handshake = HANDSHAKE_STRING.replace('{0}', headers['HttpVersion']).replace('{1}', res_key).replace('{2}', headers['Origin']).replace('{3}', headers['Host']+headers['Api'])
        conn.send(str_handshake)                 # 发送建立连接的信息

        # 检查该资源是否存在
        is_resource_exist = False
        if not not self.gm_api_map.get(headers['Api']):
            is_resource_exist = True

        # 若是无资源,则告知请求方该资源不存在,并断开连接
        if not is_resource_exist:
            self.log("ERROR", headers['Api'] + ": RESOURCE NOT FOUND!")
            self.send_data(conn, headers['Api'] + ": RESOURCE NOT FOUND!")
            self.close(conn)
            return dict()

        # 增加一个字段表示接收连接的套接字地址
        headers['Address'] = "{0}:{1}".format(address[0], address[1])

        self.log("INFO", headers)
        self.log("INFO", '%s : Socket handshaken with %s:%s success' % (thread_name, address[0], address[1]))
        return headers

    def close(self, conn):
        """ 关闭连接函数 """
        conn.close()

    def request_handle(self, conn, address, thread_name):
        """ 主接收消息处理方法

        Args:
            conn: 接入的websocket客户端对象
            address: 客户端地址[0]:ip, [1]port
            thread_name: 线程名
        """
        # websocket握手
        conn_headers = self.handshake(conn, address, thread_name)
        
        # 获取连接失败的处理,结束该处理线程
        if not conn_headers:
            self.log("ERROR", "get the connect fault, close!")
            return

        # 设置socket为阻塞
        conn.setblocking(True)

        # 设置30s超时断开,防止占用过久
        conn.settimeout(self.timeout)

        while True:
            try:
                clientdata = self.recv_data(conn)
                
                # 当收到断开连接时,断开连接,退出循环
                if not clientdata:
                    self.send_data(conn, 'close connect')
                    self.close(conn)
                    break
                
                # 正常接收数据
                else:
                    ret = self.on_message(conn_headers, conn, clientdata)
                    if not ret:
                        break

            except Exception as e:
                self.on_error( e )
        
        # 跳出while时,必定是断开了连接
        self.log("INFO", '%s : Socket close with %s:%s' % (thread_name, address[0], address[1]))
        # 触发close回调
        self.on_close()
        
    def set_timeout(self, timeout):
        """ 设置接收超时时间,到时间自动断开客户端连接,单位:s """
        self.timeout = timeout

    def ws_service(self, ip, port):
        """ websocket主运行方法

        Args:
            ip: 本地服务端ip
            port: 本地服务端绑定的端口
        """
        index = 1
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        # 自动释放端口占用
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        sock.bind((ip, port))
        sock.listen(100)

        self.log("INFO", 'Websocket server start, wait for connect!')
        while True:
            try:
                connection, address = sock.accept()
                
                thread_name = 'thread_%s' % index
                self.log("INFO", '%s : Connection from %s:%s' % (thread_name, address[0], address[1]))

                thread.start_new_thread(self.request_handle, (connection, address, thread_name,))

                # on_open方法
                self.on_open(connection)

                index += 1
            except Exception as e:
                self.log("ERROR", e)
                break

    def register_api(self, api):
        """ 列表
        
        Args:
            api: string类型,新增api接口/端点
        """
        self.gm_api_map.update( {api:True} )

    def on_open(self, conn):
        """ 标准websocket的on_open方法
        
        Args:
            conn: 接入的websocket客户端对象
        """
        pass

    def on_message(self, conn_headers, conn, message):
        """ 接收到的有效消息处理回调
                
        Args:
            conn_headers: 连接对象的请求头信息字典
            conn: 接入的websocket客户端对象
            message: 接收到的连接客户端对象消息
        """
        return True

    def on_close(self):
        """ 关闭连接时的回调
        """
        self.log("INFO", "close this connect!")

    def on_error(self, error):
        """ 触发异常时的回调

        Args:
            error: 异常消息
        """
        self.log("ERROR", error)

    def log(self, level, msg):
        """ 日志输出信息,方便修改为不同的日志框架
        
        Args:
            level: 日志消息级别
            msg: 日志消息
        """
        msg = str(msg)
        if "INFO" == level:
            print('[INFO]:' + msg)
        elif "ERROR" == level:
            print('[ERROR]:' + msg)
        elif "WARN" == level:
            print('[WARN]:' + msg)
        elif "DEBUG" == level:
            print('[DEBUG]:' + msg)            
        else:
            print("set level error, and reset on INFO!")
            print('[INFO]:' + msg)

if __name__ == '__main__':

    class TestWebsocketServer(WebsocketServerBase):
        """ 仅在测试时候使用,不作为包
        """

        def on_message(self, conn_headers, conn, message):
            """ 继承自基类的接收消息回调方法 """
            self.log("INFO", "recv form client: " + message)
            # 收到close就断开
            if "close" in message:
                self.close(conn)
                return False

            # 将消息发回去以测试
            self.send_data(conn, message)

            # 若是该api接口,则调用相应接口
            if "/getData" == conn_headers['Api']:
                self.send_data(conn, "this is getData api")

            return True

    # 实例化对象
    server = TestWebsocketServer()
    
    # 注册api接口/端点
    server.register_api("/getData")

    # 开始运行
    server.ws_service("0.0.0.0", 9997)


0x02 测试

0x03 问题

  1. 目前存在接收大量数据时,断连的情况,可继续优化。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值