feat: add shared runtime with FastAPI job server and safety pipeline
This commit is contained in:
326
src/task_manager.py
Normal file
326
src/task_manager.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from .config import AppConfig
|
||||
from .models import RuntimeOptions
|
||||
from .runtime import create_openai_client, run_job
|
||||
from .safety import assess_task_safety
|
||||
from .storage import HistoryDB
|
||||
from .utils import utc_now_iso
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RunningJob:
|
||||
thread: threading.Thread
|
||||
cancel_event: threading.Event
|
||||
started_at: str
|
||||
objective: str
|
||||
model: str
|
||||
|
||||
|
||||
class JobManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: AppConfig,
|
||||
db: HistoryDB,
|
||||
broadcast: Callable[[dict[str, Any]], None] | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.db = db
|
||||
self.broadcast = broadcast
|
||||
self._running: dict[str, _RunningJob] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def submit_job(
|
||||
self,
|
||||
*,
|
||||
objective: str,
|
||||
model: str | None = None,
|
||||
max_steps: int = 60,
|
||||
command_timeout: int = 45,
|
||||
type_interval: float = 0.02,
|
||||
click_pause: float = 0.10,
|
||||
disabled_tools: list[str] | None = None,
|
||||
safety_override: bool = False,
|
||||
no_failsafe: bool = False,
|
||||
) -> str:
|
||||
job_id = f"job_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
created_at = utc_now_iso()
|
||||
selected_model = (model or self.config.default_model).strip() or self.config.default_model
|
||||
disabled = sorted({tool.strip() for tool in (disabled_tools or []) if tool.strip()})
|
||||
self.db.create_job(
|
||||
job_id=job_id,
|
||||
objective=objective,
|
||||
model=selected_model,
|
||||
created_at=created_at,
|
||||
safety_override=safety_override,
|
||||
disabled_tools=disabled,
|
||||
)
|
||||
self._publish(
|
||||
job_id,
|
||||
{
|
||||
"ts": created_at,
|
||||
"step": 0,
|
||||
"event_type": "job_queued",
|
||||
"payload": {
|
||||
"job_id": job_id,
|
||||
"objective": objective,
|
||||
"model": selected_model,
|
||||
"disabled_tools": disabled,
|
||||
"safety_override": bool(safety_override),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
thread = threading.Thread(
|
||||
target=self._execute_job,
|
||||
kwargs={
|
||||
"job_id": job_id,
|
||||
"objective": objective,
|
||||
"model": selected_model,
|
||||
"disabled_tools": disabled,
|
||||
"safety_override": safety_override,
|
||||
"max_steps": max_steps,
|
||||
"command_timeout": command_timeout,
|
||||
"type_interval": type_interval,
|
||||
"click_pause": click_pause,
|
||||
"no_failsafe": no_failsafe,
|
||||
"cancel_event": cancel_event,
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
with self._lock:
|
||||
self._running[job_id] = _RunningJob(
|
||||
thread=thread,
|
||||
cancel_event=cancel_event,
|
||||
started_at=created_at,
|
||||
objective=objective,
|
||||
model=selected_model,
|
||||
)
|
||||
thread.start()
|
||||
return job_id
|
||||
|
||||
def _execute_job(
|
||||
self,
|
||||
*,
|
||||
job_id: str,
|
||||
objective: str,
|
||||
model: str,
|
||||
disabled_tools: list[str],
|
||||
safety_override: bool,
|
||||
max_steps: int,
|
||||
command_timeout: int,
|
||||
type_interval: float,
|
||||
click_pause: float,
|
||||
no_failsafe: bool,
|
||||
cancel_event: threading.Event,
|
||||
) -> None:
|
||||
started_at = utc_now_iso()
|
||||
self.db.update_job(job_id, status="running", started_at=started_at)
|
||||
self._publish(job_id, {"ts": started_at, "step": 0, "event_type": "job_started", "payload": {"job_id": job_id}})
|
||||
|
||||
if not safety_override:
|
||||
client = create_openai_client(self.config.openai_api_key)
|
||||
safe, reason, raw = assess_task_safety(
|
||||
client,
|
||||
model=self.config.safety_model,
|
||||
objective=objective,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
self.db.update_job(
|
||||
job_id,
|
||||
safety_checked=1,
|
||||
safety_passed=1 if safe else 0,
|
||||
safety_reason=reason,
|
||||
)
|
||||
self._publish(
|
||||
job_id,
|
||||
{
|
||||
"ts": utc_now_iso(),
|
||||
"step": 0,
|
||||
"event_type": "safety_check",
|
||||
"payload": {"safe": safe, "reason": reason, "raw": raw},
|
||||
},
|
||||
)
|
||||
if not safe:
|
||||
ended_at = utc_now_iso()
|
||||
error_text = f"Task blocked by safety gate: {reason}"
|
||||
self.db.update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
ended_at=ended_at,
|
||||
error=error_text,
|
||||
result=error_text,
|
||||
)
|
||||
self._publish(
|
||||
job_id,
|
||||
{
|
||||
"ts": ended_at,
|
||||
"step": 0,
|
||||
"event_type": "job_rejected",
|
||||
"payload": {"error": error_text},
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._running.pop(job_id, None)
|
||||
return
|
||||
else:
|
||||
self.db.update_job(
|
||||
job_id,
|
||||
safety_checked=1,
|
||||
safety_passed=1,
|
||||
safety_reason="Safety check bypassed by override.",
|
||||
)
|
||||
self._publish(
|
||||
job_id,
|
||||
{
|
||||
"ts": utc_now_iso(),
|
||||
"step": 0,
|
||||
"event_type": "safety_override",
|
||||
"payload": {"enabled": True},
|
||||
},
|
||||
)
|
||||
|
||||
def on_event(event: dict[str, Any]) -> None:
|
||||
self._publish(job_id, event)
|
||||
if event.get("event_type") == "usage_update":
|
||||
usage = (event.get("payload") or {}).get("usage") or {}
|
||||
self.db.update_job(
|
||||
job_id,
|
||||
input_tokens=int(usage.get("input_tokens", 0) or 0),
|
||||
cached_input_tokens=int(usage.get("cached_input_tokens", 0) or 0),
|
||||
output_tokens=int(usage.get("output_tokens", 0) or 0),
|
||||
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
|
||||
total_tokens=int(usage.get("total_tokens", 0) or 0),
|
||||
estimated_cost_usd=usage.get("estimated_cost_usd"),
|
||||
)
|
||||
|
||||
options = RuntimeOptions(
|
||||
model=model,
|
||||
max_steps=max_steps,
|
||||
command_timeout=command_timeout,
|
||||
type_interval=type_interval,
|
||||
click_pause=click_pause,
|
||||
disable_tools=set(disabled_tools),
|
||||
)
|
||||
try:
|
||||
result, artifacts = run_job(
|
||||
api_key=self.config.openai_api_key,
|
||||
objective=objective,
|
||||
options=options,
|
||||
runs_base=self.config.runs_dir,
|
||||
no_failsafe=no_failsafe,
|
||||
cancel_event=cancel_event,
|
||||
event_callback=on_event,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
ended_at = utc_now_iso()
|
||||
err = f"Fatal runtime error: {type(exc).__name__}: {exc}"
|
||||
self.db.update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
ended_at=ended_at,
|
||||
error=err,
|
||||
result=err,
|
||||
)
|
||||
self._publish(job_id, {"ts": ended_at, "step": 0, "event_type": "job_failed", "payload": {"error": err}})
|
||||
with self._lock:
|
||||
self._running.pop(job_id, None)
|
||||
return
|
||||
|
||||
ended_at = utc_now_iso()
|
||||
status = "completed" if result.completed else "failed"
|
||||
if result.cancelled:
|
||||
status = "cancelled"
|
||||
self.db.update_job(
|
||||
job_id,
|
||||
status=status,
|
||||
ended_at=ended_at,
|
||||
result=result.result,
|
||||
error=result.error,
|
||||
steps=result.steps,
|
||||
cancelled=1 if result.cancelled else 0,
|
||||
artifacts_dir=str(Path(artifacts.root_dir).resolve()),
|
||||
input_tokens=result.usage.input_tokens,
|
||||
cached_input_tokens=result.usage.cached_input_tokens,
|
||||
output_tokens=result.usage.output_tokens,
|
||||
reasoning_tokens=result.usage.reasoning_tokens,
|
||||
total_tokens=result.usage.total_tokens,
|
||||
estimated_cost_usd=result.usage.estimated_cost_usd,
|
||||
)
|
||||
self._publish(
|
||||
job_id,
|
||||
{
|
||||
"ts": ended_at,
|
||||
"step": result.steps,
|
||||
"event_type": "job_finished",
|
||||
"payload": {
|
||||
"status": status,
|
||||
"result": result.result,
|
||||
"error": result.error,
|
||||
"cancelled": result.cancelled,
|
||||
"usage": result.usage.to_dict(),
|
||||
},
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._running.pop(job_id, None)
|
||||
|
||||
def _publish(self, job_id: str, event: dict[str, Any]) -> None:
|
||||
ts = str(event.get("ts") or utc_now_iso())
|
||||
step = int(event.get("step", 0) or 0)
|
||||
event_type = str(event.get("event_type", "event"))
|
||||
payload = event.get("payload") or {}
|
||||
self.db.add_event(job_id=job_id, ts=ts, step=step, event_type=event_type, payload=payload)
|
||||
if self.broadcast is not None:
|
||||
self.broadcast(
|
||||
{
|
||||
"job_id": job_id,
|
||||
"ts": ts,
|
||||
"step": step,
|
||||
"event_type": event_type,
|
||||
"payload": payload,
|
||||
}
|
||||
)
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
with self._lock:
|
||||
running = self._running.get(job_id)
|
||||
if running is None:
|
||||
job = self.db.get_job(job_id)
|
||||
return bool(job and job.get("status") == "cancelled")
|
||||
running.cancel_event.set()
|
||||
self.db.update_job(job_id, status="cancelling")
|
||||
self._publish(job_id, {"ts": utc_now_iso(), "step": 0, "event_type": "cancel_requested", "payload": {}})
|
||||
return True
|
||||
|
||||
def get_job(self, job_id: str) -> dict[str, Any] | None:
|
||||
job = self.db.get_job(job_id)
|
||||
if job is None:
|
||||
return None
|
||||
live = self._running.get(job_id)
|
||||
if live and job["status"] in {"queued", "running", "cancelling"}:
|
||||
job["is_running_thread"] = live.thread.is_alive()
|
||||
else:
|
||||
job["is_running_thread"] = False
|
||||
return job
|
||||
|
||||
def list_jobs(self, limit: int = 100) -> list[dict[str, Any]]:
|
||||
return self.db.list_jobs(limit=limit)
|
||||
|
||||
def get_events(self, job_id: str, limit: int = 500) -> list[dict[str, Any]]:
|
||||
return self.db.get_job_events(job_id, limit=limit)
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
stats = self.db.stats()
|
||||
with self._lock:
|
||||
stats["live_running_threads"] = sum(1 for job in self._running.values() if job.thread.is_alive())
|
||||
return stats
|
||||
Reference in New Issue
Block a user