diff --git a/README.md b/README.md index 63f8eda..c7519e9 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ Let an Agent interact with your Computer. - `POST /grid/action`: Takes a plan (`grid_id`, optional target cell, and an action like `click`/`drag`/`type`) and returns a structured `ActionResult` with computed coordinates for tooling to consume. - `GET /grid/{grid_id}/summary`: Returns both a heuristic description (`GridPlanner`) and a rich descriptor so the skill can summarize what it sees. - `GET /grid/{grid_id}/history`: Streams back the action history for that grid so an agent or operator can audit what was done. +- `POST /grid/{grid_id}/plan`: Lets `GridPlanner` select the target and return a preview action plan without committing to it, so we can inspect coordinates before triggering events. +- `POST /grid/{grid_id}/refresh` + `GET /stream/screenshots`: Refresh the cached screenshot/metadata and broadcast the updated scene over a websocket so clients can redraw overlays in near real time. - `GET /health`: A minimal health check for deployments. Vision metadata is kept on a per-grid basis, including history, layout dimensions, and any appended memo. Each `VisionGrid` also exposes a short textual summary so the skill layer can turn sensory data into sentences directly. @@ -23,11 +25,20 @@ The `skill/` package wraps the server calls and exposes helpers: - `ClickthroughSkill.describe_grid()` builds a grid session and returns the descriptor. - `ClickthroughSkill.plan_action()` drives the `/grid/action` endpoint. +- `ClickthroughSkill.plan_with_planner()` calls `/grid/{grid_id}/plan`, so you can preview the `GridPlanner` suggestion before executing it. - `ClickthroughSkill.grid_summary()` and `.grid_history()` surface the new metadata endpoints. -- `ClickthroughAgentRunner` simulates a tiny agent loop that chooses a cell (optionally by label), submits an action, and fetches the summary/history. +- `ClickthroughSkill.refresh_grid()` pushes a new screenshot and memo, triggering websocket listeners. +- `ClickthroughAgentRunner` simulates a tiny agent loop that asks the planner for a preview, executes the resulting action, and then gathers the summary/history so you can iterate on reasoning loops in tests. Future work can swap the stub runner for a full OpenClaw skill that keeps reasoning inside the agent and uses these primitives to steer the mouse/keyboard. +## Screenshot streaming + +Capture loops can now talk to FastAPI in two ways: + +1. POST `/grid/{grid_id}/refresh` with fresh base64 screenshots and an optional memo; the server updates the cached grid metadata and broadcasts the change. +2. Open a websocket to `GET /stream/screenshots` (optionally passing `grid_id` as a query param) to receive realtime deltas whenever a refresh happens. Clients can use the descriptor/payload to redraw overlays or trigger new planner runs without polling. + ## Testing 1. `python3 -m pip install -r requirements.txt` diff --git a/server/grid.py b/server/grid.py index e8e914b..720ef8c 100644 --- a/server/grid.py +++ b/server/grid.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, List, Tuple, Any +from typing import Any, Dict, List, Tuple import uuid from .actions import ActionEngine @@ -62,24 +62,36 @@ class VisionGrid: rows=self.rows, columns=self.columns, cells=[cell.model for cell in self.cells.values()], - metadata={ - "memo": self.memo or "", - "width": self.width, - "height": self.height, - }, + metadata=self.metadata, ) + @property + def metadata(self) -> Dict[str, Any]: + return { + "memo": self.memo or "", + "width": self.width, + "height": self.height, + } + def resolve_cell_center(self, cell_id: str) -> Tuple[int, int]: cell = self.cells.get(cell_id) if not cell: raise KeyError(f"Unknown cell {cell_id}") return cell.center + def preview_action(self, payload: ActionPayload) -> ActionResult: + return self._engine.plan(payload) + def apply_action(self, payload: ActionPayload) -> ActionResult: result = self._engine.plan(payload) self._action_history.append(result.model_dump()) return result + def update_screenshot(self, screenshot_base64: str, memo: str | None = None) -> None: + self.screenshot = screenshot_base64 + if memo: + self.memo = memo + @property def action_history(self) -> List[dict[str, Any]]: return list(self._action_history) diff --git a/server/main.py b/server/main.py index ae74a3c..5c2f62b 100644 --- a/server/main.py +++ b/server/main.py @@ -1,19 +1,29 @@ -from fastapi import FastAPI, HTTPException +import time + +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from .config import ServerSettings from .grid import GridManager -from .models import ActionPayload, GridDescriptor, GridInitRequest +from .models import ( + ActionPayload, + GridDescriptor, + GridInitRequest, + GridPlanRequest, + GridRefreshRequest, +) from .planner import GridPlanner +from .streamer import ScreenshotStreamer settings = ServerSettings() manager = GridManager(settings) planner = GridPlanner() +streamer = ScreenshotStreamer() app = FastAPI( title="Clickthrough", description="Grid-aware surface that lets an agent plan clicks, drags, and typing on a fake screenshot", - version="0.2.0", + version="0.3.0", ) @@ -59,3 +69,51 @@ def grid_history(grid_id: str): except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc return {"grid_id": grid_id, "history": history} + + +@app.post("/grid/{grid_id}/plan") +def plan_grid(grid_id: str, request: GridPlanRequest): + try: + grid = manager.get_grid(grid_id) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + descriptor = grid.describe() + payload = planner.build_payload( + descriptor, + action=request.action, + preferred_label=request.preferred_label, + text=request.text, + comment=request.comment, + ) + result = grid.preview_action(payload) + return {"plan": payload.model_dump(), "result": result, "descriptor": descriptor} + + +@app.post("/grid/{grid_id}/refresh") +async def refresh_grid(grid_id: str, payload: GridRefreshRequest): + try: + grid = manager.get_grid(grid_id) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + grid.update_screenshot(payload.screenshot_base64, payload.memo) + descriptor = grid.describe() + await streamer.broadcast( + grid_id, + { + "grid_id": grid_id, + "timestamp": time.time(), + "descriptor": descriptor, + "screenshot_base64": payload.screenshot_base64, + }, + ) + return {"status": "updated", "grid_id": grid_id} + + +@app.websocket("/stream/screenshots") +async def stream_screenshots(websocket: WebSocket, grid_id: str | None = None): + key = await streamer.connect(websocket, grid_id) + try: + while True: + await websocket.receive_text() + except WebSocketDisconnect: + streamer.disconnect(websocket, key) diff --git a/server/models.py b/server/models.py index d5c75da..141680e 100644 --- a/server/models.py +++ b/server/models.py @@ -53,3 +53,15 @@ class ActionResult(BaseModel): detail: str coordinates: Optional[Tuple[int, int]] = None payload: Dict[str, Any] = Field(default_factory=dict) + + +class GridPlanRequest(BaseModel): + preferred_label: Optional[str] = None + action: ActionType = ActionType.CLICK + text: Optional[str] = None + comment: Optional[str] = None + + +class GridRefreshRequest(BaseModel): + screenshot_base64: str + memo: Optional[str] = None diff --git a/server/planner.py b/server/planner.py index cdaaa67..e8f73bb 100644 --- a/server/planner.py +++ b/server/planner.py @@ -3,7 +3,7 @@ from __future__ import annotations from math import hypot from typing import Sequence -from .models import GridCellModel, GridDescriptor +from .models import ActionPayload, ActionType, GridCellModel, GridDescriptor class GridPlanner: @@ -23,6 +23,23 @@ class GridPlanner: center_point = self._grid_center(descriptor) return min(descriptor.cells, key=lambda cell: self._distance(self._cell_center(cell), center_point)) + def build_payload( + self, + descriptor: GridDescriptor, + action: ActionType = ActionType.CLICK, + preferred_label: str | None = None, + text: str | None = None, + comment: str | None = None, + ) -> ActionPayload: + target = self.select_cell(descriptor, preferred_label) + return ActionPayload( + grid_id=descriptor.grid_id, + action=action, + target_cell=target.cell_id if target else None, + text=text, + comment=comment, + ) + def describe(self, descriptor: GridDescriptor) -> str: cell_count = len(descriptor.cells) return ( diff --git a/server/streamer.py b/server/streamer.py new file mode 100644 index 0000000..5990def --- /dev/null +++ b/server/streamer.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any, DefaultDict, Dict, List + +from fastapi import WebSocket +from websockets.exceptions import ConnectionClosedError + + +class ScreenshotStreamer: + """Keeps websocket listeners and pushes screenshot updates.""" + + def __init__(self) -> None: + self._listeners: DefaultDict[str, List[WebSocket]] = defaultdict(list) + + async def connect(self, websocket: WebSocket, grid_id: str | None = None) -> str: + await websocket.accept() + key = grid_id or "*" + self._listeners[key].append(websocket) + return key + + def disconnect(self, websocket: WebSocket, grid_key: str | None = None) -> None: + key = grid_key or "*" + sockets = self._listeners.get(key) + if not sockets: + return + if websocket in sockets: + sockets.remove(websocket) + if not sockets: + self._listeners.pop(key, None) + + async def broadcast(self, grid_id: str, payload: Dict[str, Any]) -> None: + listeners = list(self._listeners.get(grid_id, [])) + list(self._listeners.get("*", [])) + for websocket in listeners: + try: + await websocket.send_json(payload) + except (ConnectionClosedError, RuntimeError): + self.disconnect(websocket, grid_id) diff --git a/skill/agent_runner.py b/skill/agent_runner.py index 2cfdde7..bcd8e6f 100644 --- a/skill/agent_runner.py +++ b/skill/agent_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, Sequence +from typing import Any, Dict from .clickthrough_skill import ActionPlan, ClickthroughSkill @@ -10,6 +10,7 @@ class AgentRunResult: action: Dict[str, Any] history: Dict[str, Any] grid: Dict[str, Any] + plan_preview: Dict[str, Any] class ClickthroughAgentRunner: @@ -34,29 +35,26 @@ class ClickthroughAgentRunner: rows=rows, columns=columns, ) - cells = grid.get("cells") or [] - target_cell = self._choose_cell(cells, preferred_label) - plan = ActionPlan( + plan_response = self.skill.plan_with_planner( grid_id=grid["grid_id"], - target_cell=target_cell, + preferred_label=preferred_label, action=action, text=text, ) + plan_payload = plan_response["plan"] + plan = ActionPlan( + grid_id=plan_payload["grid_id"], + target_cell=plan_payload.get("target_cell"), + action=plan_payload["action"], + text=plan_payload.get("text"), + ) action_result = self.skill.plan_action(plan) summary = self.skill.grid_summary(grid["grid_id"]) history = self.skill.grid_history(grid["grid_id"]) - return AgentRunResult(summary=summary, action=action_result, history=history, grid=grid) - - def _choose_cell( - self, cells: Sequence[dict[str, Any]], preferred_label: str | None - ) -> str: - if not cells: - raise ValueError("Grid contains no cells") - if preferred_label: - search = preferred_label.lower() - for cell in cells: - label_value = cell.get("label") - if label_value and search in label_value.lower(): - return cell["cell_id"] - center_index = len(cells) // 2 - return cells[center_index]["cell_id"] + return AgentRunResult( + summary=summary, + action=action_result, + history=history, + grid=grid, + plan_preview=plan_response, + ) diff --git a/skill/clickthrough_skill.py b/skill/clickthrough_skill.py index 5f487c7..48891b7 100644 --- a/skill/clickthrough_skill.py +++ b/skill/clickthrough_skill.py @@ -60,6 +60,30 @@ class ClickthroughSkill: response.raise_for_status() return response.json() + def plan_with_planner( + self, + grid_id: str, + preferred_label: str | None = None, + action: str = "click", + text: str | None = None, + comment: str | None = None, + ) -> Dict[str, Any]: + payload = { + "preferred_label": preferred_label, + "action": action, + "text": text, + "comment": comment or "planner-generated", + } + response = self._client.post(f"/grid/{grid_id}/plan", json=payload) + response.raise_for_status() + return response.json() + + def refresh_grid(self, grid_id: str, screenshot_base64: str, memo: str | None = None) -> Dict[str, Any]: + payload = {"screenshot_base64": screenshot_base64, "memo": memo} + response = self._client.post(f"/grid/{grid_id}/refresh", json=payload) + response.raise_for_status() + return response.json() + if __name__ == "__main__": import base64 diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index e1b1e29..c1e3ae5 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -24,6 +24,32 @@ class DummySkill(ClickthroughSkill): ], } + def plan_with_planner( + self, + grid_id: str, + preferred_label: str | None = None, + action: str = "click", + text: str | None = None, + comment: str | None = None, + ) -> Dict[str, Any]: + cells = ["dummy-grid-1", "dummy-grid-2"] + if preferred_label == "target": + target = "dummy-grid-2" + else: + target = cells[len(cells) // 2] + plan = { + "grid_id": grid_id, + "target_cell": target, + "action": action, + "text": text, + "comment": comment, + } + return { + "plan": plan, + "result": {"success": True, "detail": "preview"}, + "descriptor": {"grid_id": grid_id}, + } + def plan_action(self, plan: ActionPlan) -> Dict[str, Any]: self.last_plan = plan return {"success": True, "target_cell": plan.target_cell} diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py new file mode 100644 index 0000000..501a86d --- /dev/null +++ b/tests/test_endpoints.py @@ -0,0 +1,32 @@ +from fastapi.testclient import TestClient + +from server.main import app, manager + +test_client = TestClient(app) + + +def test_plan_endpoint(default_grid_request): + init_response = test_client.post("/grid/init", json=default_grid_request) + grid_id = init_response.json()["grid_id"] + + plan_response = test_client.post( + f"/grid/{grid_id}/plan", + json={"preferred_label": None, "action": "click", "text": "hello"}, + ) + assert plan_response.status_code == 200 + payload = plan_response.json() + assert payload["plan"]["grid_id"] == grid_id + assert payload["result"]["success"] + + +def test_refresh_endpoint(default_grid_request): + init_response = test_client.post("/grid/init", json=default_grid_request) + grid_id = init_response.json()["grid_id"] + + refresh_response = test_client.post( + f"/grid/{grid_id}/refresh", json={"screenshot_base64": "AAA", "memo": "updated"} + ) + assert refresh_response.status_code == 200 + grid = manager.get_grid(grid_id) + assert grid.screenshot == "AAA" + assert grid.memo == "updated" diff --git a/tests/test_planner.py b/tests/test_planner.py new file mode 100644 index 0000000..72a25ea --- /dev/null +++ b/tests/test_planner.py @@ -0,0 +1,32 @@ +from server.config import ServerSettings +from server.grid import GridManager +from server.planner import GridPlanner +from server.models import ActionType, GridInitRequest + + +def test_planner_preferred_label(default_grid_request): + settings = ServerSettings() + manager = GridManager(settings) + request = GridInitRequest(**default_grid_request) + grid = manager.create_grid(request) + descriptor = grid.describe() + descriptor.cells[0].label = "target" + + planner = GridPlanner() + payload = planner.build_payload(descriptor, preferred_label="target", action=ActionType.CLICK) + + assert payload.target_cell == descriptor.cells[0].cell_id + + +def test_planner_falls_back_to_center(default_grid_request): + settings = ServerSettings() + manager = GridManager(settings) + request = GridInitRequest(**default_grid_request) + grid = manager.create_grid(request) + descriptor = grid.describe() + + planner = GridPlanner() + payload = planner.build_payload(descriptor, action=ActionType.CLICK) + + assert payload.target_cell is not None + assert payload.grid_id == descriptor.grid_id diff --git a/tests/test_streamer.py b/tests/test_streamer.py new file mode 100644 index 0000000..a65239c --- /dev/null +++ b/tests/test_streamer.py @@ -0,0 +1,41 @@ +import asyncio + +from server.streamer import ScreenshotStreamer + + +class DummyWebSocket: + def __init__(self): + self.sent = [] + self.accepted = False + + async def accept(self) -> None: + self.accepted = True + + async def send_json(self, payload): + self.sent.append(payload) + + +def test_streamer_broadcasts_to_grid(): + streamer = ScreenshotStreamer() + socket = DummyWebSocket() + + async def scenario(): + key = await streamer.connect(socket, "grid-123") + await streamer.broadcast("grid-123", {"frame": 1}) + streamer.disconnect(socket, key) + + asyncio.run(scenario()) + assert socket.sent == [{"frame": 1}] + + +def test_streamer_wildcard_listener_receives_updates(): + streamer = ScreenshotStreamer() + socket = DummyWebSocket() + + async def scenario(): + key = await streamer.connect(socket, None) + await streamer.broadcast("grid-456", {"frame": 2}) + streamer.disconnect(socket, key) + + asyncio.run(scenario()) + assert socket.sent == [{"frame": 2}]