39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
from collections import defaultdict
|
|
from typing import Any, DefaultDict, Dict, List
|
|
|
|
from fastapi import WebSocket
|
|
from websockets.exceptions import ConnectionClosedError
|
|
|
|
|
|
class ScreenshotStreamer:
|
|
"""Keeps websocket listeners and pushes screenshot updates."""
|
|
|
|
def __init__(self) -> None:
|
|
self._listeners: DefaultDict[str, List[WebSocket]] = defaultdict(list)
|
|
|
|
async def connect(self, websocket: WebSocket, grid_id: str | None = None) -> str:
|
|
await websocket.accept()
|
|
key = grid_id or "*"
|
|
self._listeners[key].append(websocket)
|
|
return key
|
|
|
|
def disconnect(self, websocket: WebSocket, grid_key: str | None = None) -> None:
|
|
key = grid_key or "*"
|
|
sockets = self._listeners.get(key)
|
|
if not sockets:
|
|
return
|
|
if websocket in sockets:
|
|
sockets.remove(websocket)
|
|
if not sockets:
|
|
self._listeners.pop(key, None)
|
|
|
|
async def broadcast(self, grid_id: str, payload: Dict[str, Any]) -> None:
|
|
listeners = list(self._listeners.get(grid_id, [])) + list(self._listeners.get("*", []))
|
|
for websocket in listeners:
|
|
try:
|
|
await websocket.send_json(payload)
|
|
except (ConnectionClosedError, RuntimeError):
|
|
self.disconnect(websocket, grid_id)
|