python-websocket协议解析(一)

任何使用websocket的项目第一步都是协议解析。协议具体内容这里不在赘述,网上有各种详细资料。这里主要介绍python如何实现。我先将完整的协议处理类粘贴在这,然后再分布讲解,完整代码如下:
再贴代码之前多一句嘴,我在网上看了很多别人写的代码,完全按照websocket协议来解析的几乎没有,使用的时候经常遇到数据量大就会出错的问题,都是因为payload_len解析错误。
这个协议解析代码是我自己按照协议规范一步步解析而来,经过大量测试还没发现问题。

import base64
import hashlib


class WebSocketProtocolHandler:
    def __init__(self, conn):
        self.__conn = conn
        self.__close = False
        self.shark_hands()

    def shark_hands(self):
        request_header_data = self.__conn.recv(8096)
        request_header = WebSocketProtocolHandler.get_headers(request_header_data)

        # 对请求头中的sec-websocket-key进行加密
        response_tpl = "HTTP/1.1 101 Switching Protocols\r\n" \
                       "Upgrade:websocket\r\n" \
                       "Connection: Upgrade\r\n" \
                       "Sec-WebSocket-Accept: %s\r\n" \
                       "WebSocket-Location: ws://%s%s\r\n\r\n"

        magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
        value = request_header['Sec-WebSocket-Key'] + magic_string
        ac = base64.b64encode(hashlib.sha1(value.encode('utf-8')).digest())

        response_str = response_tpl % (ac.decode('utf-8'), request_header['Host'], request_header['url'])
        self.__conn.send(bytes(response_str, encoding='utf-8'))

    @staticmethod
    def get_headers(data):
        header_dict = {}
        data = str(data, encoding="utf-8")

        header, body = data.split("\r\n\r\n", 1)
        header_list = header.split("\r\n")
        for i in range(0, len(header_list)):
            if i == 0:
                if len(header_list[0].split(" ")) == 3:
                    header_dict['method'], header_dict['url'],header_dict['protocol'] = header_list[0].split(" ")
            else:
                k, v = header_list[i].split(":", 1)
                header_dict[k] = v.strip()
        return header_dict

    def __actual_read(self, size):
        already_read_len = 0
        buf = b""
        least_size = size
        while True:
            read_buf = self.__conn.recv(least_size)
            actual_read_len = len(read_buf)

            if actual_read_len == 0:
                # disconnect
                self.__close = True
                already_read_len = 0
                buf = b""
                break

            already_read_len += actual_read_len
            least_size -= already_read_len
            buf += read_buf
            if already_read_len == size:
                break
        return already_read_len, buf

    def __read_frame(self):
        fin = 0
        op_code = 0
        mask = b""
        mask_flag = 0
        body_bytes = b""

        try:
            # 先接收前2个字节
            buf = self.__actual_read(2)[1]

            fin = int(buf[0] & 128) >> 7
            op_code = int(buf[0] & 15)
            rsv = int(buf[0] & 112) >> 4
            payload_len = buf[1] & 127
            mask_flag = int(buf[1] & 128) >> 7

            if payload_len == 126:
                if mask_flag == 1:
                    # 前8个字节是header
                    buf += self.__actual_read(6)[1]
                    extend_payload_len = buf[2:4]
                    mask = buf[4:8]
                else:
                    # 前4
                    buf += self.__actual_read(2)[1]
                    extend_payload_len = buf[2:4]
                    mask = b""
            elif payload_len == 127:
                if mask_flag == 1:
                    # 前14个字节是header
                    buf += self.__actual_read(12)[1]
                    extend_payload_len = buf[2:10]
                    mask = buf[10:14]
                else:
                    # 前10
                    buf += self.__actual_read(8)[1]
                    extend_payload_len = buf[2:10]
                    mask = b""
            else:
                if mask_flag == 1:
                    # 前6个字节是header
                    buf += self.__actual_read(4)[1]
                    mask = buf[2:6]
                    extend_payload_len = payload_len
                else:
                    mask = b""
                    extend_payload_len = payload_len

            if isinstance(extend_payload_len, int):
                data_len = payload_len
            else:
                data_len = int.from_bytes(extend_payload_len, byteorder='big', signed=False)

            # print(f"fin:{fin}, op_code:{op_code}, mask:{mask_flag}, rsv:{rsv}, payload_length:{data_len}")
            body_bytes = self.__actual_read(data_len)[1]
        except Exception as e:
            # 如果read frame过程中 socket关闭 就会异常
            # print(e)
            self.__close = True
            pass

        return fin, op_code, mask, mask_flag, body_bytes

    def recv_msg(self):
        protocol_body_bytes = bytearray()

        while True:
            if self.__close:
                # 链接已经关闭
                break

            fin, op_code, mask, mask_flag, body_bytes = self.__read_frame()

            if mask_flag == 1 and mask:
                tmp_bytes = bytearray(body_bytes)
                for i in range(len(tmp_bytes)):
                    chunk = tmp_bytes[i] ^ mask[i % 4]
                    protocol_body_bytes.append(chunk)
            else:
                protocol_body_bytes += body_bytes

            if fin == 1 and (op_code == 0 or op_code == 1 or op_code == 2):
                # 读取到尾帧数据
                break
            elif op_code > 2:
                # 非数据帧
                break

        try:
            body = str(protocol_body_bytes, encoding='utf-8')
        except Exception:
            body = ""

        return body

    def send_msg(self, msg_bytes):
        """
        WebSocket服务端向客户端发送消息
        :param conn: 客户端连接到服务器端的socket对象,即: conn,address = socket.accept()
        :param msg_bytes: 向客户端发送的字节
        :return:
        """
        import struct

        token = b"\x81" #接收的第一字节,一般都是x81不变
        length = len(msg_bytes)
        if length < 126:
            token += struct.pack("B", length)
        elif length <= 0xFFFF:
            token += struct.pack("!BH", 126, length)
        else:
            token += struct.pack("!BQ", 127, length)

        msg = token + msg_bytes
        try:
            self.__conn.send(msg)
            return True
        except Exception:
            return False


这个类接收一个conn作为入参,conn是一个tcpsocket连接,了解网络编程的都不陌生,这里不介绍,接下来的文章会稍微讲一些网络编程的内容。进入正题,具体介绍代码
整个代码分为三部分,握手,写,读,握手和写都不易出错。
首先介绍的是处理websocket的握手请求。这里用了sharkhands 和 getheaders两个函数。首先从conn中获取握手数据,再返回对应的内容即可完成握手,这里无需理解。
再介绍写,根据要发送内容的长度,封装成不同的包,也不易出错。

最后介绍读取,也是整个处理中最复杂的部分,我将读取封装成了3个函数。
recv_msg, read_frame, actual_read。
recv_msg 确保接收到对端发送的一条完整的数据,因为当数据量很大时,这些数据将会被封装成多个frame, recv_msg 将这些frame拼接在一起,形成普通程序员可以操作的信息。
read_frame是读取一个完整的帧,里面的精髓的payload_len的判断,确保读出的数据就是一帧的数据,不会多一个字节也不会少一个字节。
actual_read是网络编程的内容,不赘述。

这个是一个很正确的python websocket协议解析。在网上找了很久,确实没有发现python写的可以使用处理大数据的协议解析,所以自己研究了一个。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值