from __future__ import annotations from collections import defaultdict from typing import Any, DefaultDict, Dict, List from fastapi import WebSocket from websockets.exceptions import ConnectionClosedError class ScreenshotStreamer: """Keeps websocket listeners and pushes screenshot updates.""" def __init__(self) -> None: self._listeners: DefaultDict[str, List[WebSocket]] = defaultdict(list) async def connect(self, websocket: WebSocket, grid_id: str | None = None) -> str: await websocket.accept() key = grid_id or "*" self._listeners[key].append(websocket) return key def disconnect(self, websocket: WebSocket, grid_key: str | None = None) -> None: key = grid_key or "*" sockets = self._listeners.get(key) if not sockets: return if websocket in sockets: sockets.remove(websocket) if not sockets: self._listeners.pop(key, None) async def broadcast(self, grid_id: str, payload: Dict[str, Any]) -> None: listeners = list(self._listeners.get(grid_id, [])) + list(self._listeners.get("*", [])) for websocket in listeners: try: await websocket.send_json(payload) except (ConnectionClosedError, RuntimeError): self.disconnect(websocket, grid_id)