feat(wait): add structured wait endpoint
All checks were successful
python-syntax / syntax-check (push) Successful in 7s

This commit is contained in:
2026-05-01 15:55:29 +02:00
parent 493e5499e8
commit 5122d416e8
4 changed files with 304 additions and 2 deletions

View File

@@ -3,6 +3,7 @@ import ctypes
import hmac
import io
import os
import re
import subprocess
import sys
import time
@@ -11,6 +12,7 @@ from typing import Literal, Optional
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Header, HTTPException, Response
from PIL import ImageChops, ImageStat
from pydantic import BaseModel, Field, model_validator
@@ -193,6 +195,50 @@ class LaunchRequest(BaseModel):
dry_run: bool = False
class WaitTextCondition(BaseModel):
kind: Literal["text"]
mode: Literal["screen", "region"] = "screen"
text: str = Field(min_length=1, max_length=512)
match: Literal["contains", "exact", "regex"] = "contains"
present: bool = True
region_x: int | None = Field(default=None, ge=0)
region_y: int | None = Field(default=None, ge=0)
region_width: int | None = Field(default=None, gt=0)
region_height: int | None = Field(default=None, gt=0)
language_hint: str | None = Field(default=None, min_length=1, max_length=64)
min_confidence: float = Field(default=0.0, ge=0.0, le=1.0)
@model_validator(mode="after")
def _validate_region(self):
if self.mode == "region":
required = [self.region_x, self.region_y, self.region_width, self.region_height]
if any(v is None for v in required):
raise ValueError("region_x, region_y, region_width, region_height are required for mode=region")
return self
class WaitWindowCondition(WindowQuery):
kind: Literal["window"]
state: Literal["exists", "focused", "closed"] = "exists"
class WaitVisualCondition(BaseModel):
kind: Literal["visual"]
state: Literal["change", "stable"] = "change"
region_x: int | None = Field(default=None, ge=0)
region_y: int | None = Field(default=None, ge=0)
region_width: int | None = Field(default=None, gt=0)
region_height: int | None = Field(default=None, gt=0)
diff_threshold: float = Field(default=0.01, ge=0.0, le=1.0)
stable_for_ms: int = Field(default=800, ge=0, le=60000)
class WaitRequest(BaseModel):
condition: WaitTextCondition | WaitWindowCondition | WaitVisualCondition
timeout_ms: int = Field(default=5000, ge=0, le=120000)
poll_interval_ms: int = Field(default=250, ge=50, le=10000)
def _auth(x_clickthrough_token: Optional[str] = Header(default=None)):
@@ -483,6 +529,18 @@ def _run_ocr(image, language_hint: str | None, min_confidence: float, offset_x:
return blocks
def _normalize_text(value: str) -> str:
return re.sub(r"\s+", " ", value).strip()
def _matches_text(haystack: str, needle: str, match_mode: str) -> bool:
if match_mode == "exact":
return haystack == needle
if match_mode == "regex":
return re.search(needle, haystack) is not None
return needle.lower() in haystack.lower()
def _windows_only(feature: str):
if sys.platform != "win32":
raise HTTPException(status_code=501, detail=f"{feature} is currently supported on Windows hosts only")
@@ -698,6 +756,163 @@ def _launch_app(req: LaunchRequest) -> dict:
return result
def _capture_region_image(screen: int, region_x: int | None, region_y: int | None, region_width: int | None, region_height: int | None):
base_img, mon, displays, screen_selection = _capture_screen(screen)
if None in {region_x, region_y, region_width, region_height}:
return base_img, {"x": mon["x"], "y": mon["y"], "width": mon["width"], "height": mon["height"]}, mon, displays, screen_selection
left = region_x - mon["x"]
top = region_y - mon["y"]
right = left + region_width
bottom = top + region_height
if left < 0 or top < 0 or right > base_img.size[0] or bottom > base_img.size[1]:
raise HTTPException(status_code=400, detail="requested region is outside the captured monitor")
crop = base_img.crop((left, top, right, bottom))
region = {"x": region_x, "y": region_y, "width": region_width, "height": region_height}
return crop, region, mon, displays, screen_selection
def _image_diff_ratio(before, after) -> float:
diff = ImageChops.difference(before, after)
stat = ImageStat.Stat(diff)
channel_means = stat.mean if isinstance(stat.mean, list) else [stat.mean]
return float(sum(channel_means) / (len(channel_means) * 255.0))
def _wait_for_condition(req: WaitRequest, screen: int = 0) -> dict:
condition = req.condition
deadline = time.time() + (req.timeout_ms / 1000.0)
polls = 0
if isinstance(condition, WaitVisualCondition):
baseline, region, mon, displays, screen_selection = _capture_region_image(
screen,
condition.region_x,
condition.region_y,
condition.region_width,
condition.region_height,
)
stable_since = None
last_diff = 0.0
while True:
if time.time() > deadline:
return {
"satisfied": False,
"kind": condition.kind,
"state": condition.state,
"polls": polls,
"region": region,
"diff_ratio": last_diff,
"screen": screen_selection,
"display": mon,
}
time.sleep(req.poll_interval_ms / 1000.0)
current, _, _, _, _ = _capture_region_image(
screen,
region["x"],
region["y"],
region["width"],
region["height"],
)
polls += 1
last_diff = _image_diff_ratio(baseline, current)
if condition.state == "change":
if last_diff >= condition.diff_threshold:
return {
"satisfied": True,
"kind": condition.kind,
"state": condition.state,
"polls": polls,
"region": region,
"diff_ratio": last_diff,
"screen": screen_selection,
"display": mon,
}
else:
if last_diff <= condition.diff_threshold:
stable_since = stable_since or time.time()
if (time.time() - stable_since) * 1000 >= condition.stable_for_ms:
return {
"satisfied": True,
"kind": condition.kind,
"state": condition.state,
"polls": polls,
"region": region,
"diff_ratio": last_diff,
"stable_for_ms": int((time.time() - stable_since) * 1000),
"screen": screen_selection,
"display": mon,
}
else:
stable_since = None
baseline = current
while True:
if isinstance(condition, WaitWindowCondition):
matches = _list_windows(condition)
polls += 1
satisfied = False
if condition.state == "exists":
satisfied = bool(matches)
elif condition.state == "focused":
satisfied = any(item["foreground"] for item in matches)
elif condition.state == "closed":
satisfied = not matches
if satisfied:
return {
"satisfied": True,
"kind": condition.kind,
"state": condition.state,
"polls": polls,
"matches": matches[:10],
}
elif isinstance(condition, WaitTextCondition):
image, region, mon, displays, screen_selection = _capture_region_image(
screen,
condition.region_x,
condition.region_y,
condition.region_width,
condition.region_height,
)
blocks = _run_ocr(
image,
condition.language_hint,
condition.min_confidence,
region["x"],
region["y"],
)
polls += 1
matched = []
for block in blocks:
normalized = _normalize_text(block["text"])
target = _normalize_text(condition.text)
if _matches_text(normalized, target, condition.match):
matched.append(block)
satisfied = bool(matched) if condition.present else not bool(matched)
if satisfied:
return {
"satisfied": True,
"kind": condition.kind,
"mode": condition.mode,
"polls": polls,
"region": region,
"matches": matched,
"screen": screen_selection,
"display": mon,
}
else:
raise HTTPException(status_code=400, detail="unsupported wait condition")
if time.time() > deadline:
return {
"satisfied": False,
"kind": condition.kind,
"polls": polls,
}
time.sleep(req.poll_interval_ms / 1000.0)
def _pick_shell(explicit_shell: str | None) -> str:
shell_name = (explicit_shell or SETTINGS["exec_default_shell"] or "powershell").lower().strip()
if shell_name not in {"powershell", "bash", "cmd"}:
@@ -1089,6 +1304,17 @@ def launch(req: LaunchRequest, _: None = Depends(_auth)):
}
@app.post("/wait")
def wait(req: WaitRequest, screen: int = 0, _: None = Depends(_auth)):
result = _wait_for_condition(req, screen)
return {
"ok": result.get("satisfied", False),
"request_id": _request_id(),
"time_ms": _now_ms(),
"result": result,
}
@app.post("/ocr")
def ocr(req: OCRRequest, screen: int = 0, _: None = Depends(_auth)):
source = req.mode