【通用消息通知服务】0x3 - 发送我们第一条消息(Websocket)

【通用消息通知服务】0x3 - 发送我们第一条消息

项目地址: A generic message notification system[Github]

实现接收/发送Websocket消息

Websocket Connection Pool

import asyncio
from asyncio.queues import Queue
from asyncio.queues import QueueEmpty
from contextlib import suppress
from typing import Any

import async_timeout
import orjson
from sanic.log import logger
from ulid import ULID

from common.depend import Dependency

PING = "#ping"
PONG = "#pong"


class WebsocketConnectionPoolDependency(
    Dependency, dependency_name="WebsocketPool", dependency_alias="ws_pool"
):
    def __init__(self, app) -> None:
        super().__init__(app)
        self.lock = asyncio.Lock()
        self.connections = {}	# 存储websocket connections
        self.send_queues = {}   # 各websocket发送队列
        self.recv_queues = {}   # 各websocket接收消息队列
        self.close_callbacks = {} # websocket销毁回调
        self.listeners = {} # 连接监听函数

    def _gen_id(self) -> str:
        return str(ULID())

    async def add_connection(self, connection) -> str:
        async with self.lock:
            id = self._gen_id()
            self.connections[id] = connection
            self.send_queues[id] = Queue()
            self.app.add_task(
                self.send_task(self.send_queues[id], connection),
                name=f"websocket_{id}_send_task",
            )
            self.recv_queues[id] = Queue()
            self.app.add_task(
                self.recv_task(self.recv_queues[id], connection),
                name=f"websocket_{id}_recv_task",
            )
            self.app.add_task(self.notify_task(id), name=f"websocket_{id}_notify_task")
            self.app.add_task(
                self.is_alive_task(id), name=f"websocket_{id}_is_alive_task"
            )
            setattr(connection, "_id", id)
            return connection._id

    def get_connection(self, connection_id: str):
        return self.connections.get(connection_id)

    async def add_listener(self, connection_id, handler) -> str:
        async with self.lock:
            id = self._gen_id()
            self.listeners.setdefault(connection_id, {}).update({id: handler})
            return id

    async def remove_listener(self, connection_id, listener_id):
        async with self.lock:
            self.listeners.get(connection_id, {}).pop(listener_id, None)

    async def add_close_callback(self, connection_id, callback):
        async with self.lock:
            self.close_callbacks.setdefault(connection_id, []).append(callback)

    def is_alive(self, connection_id: str):
        if hasattr(connection_id, "_id"):
            connection_id = connection_id._id
        return connection_id in self.connections

    async def remove_connection(self, connection: Any):
        if hasattr(connection, "_id"):
            connection_id = connection._id
        else:
            connection_id = connection

            if connection_id not in self.connections:
                # removed already
                return

        async with self.lock:
            logger.info(f"remove connection: {connection_id}")

            with suppress(Exception):
                await self.app.cancel_task(f"websocket_{connection_id}_send_task")
            with suppress(Exception):
                await self.app.cancel_task(f"websocket_{connection_id}_recv_task")
            with suppress(Exception):
                await self.app.cancel_task(f"websocket_{connection_id}_notify_task")
            with suppress(Exception):
                await self.app.cancel_task(f"websocket_{connection_id}_is_alive_task")

            if connection_id in self.send_queues:
                del self.send_queues[connection_id]

            if connection_id in self.recv_queues:
                del self.recv_queues[connection_id]

            if connection_id in self.listeners:
                del self.listeners[connection_id]

            if connection_id in self.close_callbacks:
                await self.do_close_callbacks(connection_id)
                del self.close_callbacks[connection_id]

            if connection_id in self.connections:
                del self.connections[connection_id]

    async def do_close_callbacks(self, connection_id):
        for cb in self.close_callbacks.get(connection_id, []):
            self.app.add_task(cb(connection_id))

    async def prepare(self):
        self.is_prepared = True
        logger.info("dependency:WebsocketPool is prepared")
        return self.is_prepared

    async def check(self):
        return True

    async def send_task(self, queue, connection):
        while self.is_alive(connection):
            try:
                data = queue.get_nowait()
            except QueueEmpty:
                await asyncio.sleep(0)
                continue
            try:
                if isinstance(data, (bytes, str, int)):
                    await connection.send(data)
                else:
                    await connection.send(orjson.dumps(data).decode())
                queue.task_done()
            except Exception as err:
                break

    async def recv_task(self, queue, connection):
        while self.is_alive(connection):
            try:
                data = await connection.recv()
                await queue.put(data)
                logger.info(f"recv message: {data} from connection: {connection._id}")
            except Exception as err:
                break

    async def notify_task(self, connection_id):
        while self.is_alive(connection_id):
            try:
                logger.info(f"notify connection: {connection_id}'s listeners")
                data = await self.recv_queues[connection_id].get()
                for listener in self.listeners.get(connection_id, {}).values():
                    await listener(connection_id, data)
            except Exception as err:
                pass

    async def is_alive_task(self, connection_id: str):
        if hasattr(connection_id, "_id"):
            connection_id = connection_id._id

        get_pong = asyncio.Event()

        async def wait_pong(connection_id, data):
            if data != PONG:
                return
            get_pong.set()

        while True:
            get_pong.clear()
            await self.send(connection_id, PING)
            listener_id = await self.add_listener(connection_id, wait_pong)

            with suppress(asyncio.TimeoutError):
                async with async_timeout.timeout(
                    self.app.config.WEBSOCKET_PING_TIMEOUT
                ):
                    await get_pong.wait()

            await self.remove_listener(connection_id, listener_id)
            if get_pong.is_set():
                # this connection is closed
                await asyncio.sleep(self.app.config.WEBSOCKET_PING_INTERVAL)
            else:
                await self.remove_connection(connection_id)

    async def wait_closed(self, connection_id: str):
        """
        if negative=True, only release when client close this connection.
        """
        while self.is_alive(connection_id):
            await asyncio.sleep(0)
        return False

    async def send(self, connection_id: str, data: Any) -> bool:
        if not self.is_alive(connection_id):
            return False
        if connection_id not in self.send_queues:
            return False
        await self.send_queues[connection_id].put(data)

        return True

