feat: add shared runtime with FastAPI job server and safety pipeline
This commit is contained in:
188
src/server.py
Normal file
188
src/server.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .config import AppConfig, load_app_config
|
||||
from .storage import HistoryDB
|
||||
from .task_manager import JobManager
|
||||
from .ui import monitoring_page_html
|
||||
|
||||
|
||||
class CreateJobRequest(BaseModel):
|
||||
job: str = Field(..., min_length=1)
|
||||
model: str | None = None
|
||||
max_steps: int = Field(60, ge=1, le=400)
|
||||
command_timeout: int = Field(45, ge=1, le=600)
|
||||
type_interval: float = Field(0.02, ge=0.0, le=1.0)
|
||||
click_pause: float = Field(0.10, ge=0.0, le=2.0)
|
||||
disabled_tools: list[str] = Field(default_factory=list)
|
||||
safety_override: bool = False
|
||||
no_failsafe: bool = False
|
||||
|
||||
|
||||
class _WebSocketHub:
|
||||
def __init__(self) -> None:
|
||||
self._connections: set[WebSocket] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._loop = loop
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
async with self._lock:
|
||||
self._connections.add(websocket)
|
||||
|
||||
async def disconnect(self, websocket: WebSocket) -> None:
|
||||
async with self._lock:
|
||||
self._connections.discard(websocket)
|
||||
|
||||
async def broadcast(self, message: dict[str, Any]) -> None:
|
||||
async with self._lock:
|
||||
clients = list(self._connections)
|
||||
dead: list[WebSocket] = []
|
||||
for ws in clients:
|
||||
try:
|
||||
await ws.send_json(message)
|
||||
except Exception: # noqa: BLE001
|
||||
dead.append(ws)
|
||||
if dead:
|
||||
async with self._lock:
|
||||
for ws in dead:
|
||||
self._connections.discard(ws)
|
||||
|
||||
def broadcast_from_thread(self, message: dict[str, Any]) -> None:
|
||||
if self._loop is None:
|
||||
return
|
||||
asyncio.run_coroutine_threadsafe(self.broadcast(message), self._loop)
|
||||
|
||||
|
||||
def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
app_config = config or load_app_config(cwd=Path.cwd())
|
||||
if not app_config.openai_api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY is required in environment or .env.")
|
||||
if not app_config.screenjob_token:
|
||||
raise RuntimeError("SCREENJOB_TOKEN is required in environment or .env.")
|
||||
|
||||
app = FastAPI(title="ScreenJob API", version="1.0.0")
|
||||
db = HistoryDB(app_config.db_path)
|
||||
ws_hub = _WebSocketHub()
|
||||
manager = JobManager(config=app_config, db=db, broadcast=ws_hub.broadcast_from_thread)
|
||||
|
||||
app.state.config = app_config
|
||||
app.state.db = db
|
||||
app.state.ws_hub = ws_hub
|
||||
app.state.manager = manager
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _on_startup() -> None:
|
||||
ws_hub.set_loop(asyncio.get_running_loop())
|
||||
|
||||
def _extract_token(authorization: str | None, x_screenjob_token: str | None) -> str:
|
||||
if x_screenjob_token:
|
||||
return x_screenjob_token.strip()
|
||||
if authorization:
|
||||
token = authorization.strip()
|
||||
if token.lower().startswith("bearer "):
|
||||
return token[7:].strip()
|
||||
return token
|
||||
return ""
|
||||
|
||||
def require_token(
|
||||
authorization: str | None = Header(default=None),
|
||||
x_screenjob_token: str | None = Header(default=None),
|
||||
) -> None:
|
||||
token = _extract_token(authorization, x_screenjob_token)
|
||||
if not token or not secrets.compare_digest(token, app_config.screenjob_token):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
@app.post("/api/jobs")
|
||||
def create_job(payload: CreateJobRequest, _: None = Depends(require_token)) -> dict[str, str]:
|
||||
job_id = manager.submit_job(
|
||||
objective=payload.job,
|
||||
model=payload.model,
|
||||
max_steps=payload.max_steps,
|
||||
command_timeout=payload.command_timeout,
|
||||
type_interval=payload.type_interval,
|
||||
click_pause=payload.click_pause,
|
||||
disabled_tools=payload.disabled_tools,
|
||||
safety_override=payload.safety_override,
|
||||
no_failsafe=payload.no_failsafe,
|
||||
)
|
||||
return {"job_id": job_id}
|
||||
|
||||
@app.get("/api/jobs")
|
||||
def list_jobs(limit: int = Query(default=100, ge=1, le=500), _: None = Depends(require_token)) -> dict[str, Any]:
|
||||
return {"jobs": manager.list_jobs(limit=limit)}
|
||||
|
||||
@app.get("/api/jobs/{job_id}")
|
||||
def get_job(job_id: str, _: None = Depends(require_token)) -> dict[str, Any]:
|
||||
job = manager.get_job(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return job
|
||||
|
||||
@app.get("/api/jobs/{job_id}/events")
|
||||
def get_job_events(
|
||||
job_id: str,
|
||||
limit: int = Query(default=500, ge=1, le=5000),
|
||||
_: None = Depends(require_token),
|
||||
) -> dict[str, Any]:
|
||||
job = manager.get_job(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return {"events": manager.get_events(job_id, limit=limit)}
|
||||
|
||||
@app.post("/api/jobs/{job_id}/cancel")
|
||||
def cancel_job(job_id: str, _: None = Depends(require_token)) -> dict[str, Any]:
|
||||
job = manager.get_job(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
accepted = manager.cancel_job(job_id)
|
||||
return {"job_id": job_id, "cancel_requested": bool(accepted)}
|
||||
|
||||
@app.get("/api/stats")
|
||||
def stats(_: None = Depends(require_token)) -> dict[str, Any]:
|
||||
return manager.stats()
|
||||
|
||||
if not app_config.disable_ui:
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def ui_root() -> str:
|
||||
return monitoring_page_html()
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def ws_endpoint(websocket: WebSocket, token: str = Query(default="")) -> None:
|
||||
if not token or not secrets.compare_digest(token, app_config.screenjob_token):
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await ws_hub.connect(websocket)
|
||||
try:
|
||||
await websocket.send_json({"event_type": "connected", "payload": {"ok": True}})
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
await ws_hub.disconnect(websocket)
|
||||
except Exception:
|
||||
await ws_hub.disconnect(websocket)
|
||||
else:
|
||||
@app.get("/", response_class=JSONResponse)
|
||||
def ui_disabled() -> dict[str, Any]:
|
||||
return {"ok": True, "ui_disabled": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import uvicorn
|
||||
|
||||
app = create_app(load_app_config(Path.cwd()))
|
||||
config = app.state.config
|
||||
uvicorn.run(app, host=config.host, port=config.port, log_level="info")
|
||||
Reference in New Issue
Block a user