This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple, Any
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import uuid
|
||||
|
||||
from .actions import ActionEngine
|
||||
@@ -62,24 +62,36 @@ class VisionGrid:
|
||||
rows=self.rows,
|
||||
columns=self.columns,
|
||||
cells=[cell.model for cell in self.cells.values()],
|
||||
metadata={
|
||||
"memo": self.memo or "",
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
},
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"memo": self.memo or "",
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
}
|
||||
|
||||
def resolve_cell_center(self, cell_id: str) -> Tuple[int, int]:
|
||||
cell = self.cells.get(cell_id)
|
||||
if not cell:
|
||||
raise KeyError(f"Unknown cell {cell_id}")
|
||||
return cell.center
|
||||
|
||||
def preview_action(self, payload: ActionPayload) -> ActionResult:
|
||||
return self._engine.plan(payload)
|
||||
|
||||
def apply_action(self, payload: ActionPayload) -> ActionResult:
|
||||
result = self._engine.plan(payload)
|
||||
self._action_history.append(result.model_dump())
|
||||
return result
|
||||
|
||||
def update_screenshot(self, screenshot_base64: str, memo: str | None = None) -> None:
|
||||
self.screenshot = screenshot_base64
|
||||
if memo:
|
||||
self.memo = memo
|
||||
|
||||
@property
|
||||
def action_history(self) -> List[dict[str, Any]]:
|
||||
return list(self._action_history)
|
||||
|
||||
@@ -1,19 +1,29 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import time
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
|
||||
from .config import ServerSettings
|
||||
from .grid import GridManager
|
||||
from .models import ActionPayload, GridDescriptor, GridInitRequest
|
||||
from .models import (
|
||||
ActionPayload,
|
||||
GridDescriptor,
|
||||
GridInitRequest,
|
||||
GridPlanRequest,
|
||||
GridRefreshRequest,
|
||||
)
|
||||
from .planner import GridPlanner
|
||||
from .streamer import ScreenshotStreamer
|
||||
|
||||
|
||||
settings = ServerSettings()
|
||||
manager = GridManager(settings)
|
||||
planner = GridPlanner()
|
||||
streamer = ScreenshotStreamer()
|
||||
|
||||
app = FastAPI(
|
||||
title="Clickthrough",
|
||||
description="Grid-aware surface that lets an agent plan clicks, drags, and typing on a fake screenshot",
|
||||
version="0.2.0",
|
||||
version="0.3.0",
|
||||
)
|
||||
|
||||
|
||||
@@ -59,3 +69,51 @@ def grid_history(grid_id: str):
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
return {"grid_id": grid_id, "history": history}
|
||||
|
||||
|
||||
@app.post("/grid/{grid_id}/plan")
|
||||
def plan_grid(grid_id: str, request: GridPlanRequest):
|
||||
try:
|
||||
grid = manager.get_grid(grid_id)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
descriptor = grid.describe()
|
||||
payload = planner.build_payload(
|
||||
descriptor,
|
||||
action=request.action,
|
||||
preferred_label=request.preferred_label,
|
||||
text=request.text,
|
||||
comment=request.comment,
|
||||
)
|
||||
result = grid.preview_action(payload)
|
||||
return {"plan": payload.model_dump(), "result": result, "descriptor": descriptor}
|
||||
|
||||
|
||||
@app.post("/grid/{grid_id}/refresh")
|
||||
async def refresh_grid(grid_id: str, payload: GridRefreshRequest):
|
||||
try:
|
||||
grid = manager.get_grid(grid_id)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
grid.update_screenshot(payload.screenshot_base64, payload.memo)
|
||||
descriptor = grid.describe()
|
||||
await streamer.broadcast(
|
||||
grid_id,
|
||||
{
|
||||
"grid_id": grid_id,
|
||||
"timestamp": time.time(),
|
||||
"descriptor": descriptor,
|
||||
"screenshot_base64": payload.screenshot_base64,
|
||||
},
|
||||
)
|
||||
return {"status": "updated", "grid_id": grid_id}
|
||||
|
||||
|
||||
@app.websocket("/stream/screenshots")
|
||||
async def stream_screenshots(websocket: WebSocket, grid_id: str | None = None):
|
||||
key = await streamer.connect(websocket, grid_id)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
streamer.disconnect(websocket, key)
|
||||
|
||||
@@ -53,3 +53,15 @@ class ActionResult(BaseModel):
|
||||
detail: str
|
||||
coordinates: Optional[Tuple[int, int]] = None
|
||||
payload: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GridPlanRequest(BaseModel):
|
||||
preferred_label: Optional[str] = None
|
||||
action: ActionType = ActionType.CLICK
|
||||
text: Optional[str] = None
|
||||
comment: Optional[str] = None
|
||||
|
||||
|
||||
class GridRefreshRequest(BaseModel):
|
||||
screenshot_base64: str
|
||||
memo: Optional[str] = None
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from math import hypot
|
||||
from typing import Sequence
|
||||
|
||||
from .models import GridCellModel, GridDescriptor
|
||||
from .models import ActionPayload, ActionType, GridCellModel, GridDescriptor
|
||||
|
||||
|
||||
class GridPlanner:
|
||||
@@ -23,6 +23,23 @@ class GridPlanner:
|
||||
center_point = self._grid_center(descriptor)
|
||||
return min(descriptor.cells, key=lambda cell: self._distance(self._cell_center(cell), center_point))
|
||||
|
||||
def build_payload(
|
||||
self,
|
||||
descriptor: GridDescriptor,
|
||||
action: ActionType = ActionType.CLICK,
|
||||
preferred_label: str | None = None,
|
||||
text: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> ActionPayload:
|
||||
target = self.select_cell(descriptor, preferred_label)
|
||||
return ActionPayload(
|
||||
grid_id=descriptor.grid_id,
|
||||
action=action,
|
||||
target_cell=target.cell_id if target else None,
|
||||
text=text,
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
def describe(self, descriptor: GridDescriptor) -> str:
|
||||
cell_count = len(descriptor.cells)
|
||||
return (
|
||||
|
||||
38
server/streamer.py
Normal file
38
server/streamer.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List
|
||||
|
||||
from fastapi import WebSocket
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
|
||||
class ScreenshotStreamer:
|
||||
"""Keeps websocket listeners and pushes screenshot updates."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._listeners: DefaultDict[str, List[WebSocket]] = defaultdict(list)
|
||||
|
||||
async def connect(self, websocket: WebSocket, grid_id: str | None = None) -> str:
|
||||
await websocket.accept()
|
||||
key = grid_id or "*"
|
||||
self._listeners[key].append(websocket)
|
||||
return key
|
||||
|
||||
def disconnect(self, websocket: WebSocket, grid_key: str | None = None) -> None:
|
||||
key = grid_key or "*"
|
||||
sockets = self._listeners.get(key)
|
||||
if not sockets:
|
||||
return
|
||||
if websocket in sockets:
|
||||
sockets.remove(websocket)
|
||||
if not sockets:
|
||||
self._listeners.pop(key, None)
|
||||
|
||||
async def broadcast(self, grid_id: str, payload: Dict[str, Any]) -> None:
|
||||
listeners = list(self._listeners.get(grid_id, [])) + list(self._listeners.get("*", []))
|
||||
for websocket in listeners:
|
||||
try:
|
||||
await websocket.send_json(payload)
|
||||
except (ConnectionClosedError, RuntimeError):
|
||||
self.disconnect(websocket, grid_id)
|
||||
Reference in New Issue
Block a user