Files
screenjob/src/task_manager.py

354 lines
12 KiB
Python

from __future__ import annotations
import json
import threading
import time
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
from .config import AppConfig
from .models import RuntimeOptions
from .runtime import create_openai_client, run_job
from .safety import assess_task_safety
from .storage import HistoryDB
from .utils import utc_now_iso
@dataclass
class _RunningJob:
thread: threading.Thread
cancel_event: threading.Event
started_at: str
objective: str
model: str
class JobManager:
def __init__(
self,
*,
config: AppConfig,
db: HistoryDB,
broadcast: Callable[[dict[str, Any]], None] | None = None,
) -> None:
self.config = config
self.db = db
self.broadcast = broadcast
self._running: dict[str, _RunningJob] = {}
self._lock = threading.Lock()
def submit_job(
self,
*,
objective: str,
model: str | None = None,
max_steps: int = 60,
command_timeout: int = 45,
type_interval: float = 0.02,
click_pause: float = 0.10,
disabled_tools: list[str] | None = None,
safety_override: bool = False,
no_failsafe: bool = False,
) -> str:
job_id = f"job_{int(time.time())}_{uuid.uuid4().hex[:8]}"
created_at = utc_now_iso()
selected_model = (model or self.config.default_model).strip() or self.config.default_model
disabled = sorted({tool.strip() for tool in (disabled_tools or []) if tool.strip()})
self.db.create_job(
job_id=job_id,
objective=objective,
model=selected_model,
created_at=created_at,
safety_override=safety_override,
disabled_tools=disabled,
)
self._publish(
job_id,
{
"ts": created_at,
"step": 0,
"event_type": "job_queued",
"payload": {
"job_id": job_id,
"objective": objective,
"model": selected_model,
"disabled_tools": disabled,
"safety_override": bool(safety_override),
},
},
)
cancel_event = threading.Event()
thread = threading.Thread(
target=self._execute_job,
kwargs={
"job_id": job_id,
"objective": objective,
"model": selected_model,
"disabled_tools": disabled,
"safety_override": safety_override,
"max_steps": max_steps,
"command_timeout": command_timeout,
"type_interval": type_interval,
"click_pause": click_pause,
"no_failsafe": no_failsafe,
"cancel_event": cancel_event,
},
daemon=True,
)
with self._lock:
self._running[job_id] = _RunningJob(
thread=thread,
cancel_event=cancel_event,
started_at=created_at,
objective=objective,
model=selected_model,
)
thread.start()
return job_id
def _execute_job(
self,
*,
job_id: str,
objective: str,
model: str,
disabled_tools: list[str],
safety_override: bool,
max_steps: int,
command_timeout: int,
type_interval: float,
click_pause: float,
no_failsafe: bool,
cancel_event: threading.Event,
) -> None:
started_at = utc_now_iso()
self.db.update_job(job_id, status="running", started_at=started_at)
self._publish(job_id, {"ts": started_at, "step": 0, "event_type": "job_started", "payload": {"job_id": job_id}})
if not safety_override:
client = create_openai_client(self.config.openai_api_key)
safe, reason, raw = assess_task_safety(
client,
model=self.config.safety_model,
objective=objective,
disabled_tools=disabled_tools,
)
self.db.update_job(
job_id,
safety_checked=1,
safety_passed=1 if safe else 0,
safety_reason=reason,
)
self._publish(
job_id,
{
"ts": utc_now_iso(),
"step": 0,
"event_type": "safety_check",
"payload": {"safe": safe, "reason": reason, "raw": raw},
},
)
if not safe:
ended_at = utc_now_iso()
error_text = f"Task blocked by safety gate: {reason}"
self.db.update_job(
job_id,
status="failed",
ended_at=ended_at,
error=error_text,
result=error_text,
response_json=json.dumps({"return": error_text, "data": None}, ensure_ascii=False),
)
self._publish(
job_id,
{
"ts": ended_at,
"step": 0,
"event_type": "job_rejected",
"payload": {"error": error_text},
},
)
with self._lock:
self._running.pop(job_id, None)
return
else:
self.db.update_job(
job_id,
safety_checked=1,
safety_passed=1,
safety_reason="Safety check bypassed by override.",
)
self._publish(
job_id,
{
"ts": utc_now_iso(),
"step": 0,
"event_type": "safety_override",
"payload": {"enabled": True},
},
)
def on_event(event: dict[str, Any]) -> None:
self._publish(job_id, event)
if event.get("event_type") == "job_started":
run_id = str(((event.get("payload") or {}).get("run_id") or "")).strip()
if run_id:
self.db.update_job(
job_id,
artifacts_dir=str((self.config.runs_dir / f"run_{run_id}").resolve()),
)
if event.get("event_type") == "usage_update":
usage = (event.get("payload") or {}).get("usage") or {}
self.db.update_job(
job_id,
input_tokens=int(usage.get("input_tokens", 0) or 0),
cached_input_tokens=int(usage.get("cached_input_tokens", 0) or 0),
output_tokens=int(usage.get("output_tokens", 0) or 0),
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
total_tokens=int(usage.get("total_tokens", 0) or 0),
estimated_cost_usd=usage.get("estimated_cost_usd"),
)
options = RuntimeOptions(
model=model,
max_steps=max_steps,
command_timeout=command_timeout,
type_interval=type_interval,
click_pause=click_pause,
disable_tools=set(disabled_tools),
)
try:
result, artifacts = run_job(
api_key=self.config.openai_api_key,
objective=objective,
options=options,
runs_base=self.config.runs_dir,
no_failsafe=no_failsafe,
cancel_event=cancel_event,
event_callback=on_event,
)
except Exception as exc: # noqa: BLE001
ended_at = utc_now_iso()
err = f"Fatal runtime error: {type(exc).__name__}: {exc}"
self.db.update_job(
job_id,
status="failed",
ended_at=ended_at,
error=err,
result=err,
response_json=json.dumps({"return": err, "data": None}, ensure_ascii=False),
)
self._publish(job_id, {"ts": ended_at, "step": 0, "event_type": "job_failed", "payload": {"error": err}})
with self._lock:
self._running.pop(job_id, None)
return
ended_at = utc_now_iso()
status = "completed" if result.completed else "failed"
if result.cancelled:
status = "cancelled"
self.db.update_job(
job_id,
status=status,
ended_at=ended_at,
result=result.return_message,
response_json=json.dumps(
{
"return": result.return_message,
"data": result.data,
},
ensure_ascii=False,
),
error=result.error,
steps=result.steps,
cancelled=1 if result.cancelled else 0,
artifacts_dir=str(Path(artifacts.root_dir).resolve()),
input_tokens=result.usage.input_tokens,
cached_input_tokens=result.usage.cached_input_tokens,
output_tokens=result.usage.output_tokens,
reasoning_tokens=result.usage.reasoning_tokens,
total_tokens=result.usage.total_tokens,
estimated_cost_usd=result.usage.estimated_cost_usd,
)
self._publish(
job_id,
{
"ts": ended_at,
"step": result.steps,
"event_type": "job_finished",
"payload": {
"status": status,
"result": result.return_message,
"response": {"return": result.return_message, "data": result.data},
"error": result.error,
"cancelled": result.cancelled,
"usage": result.usage.to_dict(),
},
},
)
with self._lock:
self._running.pop(job_id, None)
def _publish(self, job_id: str, event: dict[str, Any]) -> None:
ts = str(event.get("ts") or utc_now_iso())
step = int(event.get("step", 0) or 0)
event_type = str(event.get("event_type", "event"))
payload = event.get("payload") or {}
self.db.add_event(job_id=job_id, ts=ts, step=step, event_type=event_type, payload=payload)
if self.broadcast is not None:
self.broadcast(
{
"job_id": job_id,
"ts": ts,
"step": step,
"event_type": event_type,
"payload": payload,
}
)
def cancel_job(self, job_id: str) -> bool:
with self._lock:
running = self._running.get(job_id)
if running is None:
job = self.db.get_job(job_id)
return bool(job and job.get("status") == "cancelled")
running.cancel_event.set()
self.db.update_job(job_id, status="cancelling")
self._publish(job_id, {"ts": utc_now_iso(), "step": 0, "event_type": "cancel_requested", "payload": {}})
return True
def get_job(self, job_id: str) -> dict[str, Any] | None:
job = self.db.get_job(job_id)
if job is None:
return None
live = self._running.get(job_id)
if live and job["status"] in {"queued", "running", "cancelling"}:
job["is_running_thread"] = live.thread.is_alive()
else:
job["is_running_thread"] = False
return self._normalize_job_payload(job)
def list_jobs(self, limit: int = 100) -> list[dict[str, Any]]:
return [self._normalize_job_payload(job) for job in self.db.list_jobs(limit=limit)]
def get_events(self, job_id: str, limit: int = 500) -> list[dict[str, Any]]:
return self.db.get_job_events(job_id, limit=limit)
def stats(self) -> dict[str, Any]:
stats = self.db.stats()
with self._lock:
stats["live_running_threads"] = sum(1 for job in self._running.values() if job.thread.is_alive())
return stats
def _normalize_job_payload(self, job: dict[str, Any]) -> dict[str, Any]:
response = job.get("response")
if not isinstance(response, dict):
response = {"return": str(job.get("result") or ""), "data": None}
job["response"] = response
job["return"] = str(response.get("return") or "")
job["data"] = response.get("data")
return job