354 lines
12 KiB
Python
354 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Callable
|
|
|
|
from .config import AppConfig
|
|
from .models import RuntimeOptions
|
|
from .runtime import create_openai_client, run_job
|
|
from .safety import assess_task_safety
|
|
from .storage import HistoryDB
|
|
from .utils import utc_now_iso
|
|
|
|
|
|
@dataclass
|
|
class _RunningJob:
|
|
thread: threading.Thread
|
|
cancel_event: threading.Event
|
|
started_at: str
|
|
objective: str
|
|
model: str
|
|
|
|
|
|
class JobManager:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
config: AppConfig,
|
|
db: HistoryDB,
|
|
broadcast: Callable[[dict[str, Any]], None] | None = None,
|
|
) -> None:
|
|
self.config = config
|
|
self.db = db
|
|
self.broadcast = broadcast
|
|
self._running: dict[str, _RunningJob] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def submit_job(
|
|
self,
|
|
*,
|
|
objective: str,
|
|
model: str | None = None,
|
|
max_steps: int = 60,
|
|
command_timeout: int = 45,
|
|
type_interval: float = 0.02,
|
|
click_pause: float = 0.10,
|
|
disabled_tools: list[str] | None = None,
|
|
safety_override: bool = False,
|
|
no_failsafe: bool = False,
|
|
) -> str:
|
|
job_id = f"job_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
|
created_at = utc_now_iso()
|
|
selected_model = (model or self.config.default_model).strip() or self.config.default_model
|
|
disabled = sorted({tool.strip() for tool in (disabled_tools or []) if tool.strip()})
|
|
self.db.create_job(
|
|
job_id=job_id,
|
|
objective=objective,
|
|
model=selected_model,
|
|
created_at=created_at,
|
|
safety_override=safety_override,
|
|
disabled_tools=disabled,
|
|
)
|
|
self._publish(
|
|
job_id,
|
|
{
|
|
"ts": created_at,
|
|
"step": 0,
|
|
"event_type": "job_queued",
|
|
"payload": {
|
|
"job_id": job_id,
|
|
"objective": objective,
|
|
"model": selected_model,
|
|
"disabled_tools": disabled,
|
|
"safety_override": bool(safety_override),
|
|
},
|
|
},
|
|
)
|
|
|
|
cancel_event = threading.Event()
|
|
thread = threading.Thread(
|
|
target=self._execute_job,
|
|
kwargs={
|
|
"job_id": job_id,
|
|
"objective": objective,
|
|
"model": selected_model,
|
|
"disabled_tools": disabled,
|
|
"safety_override": safety_override,
|
|
"max_steps": max_steps,
|
|
"command_timeout": command_timeout,
|
|
"type_interval": type_interval,
|
|
"click_pause": click_pause,
|
|
"no_failsafe": no_failsafe,
|
|
"cancel_event": cancel_event,
|
|
},
|
|
daemon=True,
|
|
)
|
|
with self._lock:
|
|
self._running[job_id] = _RunningJob(
|
|
thread=thread,
|
|
cancel_event=cancel_event,
|
|
started_at=created_at,
|
|
objective=objective,
|
|
model=selected_model,
|
|
)
|
|
thread.start()
|
|
return job_id
|
|
|
|
def _execute_job(
|
|
self,
|
|
*,
|
|
job_id: str,
|
|
objective: str,
|
|
model: str,
|
|
disabled_tools: list[str],
|
|
safety_override: bool,
|
|
max_steps: int,
|
|
command_timeout: int,
|
|
type_interval: float,
|
|
click_pause: float,
|
|
no_failsafe: bool,
|
|
cancel_event: threading.Event,
|
|
) -> None:
|
|
started_at = utc_now_iso()
|
|
self.db.update_job(job_id, status="running", started_at=started_at)
|
|
self._publish(job_id, {"ts": started_at, "step": 0, "event_type": "job_started", "payload": {"job_id": job_id}})
|
|
|
|
if not safety_override:
|
|
client = create_openai_client(self.config.openai_api_key)
|
|
safe, reason, raw = assess_task_safety(
|
|
client,
|
|
model=self.config.safety_model,
|
|
objective=objective,
|
|
disabled_tools=disabled_tools,
|
|
)
|
|
self.db.update_job(
|
|
job_id,
|
|
safety_checked=1,
|
|
safety_passed=1 if safe else 0,
|
|
safety_reason=reason,
|
|
)
|
|
self._publish(
|
|
job_id,
|
|
{
|
|
"ts": utc_now_iso(),
|
|
"step": 0,
|
|
"event_type": "safety_check",
|
|
"payload": {"safe": safe, "reason": reason, "raw": raw},
|
|
},
|
|
)
|
|
if not safe:
|
|
ended_at = utc_now_iso()
|
|
error_text = f"Task blocked by safety gate: {reason}"
|
|
self.db.update_job(
|
|
job_id,
|
|
status="failed",
|
|
ended_at=ended_at,
|
|
error=error_text,
|
|
result=error_text,
|
|
response_json=json.dumps({"return": error_text, "data": None}, ensure_ascii=False),
|
|
)
|
|
self._publish(
|
|
job_id,
|
|
{
|
|
"ts": ended_at,
|
|
"step": 0,
|
|
"event_type": "job_rejected",
|
|
"payload": {"error": error_text},
|
|
},
|
|
)
|
|
with self._lock:
|
|
self._running.pop(job_id, None)
|
|
return
|
|
else:
|
|
self.db.update_job(
|
|
job_id,
|
|
safety_checked=1,
|
|
safety_passed=1,
|
|
safety_reason="Safety check bypassed by override.",
|
|
)
|
|
self._publish(
|
|
job_id,
|
|
{
|
|
"ts": utc_now_iso(),
|
|
"step": 0,
|
|
"event_type": "safety_override",
|
|
"payload": {"enabled": True},
|
|
},
|
|
)
|
|
|
|
def on_event(event: dict[str, Any]) -> None:
|
|
self._publish(job_id, event)
|
|
if event.get("event_type") == "job_started":
|
|
run_id = str(((event.get("payload") or {}).get("run_id") or "")).strip()
|
|
if run_id:
|
|
self.db.update_job(
|
|
job_id,
|
|
artifacts_dir=str((self.config.runs_dir / f"run_{run_id}").resolve()),
|
|
)
|
|
if event.get("event_type") == "usage_update":
|
|
usage = (event.get("payload") or {}).get("usage") or {}
|
|
self.db.update_job(
|
|
job_id,
|
|
input_tokens=int(usage.get("input_tokens", 0) or 0),
|
|
cached_input_tokens=int(usage.get("cached_input_tokens", 0) or 0),
|
|
output_tokens=int(usage.get("output_tokens", 0) or 0),
|
|
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
|
|
total_tokens=int(usage.get("total_tokens", 0) or 0),
|
|
estimated_cost_usd=usage.get("estimated_cost_usd"),
|
|
)
|
|
|
|
options = RuntimeOptions(
|
|
model=model,
|
|
max_steps=max_steps,
|
|
command_timeout=command_timeout,
|
|
type_interval=type_interval,
|
|
click_pause=click_pause,
|
|
disable_tools=set(disabled_tools),
|
|
)
|
|
try:
|
|
result, artifacts = run_job(
|
|
api_key=self.config.openai_api_key,
|
|
objective=objective,
|
|
options=options,
|
|
runs_base=self.config.runs_dir,
|
|
no_failsafe=no_failsafe,
|
|
cancel_event=cancel_event,
|
|
event_callback=on_event,
|
|
)
|
|
except Exception as exc: # noqa: BLE001
|
|
ended_at = utc_now_iso()
|
|
err = f"Fatal runtime error: {type(exc).__name__}: {exc}"
|
|
self.db.update_job(
|
|
job_id,
|
|
status="failed",
|
|
ended_at=ended_at,
|
|
error=err,
|
|
result=err,
|
|
response_json=json.dumps({"return": err, "data": None}, ensure_ascii=False),
|
|
)
|
|
self._publish(job_id, {"ts": ended_at, "step": 0, "event_type": "job_failed", "payload": {"error": err}})
|
|
with self._lock:
|
|
self._running.pop(job_id, None)
|
|
return
|
|
|
|
ended_at = utc_now_iso()
|
|
status = "completed" if result.completed else "failed"
|
|
if result.cancelled:
|
|
status = "cancelled"
|
|
self.db.update_job(
|
|
job_id,
|
|
status=status,
|
|
ended_at=ended_at,
|
|
result=result.return_message,
|
|
response_json=json.dumps(
|
|
{
|
|
"return": result.return_message,
|
|
"data": result.data,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
error=result.error,
|
|
steps=result.steps,
|
|
cancelled=1 if result.cancelled else 0,
|
|
artifacts_dir=str(Path(artifacts.root_dir).resolve()),
|
|
input_tokens=result.usage.input_tokens,
|
|
cached_input_tokens=result.usage.cached_input_tokens,
|
|
output_tokens=result.usage.output_tokens,
|
|
reasoning_tokens=result.usage.reasoning_tokens,
|
|
total_tokens=result.usage.total_tokens,
|
|
estimated_cost_usd=result.usage.estimated_cost_usd,
|
|
)
|
|
self._publish(
|
|
job_id,
|
|
{
|
|
"ts": ended_at,
|
|
"step": result.steps,
|
|
"event_type": "job_finished",
|
|
"payload": {
|
|
"status": status,
|
|
"result": result.return_message,
|
|
"response": {"return": result.return_message, "data": result.data},
|
|
"error": result.error,
|
|
"cancelled": result.cancelled,
|
|
"usage": result.usage.to_dict(),
|
|
},
|
|
},
|
|
)
|
|
with self._lock:
|
|
self._running.pop(job_id, None)
|
|
|
|
def _publish(self, job_id: str, event: dict[str, Any]) -> None:
|
|
ts = str(event.get("ts") or utc_now_iso())
|
|
step = int(event.get("step", 0) or 0)
|
|
event_type = str(event.get("event_type", "event"))
|
|
payload = event.get("payload") or {}
|
|
self.db.add_event(job_id=job_id, ts=ts, step=step, event_type=event_type, payload=payload)
|
|
if self.broadcast is not None:
|
|
self.broadcast(
|
|
{
|
|
"job_id": job_id,
|
|
"ts": ts,
|
|
"step": step,
|
|
"event_type": event_type,
|
|
"payload": payload,
|
|
}
|
|
)
|
|
|
|
def cancel_job(self, job_id: str) -> bool:
|
|
with self._lock:
|
|
running = self._running.get(job_id)
|
|
if running is None:
|
|
job = self.db.get_job(job_id)
|
|
return bool(job and job.get("status") == "cancelled")
|
|
running.cancel_event.set()
|
|
self.db.update_job(job_id, status="cancelling")
|
|
self._publish(job_id, {"ts": utc_now_iso(), "step": 0, "event_type": "cancel_requested", "payload": {}})
|
|
return True
|
|
|
|
def get_job(self, job_id: str) -> dict[str, Any] | None:
|
|
job = self.db.get_job(job_id)
|
|
if job is None:
|
|
return None
|
|
live = self._running.get(job_id)
|
|
if live and job["status"] in {"queued", "running", "cancelling"}:
|
|
job["is_running_thread"] = live.thread.is_alive()
|
|
else:
|
|
job["is_running_thread"] = False
|
|
return self._normalize_job_payload(job)
|
|
|
|
def list_jobs(self, limit: int = 100) -> list[dict[str, Any]]:
|
|
return [self._normalize_job_payload(job) for job in self.db.list_jobs(limit=limit)]
|
|
|
|
def get_events(self, job_id: str, limit: int = 500) -> list[dict[str, Any]]:
|
|
return self.db.get_job_events(job_id, limit=limit)
|
|
|
|
def stats(self) -> dict[str, Any]:
|
|
stats = self.db.stats()
|
|
with self._lock:
|
|
stats["live_running_threads"] = sum(1 for job in self._running.values() if job.thread.is_alive())
|
|
return stats
|
|
|
|
def _normalize_job_payload(self, job: dict[str, Any]) -> dict[str, Any]:
|
|
response = job.get("response")
|
|
if not isinstance(response, dict):
|
|
response = {"return": str(job.get("result") or ""), "data": None}
|
|
job["response"] = response
|
|
job["return"] = str(response.get("return") or "")
|
|
job["data"] = response.get("data")
|
|
return job
|