Websocket Provider

from typing import Dict
from typing import List
from typing import Union

from pydantic import BaseModel
from pydantic import field_serializer
from sanic.log import logger

from apps.message.common.constants import MessageProviderType
from apps.message.common.constants import MessageStatus
from apps.message.common.interfaces import SendResult
from apps.message.providers.base import MessageProviderModel
from apps.message.validators.types import EndpointExID
from apps.message.validators.types import EndpointTag
from apps.message.validators.types import ETag
from apps.message.validators.types import ExID
from utils import get_app


class WebsocketMessageProviderModel(MessageProviderModel):
    class Info:
        name = "websocket"
        description = "Bio-Channel Communication"
        type = MessageProviderType.WEBSOCKET

    class Capability:
        is_enabled = True
        can_send = True

    class Message(BaseModel):
        connections: List[Union[EndpointTag, EndpointExID, str]]
        action: str
        payload: Union[List, Dict, str, bytes]

        @field_serializer("connections")
        def serialize_connections(self, connections):
            return list(set(map(str, connections)))

    async def send(self, provider_id, message: Message) -> SendResult:
        app = get_app()
        websocket_pool = app.ctx.ws_pool

        sent_list = set()

        connections = []
        for connection in message.connections:
            if isinstance(connection, ETag):
                connections.extend(
                    [
                        w
                        for c in await connection.decode()
                        for w in c.get("websockets", [])
                    ]
                )
            elif isinstance(connection, ExID):
                endpoint = await connection.decode()
                if endpoint:
                    connections.extend(endpoint.get("websockets", []))
            else:
                connections.append(connection)

        connections = list(
            set(filter(lambda x: app.ctx.ws_pool.is_alive(connection), connections))
        )

        # logger.info(f"sending websocket message to {connections}")
        for connection in connections:
            if await websocket_pool.send(
                connection, data=message.model_dump_json(exclude=["connections"])
            ):
                sent_list.add(connection)

        if sent_list:
            return SendResult(
                provider_id=provider_id, message=message, status=MessageStatus.SUCCEEDED
            )
        else:
            return SendResult(
                provider_id=provider_id, message=message, status=MessageStatus.FAILED
            )

websocket接口


@app.websocket("/websocket")
async def handle_websocket(request, ws):
    from apps.endpoint.listeners import register_websocket_endpoint
    from apps.endpoint.listeners import unregister_websocket_endpoint

    con_id = None
    try:
        ctx = request.app.ctx
        con_id = await ctx.ws_pool.add_connection(ws)
        logger.info(f"new connection connected -> {con_id}")
        await ctx.ws_pool.add_listener(con_id, register_websocket_endpoint)
        await ctx.ws_pool.add_close_callback(con_id, unregister_websocket_endpoint)
        await ctx.ws_pool.send(
            con_id, data={"action": "on.connect", "payload": {"connection_id": con_id}}
        )
        await ctx.ws_pool.wait_closed(con_id) # 等待连接断开
    finally:
    	# 如果连接被客户端断开, handle_websocket将会被直接销毁, 所以销毁处理需要放在finally。
        request.app.add_task(request.app.ctx.ws_pool.remove_connection(con_id))

结果截图

websocket connected

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值