from __future__ import annotations import logging from datetime import datetime, timedelta, timezone from dataclasses import dataclass from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from gitea_codex_bot.models import JobStatus, ReviewJob, ReviewRun, RunStatus, WebhookEvent from gitea_codex_bot.services.security import payload_digest from gitea_codex_bot.types import ParsedCommand logger = logging.getLogger(__name__) LEASE_TIMEOUT_ERROR_PREFIX = "Job lease timed out" @dataclass class RecoveryOutcome: repo: str pr_number: int job_id: int retries_used: int failed: bool message: str def persist_webhook_event( session: Session, *, delivery_id: str | None, event_name: str, repo: str, comment_id: int | None, payload: bytes, ) -> bool: event = WebhookEvent( delivery_id=delivery_id, event_name=event_name, repo=repo, comment_id=comment_id, payload_sha256=payload_digest(payload), ) session.add(event) try: session.commit() return True except IntegrityError: session.rollback() return False def cooldown_remaining_seconds(session: Session, repo: str, pr_number: int, cooldown_seconds: int) -> int: cutoff = datetime.now(timezone.utc) - timedelta(seconds=cooldown_seconds) row = session.execute( select(ReviewJob) .where(ReviewJob.repo == repo, ReviewJob.pr_number == pr_number, ReviewJob.created_at >= cutoff) .order_by(ReviewJob.created_at.desc()) .limit(1) ).scalar_one_or_none() if not row: return 0 created_at = row.created_at if created_at.tzinfo is None: created_at = created_at.replace(tzinfo=timezone.utc) age = (datetime.now(timezone.utc) - created_at).total_seconds() remaining = int(max(cooldown_seconds - age, 0)) return remaining def enqueue_job( session: Session, *, repo: str, pr_number: int, head_sha: str, trigger_comment_id: int, trigger_comment_body: str | None, requested_by: str, command: ParsedCommand, ) -> ReviewJob: job = ReviewJob( repo=repo, pr_number=pr_number, head_sha=head_sha, trigger_comment_id=trigger_comment_id, trigger_comment_body=trigger_comment_body, command=command.name, command_args=" ".join(command.arguments) if command.arguments else None, requested_by=requested_by, status=JobStatus.queued, ) session.add(job) session.commit() session.refresh(job) logger.info( "Job enqueued id=%s repo=%s pr=%s command=%s head_sha=%s trigger_comment_id=%s requested_by=%s", job.id, job.repo, job.pr_number, job.command, job.head_sha, job.trigger_comment_id, job.requested_by, ) return job def claim_next_job(session: Session) -> ReviewJob | None: job = session.execute( select(ReviewJob).where(ReviewJob.status == JobStatus.queued).order_by(ReviewJob.created_at.asc()).limit(1).with_for_update(skip_locked=True) ).scalar_one_or_none() if not job: session.rollback() return None job.status = JobStatus.running job.started_at = datetime.now(timezone.utc) run = ReviewRun(job_id=job.id, status=RunStatus.running) session.add(run) session.commit() session.refresh(job) logger.info( "Job claimed id=%s repo=%s pr=%s command=%s head_sha=%s status=%s", job.id, job.repo, job.pr_number, job.command, job.head_sha, job.status.value if hasattr(job.status, "value") else job.status, ) return job def recover_stuck_jobs(session: Session, *, lease_timeout_seconds: int, action: str, max_retries: int) -> list[RecoveryOutcome]: if lease_timeout_seconds <= 0: return [] now = datetime.now(timezone.utc) cutoff = now - timedelta(seconds=lease_timeout_seconds) stale_jobs = session.execute( select(ReviewJob) .where( ReviewJob.status == JobStatus.running, ReviewJob.started_at.is_not(None), ReviewJob.started_at < cutoff, ) .order_by(ReviewJob.started_at.asc()) .with_for_update(skip_locked=True) ).scalars() outcomes: list[RecoveryOutcome] = [] for job in stale_jobs: prior_retries = session.execute( select(ReviewRun) .where( ReviewRun.job_id == job.id, ReviewRun.status == RunStatus.failed, ReviewRun.error_message.is_not(None), ) .order_by(ReviewRun.id.asc()) ).scalars() lease_retries_used = sum(1 for run in prior_retries if (run.error_message or "").startswith(LEASE_TIMEOUT_ERROR_PREFIX)) retries_used_after_this_timeout = lease_retries_used + 1 should_fail = action == "fail" or lease_retries_used >= max_retries message = ( f"{LEASE_TIMEOUT_ERROR_PREFIX} after {lease_timeout_seconds}s while in running state; " f"retries_used={retries_used_after_this_timeout}, max_retries={max_retries}." ) latest_run = ( session.execute(select(ReviewRun).where(ReviewRun.job_id == job.id).order_by(ReviewRun.id.desc()).limit(1)).scalar_one_or_none() ) if latest_run and latest_run.status == RunStatus.running: latest_run.status = RunStatus.failed latest_run.finished_at = now latest_run.error_message = message job.last_error = message if should_fail: job.status = JobStatus.failed job.finished_at = now else: job.status = JobStatus.queued job.started_at = None job.finished_at = None outcomes.append( RecoveryOutcome( repo=job.repo, pr_number=job.pr_number, job_id=job.id, retries_used=retries_used_after_this_timeout, failed=should_fail, message=message, ) ) session.commit() return outcomes def finish_job( session: Session, *, job_id: int, success: bool, skipped: bool, result: dict | None, error_message: str | None, ) -> None: job = session.get(ReviewJob, job_id) if not job: return latest_run = ( session.execute(select(ReviewRun).where(ReviewRun.job_id == job_id).order_by(ReviewRun.id.desc()).limit(1)).scalar_one_or_none() ) if skipped: job.status = JobStatus.skipped run_status = RunStatus.skipped elif success: job.status = JobStatus.succeeded run_status = RunStatus.succeeded else: job.status = JobStatus.failed run_status = RunStatus.failed now = datetime.now(timezone.utc) job.finished_at = now job.last_error = error_message if result is not None: job.result_json = result if latest_run: latest_run.status = run_status latest_run.finished_at = now latest_run.result_json = result latest_run.error_message = error_message session.commit() logger.info( "Job finished id=%s repo=%s pr=%s status=%s run_status=%s skipped=%s error_present=%s", job.id, job.repo, job.pr_number, job.status.value if hasattr(job.status, "value") else job.status, run_status.value if hasattr(run_status, "value") else run_status, skipped, bool(error_message), )