feat: add shared runtime with FastAPI job server and safety pipeline

This commit is contained in:
Space-Banane
2026-05-27 17:43:51 +02:00
parent 84b0df520c
commit 10355bf11a
14 changed files with 1516 additions and 157 deletions

326
src/task_manager.py Normal file
View File

@@ -0,0 +1,326 @@
from __future__ import annotations
import threading
import time
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
from .config import AppConfig
from .models import RuntimeOptions
from .runtime import create_openai_client, run_job
from .safety import assess_task_safety
from .storage import HistoryDB
from .utils import utc_now_iso
@dataclass
class _RunningJob:
thread: threading.Thread
cancel_event: threading.Event
started_at: str
objective: str
model: str
class JobManager:
def __init__(
self,
*,
config: AppConfig,
db: HistoryDB,
broadcast: Callable[[dict[str, Any]], None] | None = None,
) -> None:
self.config = config
self.db = db
self.broadcast = broadcast
self._running: dict[str, _RunningJob] = {}
self._lock = threading.Lock()
def submit_job(
self,
*,
objective: str,
model: str | None = None,
max_steps: int = 60,
command_timeout: int = 45,
type_interval: float = 0.02,
click_pause: float = 0.10,
disabled_tools: list[str] | None = None,
safety_override: bool = False,
no_failsafe: bool = False,
) -> str:
job_id = f"job_{int(time.time())}_{uuid.uuid4().hex[:8]}"
created_at = utc_now_iso()
selected_model = (model or self.config.default_model).strip() or self.config.default_model
disabled = sorted({tool.strip() for tool in (disabled_tools or []) if tool.strip()})
self.db.create_job(
job_id=job_id,
objective=objective,
model=selected_model,
created_at=created_at,
safety_override=safety_override,
disabled_tools=disabled,
)
self._publish(
job_id,
{
"ts": created_at,
"step": 0,
"event_type": "job_queued",
"payload": {
"job_id": job_id,
"objective": objective,
"model": selected_model,
"disabled_tools": disabled,
"safety_override": bool(safety_override),
},
},
)
cancel_event = threading.Event()
thread = threading.Thread(
target=self._execute_job,
kwargs={
"job_id": job_id,
"objective": objective,
"model": selected_model,
"disabled_tools": disabled,
"safety_override": safety_override,
"max_steps": max_steps,
"command_timeout": command_timeout,
"type_interval": type_interval,
"click_pause": click_pause,
"no_failsafe": no_failsafe,
"cancel_event": cancel_event,
},
daemon=True,
)
with self._lock:
self._running[job_id] = _RunningJob(
thread=thread,
cancel_event=cancel_event,
started_at=created_at,
objective=objective,
model=selected_model,
)
thread.start()
return job_id
def _execute_job(
self,
*,
job_id: str,
objective: str,
model: str,
disabled_tools: list[str],
safety_override: bool,
max_steps: int,
command_timeout: int,
type_interval: float,
click_pause: float,
no_failsafe: bool,
cancel_event: threading.Event,
) -> None:
started_at = utc_now_iso()
self.db.update_job(job_id, status="running", started_at=started_at)
self._publish(job_id, {"ts": started_at, "step": 0, "event_type": "job_started", "payload": {"job_id": job_id}})
if not safety_override:
client = create_openai_client(self.config.openai_api_key)
safe, reason, raw = assess_task_safety(
client,
model=self.config.safety_model,
objective=objective,
disabled_tools=disabled_tools,
)
self.db.update_job(
job_id,
safety_checked=1,
safety_passed=1 if safe else 0,
safety_reason=reason,
)
self._publish(
job_id,
{
"ts": utc_now_iso(),
"step": 0,
"event_type": "safety_check",
"payload": {"safe": safe, "reason": reason, "raw": raw},
},
)
if not safe:
ended_at = utc_now_iso()
error_text = f"Task blocked by safety gate: {reason}"
self.db.update_job(
job_id,
status="failed",
ended_at=ended_at,
error=error_text,
result=error_text,
)
self._publish(
job_id,
{
"ts": ended_at,
"step": 0,
"event_type": "job_rejected",
"payload": {"error": error_text},
},
)
with self._lock:
self._running.pop(job_id, None)
return
else:
self.db.update_job(
job_id,
safety_checked=1,
safety_passed=1,
safety_reason="Safety check bypassed by override.",
)
self._publish(
job_id,
{
"ts": utc_now_iso(),
"step": 0,
"event_type": "safety_override",
"payload": {"enabled": True},
},
)
def on_event(event: dict[str, Any]) -> None:
self._publish(job_id, event)
if event.get("event_type") == "usage_update":
usage = (event.get("payload") or {}).get("usage") or {}
self.db.update_job(
job_id,
input_tokens=int(usage.get("input_tokens", 0) or 0),
cached_input_tokens=int(usage.get("cached_input_tokens", 0) or 0),
output_tokens=int(usage.get("output_tokens", 0) or 0),
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
total_tokens=int(usage.get("total_tokens", 0) or 0),
estimated_cost_usd=usage.get("estimated_cost_usd"),
)
options = RuntimeOptions(
model=model,
max_steps=max_steps,
command_timeout=command_timeout,
type_interval=type_interval,
click_pause=click_pause,
disable_tools=set(disabled_tools),
)
try:
result, artifacts = run_job(
api_key=self.config.openai_api_key,
objective=objective,
options=options,
runs_base=self.config.runs_dir,
no_failsafe=no_failsafe,
cancel_event=cancel_event,
event_callback=on_event,
)
except Exception as exc: # noqa: BLE001
ended_at = utc_now_iso()
err = f"Fatal runtime error: {type(exc).__name__}: {exc}"
self.db.update_job(
job_id,
status="failed",
ended_at=ended_at,
error=err,
result=err,
)
self._publish(job_id, {"ts": ended_at, "step": 0, "event_type": "job_failed", "payload": {"error": err}})
with self._lock:
self._running.pop(job_id, None)
return
ended_at = utc_now_iso()
status = "completed" if result.completed else "failed"
if result.cancelled:
status = "cancelled"
self.db.update_job(
job_id,
status=status,
ended_at=ended_at,
result=result.result,
error=result.error,
steps=result.steps,
cancelled=1 if result.cancelled else 0,
artifacts_dir=str(Path(artifacts.root_dir).resolve()),
input_tokens=result.usage.input_tokens,
cached_input_tokens=result.usage.cached_input_tokens,
output_tokens=result.usage.output_tokens,
reasoning_tokens=result.usage.reasoning_tokens,
total_tokens=result.usage.total_tokens,
estimated_cost_usd=result.usage.estimated_cost_usd,
)
self._publish(
job_id,
{
"ts": ended_at,
"step": result.steps,
"event_type": "job_finished",
"payload": {
"status": status,
"result": result.result,
"error": result.error,
"cancelled": result.cancelled,
"usage": result.usage.to_dict(),
},
},
)
with self._lock:
self._running.pop(job_id, None)
def _publish(self, job_id: str, event: dict[str, Any]) -> None:
ts = str(event.get("ts") or utc_now_iso())
step = int(event.get("step", 0) or 0)
event_type = str(event.get("event_type", "event"))
payload = event.get("payload") or {}
self.db.add_event(job_id=job_id, ts=ts, step=step, event_type=event_type, payload=payload)
if self.broadcast is not None:
self.broadcast(
{
"job_id": job_id,
"ts": ts,
"step": step,
"event_type": event_type,
"payload": payload,
}
)
def cancel_job(self, job_id: str) -> bool:
with self._lock:
running = self._running.get(job_id)
if running is None:
job = self.db.get_job(job_id)
return bool(job and job.get("status") == "cancelled")
running.cancel_event.set()
self.db.update_job(job_id, status="cancelling")
self._publish(job_id, {"ts": utc_now_iso(), "step": 0, "event_type": "cancel_requested", "payload": {}})
return True
def get_job(self, job_id: str) -> dict[str, Any] | None:
job = self.db.get_job(job_id)
if job is None:
return None
live = self._running.get(job_id)
if live and job["status"] in {"queued", "running", "cancelling"}:
job["is_running_thread"] = live.thread.is_alive()
else:
job["is_running_thread"] = False
return job
def list_jobs(self, limit: int = 100) -> list[dict[str, Any]]:
return self.db.list_jobs(limit=limit)
def get_events(self, job_id: str, limit: int = 500) -> list[dict[str, Any]]:
return self.db.get_job_events(job_id, limit=limit)
def stats(self) -> dict[str, Any]:
stats = self.db.stats()
with self._lock:
stats["live_running_threads"] = sum(1 for job in self._running.values() if job.thread.is_alive())
return stats