Add planner previews and streaming
Some checks failed
CI / test (push) Failing after 45s

This commit is contained in:
2026-04-05 19:33:24 +02:00
parent b1d2b6b321
commit 1b0b9cfdef
12 changed files with 332 additions and 31 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any
from typing import Any, Dict, List, Tuple
import uuid
from .actions import ActionEngine
@@ -62,24 +62,36 @@ class VisionGrid:
rows=self.rows,
columns=self.columns,
cells=[cell.model for cell in self.cells.values()],
metadata={
"memo": self.memo or "",
"width": self.width,
"height": self.height,
},
metadata=self.metadata,
)
@property
def metadata(self) -> Dict[str, Any]:
return {
"memo": self.memo or "",
"width": self.width,
"height": self.height,
}
def resolve_cell_center(self, cell_id: str) -> Tuple[int, int]:
cell = self.cells.get(cell_id)
if not cell:
raise KeyError(f"Unknown cell {cell_id}")
return cell.center
def preview_action(self, payload: ActionPayload) -> ActionResult:
return self._engine.plan(payload)
def apply_action(self, payload: ActionPayload) -> ActionResult:
result = self._engine.plan(payload)
self._action_history.append(result.model_dump())
return result
def update_screenshot(self, screenshot_base64: str, memo: str | None = None) -> None:
self.screenshot = screenshot_base64
if memo:
self.memo = memo
@property
def action_history(self) -> List[dict[str, Any]]:
return list(self._action_history)

View File

@@ -1,19 +1,29 @@
from fastapi import FastAPI, HTTPException
import time
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from .config import ServerSettings
from .grid import GridManager
from .models import ActionPayload, GridDescriptor, GridInitRequest
from .models import (
ActionPayload,
GridDescriptor,
GridInitRequest,
GridPlanRequest,
GridRefreshRequest,
)
from .planner import GridPlanner
from .streamer import ScreenshotStreamer
settings = ServerSettings()
manager = GridManager(settings)
planner = GridPlanner()
streamer = ScreenshotStreamer()
app = FastAPI(
title="Clickthrough",
description="Grid-aware surface that lets an agent plan clicks, drags, and typing on a fake screenshot",
version="0.2.0",
version="0.3.0",
)
@@ -59,3 +69,51 @@ def grid_history(grid_id: str):
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return {"grid_id": grid_id, "history": history}
@app.post("/grid/{grid_id}/plan")
def plan_grid(grid_id: str, request: GridPlanRequest):
try:
grid = manager.get_grid(grid_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
descriptor = grid.describe()
payload = planner.build_payload(
descriptor,
action=request.action,
preferred_label=request.preferred_label,
text=request.text,
comment=request.comment,
)
result = grid.preview_action(payload)
return {"plan": payload.model_dump(), "result": result, "descriptor": descriptor}
@app.post("/grid/{grid_id}/refresh")
async def refresh_grid(grid_id: str, payload: GridRefreshRequest):
try:
grid = manager.get_grid(grid_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
grid.update_screenshot(payload.screenshot_base64, payload.memo)
descriptor = grid.describe()
await streamer.broadcast(
grid_id,
{
"grid_id": grid_id,
"timestamp": time.time(),
"descriptor": descriptor,
"screenshot_base64": payload.screenshot_base64,
},
)
return {"status": "updated", "grid_id": grid_id}
@app.websocket("/stream/screenshots")
async def stream_screenshots(websocket: WebSocket, grid_id: str | None = None):
key = await streamer.connect(websocket, grid_id)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
streamer.disconnect(websocket, key)

View File

@@ -53,3 +53,15 @@ class ActionResult(BaseModel):
detail: str
coordinates: Optional[Tuple[int, int]] = None
payload: Dict[str, Any] = Field(default_factory=dict)
class GridPlanRequest(BaseModel):
preferred_label: Optional[str] = None
action: ActionType = ActionType.CLICK
text: Optional[str] = None
comment: Optional[str] = None
class GridRefreshRequest(BaseModel):
screenshot_base64: str
memo: Optional[str] = None

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from math import hypot
from typing import Sequence
from .models import GridCellModel, GridDescriptor
from .models import ActionPayload, ActionType, GridCellModel, GridDescriptor
class GridPlanner:
@@ -23,6 +23,23 @@ class GridPlanner:
center_point = self._grid_center(descriptor)
return min(descriptor.cells, key=lambda cell: self._distance(self._cell_center(cell), center_point))
def build_payload(
self,
descriptor: GridDescriptor,
action: ActionType = ActionType.CLICK,
preferred_label: str | None = None,
text: str | None = None,
comment: str | None = None,
) -> ActionPayload:
target = self.select_cell(descriptor, preferred_label)
return ActionPayload(
grid_id=descriptor.grid_id,
action=action,
target_cell=target.cell_id if target else None,
text=text,
comment=comment,
)
def describe(self, descriptor: GridDescriptor) -> str:
cell_count = len(descriptor.cells)
return (

38
server/streamer.py Normal file
View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List
from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedError
class ScreenshotStreamer:
"""Keeps websocket listeners and pushes screenshot updates."""
def __init__(self) -> None:
self._listeners: DefaultDict[str, List[WebSocket]] = defaultdict(list)
async def connect(self, websocket: WebSocket, grid_id: str | None = None) -> str:
await websocket.accept()
key = grid_id or "*"
self._listeners[key].append(websocket)
return key
def disconnect(self, websocket: WebSocket, grid_key: str | None = None) -> None:
key = grid_key or "*"
sockets = self._listeners.get(key)
if not sockets:
return
if websocket in sockets:
sockets.remove(websocket)
if not sockets:
self._listeners.pop(key, None)
async def broadcast(self, grid_id: str, payload: Dict[str, Any]) -> None:
listeners = list(self._listeners.get(grid_id, [])) + list(self._listeners.get("*", []))
for websocket in listeners:
try:
await websocket.send_json(payload)
except (ConnectionClosedError, RuntimeError):
self.disconnect(websocket, grid_id)