Add pytesseract OCR, click_text interact action, and interact verify endpoint
All checks were successful
python-syntax / syntax-check (push) Successful in 6s
All checks were successful
python-syntax / syntax-check (push) Successful in 6s
This commit is contained in:
@@ -8,13 +8,15 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from .config import SETTINGS
|
||||
from .models import ExecRequest, InteractRequest, LaunchRequest, SeeRequest, SeeZoomRequest, WindowActionRequest, WindowQuery
|
||||
from .models import ExecRequest, InteractRequest, InteractVerifyRequest, LaunchRequest, SeeRequest, SeeZoomRequest, WindowActionRequest, WindowQuery
|
||||
from .services import (
|
||||
apply_window_action,
|
||||
capture_region_image,
|
||||
capture_screen,
|
||||
draw_grid,
|
||||
encode_image,
|
||||
execute_and_verify,
|
||||
extract_ocr_items,
|
||||
exec_action,
|
||||
exec_command as run_exec_command,
|
||||
get_displays,
|
||||
@@ -65,7 +67,8 @@ async def _http_exception_handler(_: Request, exc: HTTPException):
|
||||
detail = exc.detail
|
||||
if isinstance(detail, dict):
|
||||
message = str(detail.get("message", "request failed"))
|
||||
return _err("http_error", message, exc.status_code, detail)
|
||||
code = str(detail.get("code", "http_error"))
|
||||
return _err(code, message, exc.status_code, detail.get("details"))
|
||||
return _err("http_error", str(detail), exc.status_code)
|
||||
|
||||
|
||||
@@ -99,6 +102,8 @@ def see(req: SeeRequest, _: None = Depends(_auth)):
|
||||
if req.with_grid:
|
||||
out_img, grid_meta = draw_grid(image, region["x"], region["y"], req.grid_rows, req.grid_cols, req.include_labels)
|
||||
meta.update(grid_meta)
|
||||
if req.ocr:
|
||||
meta["ocr"] = extract_ocr_items(image, region["x"], region["y"], req.ocr_min_confidence, req.ocr_lang, req.ocr_psm)
|
||||
return _ok(
|
||||
{
|
||||
"image": {
|
||||
@@ -154,6 +159,11 @@ def interact(req: InteractRequest, _: None = Depends(_auth)):
|
||||
return _ok(exec_action(req.action, req.screen))
|
||||
|
||||
|
||||
@app.post("/interact/verify")
|
||||
def interact_verify(req: InteractVerifyRequest, _: None = Depends(_auth)):
|
||||
return _ok(execute_and_verify(req))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health(_: None = Depends(_auth)):
|
||||
return _ok(
|
||||
|
||||
@@ -48,6 +48,7 @@ class ActionRequest(BaseModel):
|
||||
"scroll",
|
||||
"type",
|
||||
"hotkey",
|
||||
"click_text",
|
||||
]
|
||||
target: Optional[Target] = None
|
||||
duration_ms: int = Field(default=0, ge=0, le=20000)
|
||||
@@ -58,6 +59,13 @@ class ActionRequest(BaseModel):
|
||||
keys: list[str] = Field(default_factory=list)
|
||||
interval_ms: int = Field(default=20, ge=0, le=5000)
|
||||
dry_run: bool = False
|
||||
click_text: "ClickTextAction | None" = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_click_text(self):
|
||||
if self.action == "click_text" and self.click_text is None:
|
||||
raise ValueError("click_text payload is required when action=click_text")
|
||||
return self
|
||||
|
||||
|
||||
class ExecRequest(BaseModel):
|
||||
@@ -103,6 +111,10 @@ class SeeRequest(BaseModel):
|
||||
include_labels: bool = True
|
||||
image_format: Literal["png", "jpeg"] = "png"
|
||||
jpeg_quality: int = Field(default=85, ge=1, le=100)
|
||||
ocr: bool = False
|
||||
ocr_min_confidence: float = Field(default=0.0, ge=0.0, le=100.0)
|
||||
ocr_lang: str = Field(default="eng", min_length=1, max_length=64)
|
||||
ocr_psm: int | None = Field(default=None, ge=0, le=13)
|
||||
|
||||
|
||||
class SeeZoomRequest(BaseModel):
|
||||
@@ -122,3 +134,55 @@ class SeeZoomRequest(BaseModel):
|
||||
class InteractRequest(BaseModel):
|
||||
screen: int = 0
|
||||
action: ActionRequest
|
||||
|
||||
|
||||
class OCRRegion(BaseModel):
|
||||
x: int = Field(ge=0)
|
||||
y: int = Field(ge=0)
|
||||
width: int = Field(gt=0)
|
||||
height: int = Field(gt=0)
|
||||
|
||||
|
||||
class ClickTextAction(BaseModel):
|
||||
text: str = Field(min_length=1, max_length=1000)
|
||||
match: Literal["contains", "exact", "regex"] = "contains"
|
||||
region: OCRRegion | None = None
|
||||
screen: int | None = None
|
||||
case_sensitive: bool = False
|
||||
min_confidence: float = Field(default=0.0, ge=0.0, le=100.0)
|
||||
occurrence: Literal["first", "best", "nth"] = "first"
|
||||
nth: int | None = Field(default=None, ge=1, le=10000)
|
||||
ocr_lang: str = Field(default="eng", min_length=1, max_length=64)
|
||||
ocr_psm: int | None = Field(default=None, ge=0, le=13)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_nth(self):
|
||||
if self.occurrence == "nth" and self.nth is None:
|
||||
raise ValueError("nth is required when occurrence=nth")
|
||||
if self.occurrence != "nth" and self.nth is not None:
|
||||
raise ValueError("nth is only allowed when occurrence=nth")
|
||||
return self
|
||||
|
||||
|
||||
class VerifyOCRTextNearPoint(BaseModel):
|
||||
type: Literal["ocr_text_near_point"]
|
||||
text: str = Field(min_length=1, max_length=1000)
|
||||
x: int = Field(ge=0)
|
||||
y: int = Field(ge=0)
|
||||
radius: int = Field(default=80, ge=1, le=1000)
|
||||
screen: int = 0
|
||||
match: Literal["contains", "exact", "regex"] = "contains"
|
||||
case_sensitive: bool = False
|
||||
min_confidence: float = Field(default=0.0, ge=0.0, le=100.0)
|
||||
ocr_lang: str = Field(default="eng", min_length=1, max_length=64)
|
||||
ocr_psm: int | None = Field(default=None, ge=0, le=13)
|
||||
|
||||
|
||||
class InteractVerifyRequest(BaseModel):
|
||||
action: InteractRequest
|
||||
verify: VerifyOCRTextNearPoint
|
||||
check_interval_ms: int = Field(default=250, ge=50, le=5000)
|
||||
timeout_ms: int = Field(default=3000, ge=100, le=60000)
|
||||
|
||||
|
||||
ActionRequest.model_rebuild()
|
||||
|
||||
@@ -11,7 +11,22 @@ from fastapi import HTTPException
|
||||
from PIL import ImageChops, ImageStat
|
||||
|
||||
from .config import SETTINGS
|
||||
from .models import ActionRequest, GridTarget, LaunchRequest, PixelTarget, Target, WindowActionRequest, WindowQuery
|
||||
from .models import (
|
||||
ActionRequest,
|
||||
ClickTextAction,
|
||||
GridTarget,
|
||||
InteractVerifyRequest,
|
||||
LaunchRequest,
|
||||
PixelTarget,
|
||||
Target,
|
||||
VerifyOCRTextNearPoint,
|
||||
WindowActionRequest,
|
||||
WindowQuery,
|
||||
)
|
||||
|
||||
|
||||
def api_error(status_code: int, code: str, message: str, details=None):
|
||||
raise HTTPException(status_code=status_code, detail={"code": code, "message": message, "details": details})
|
||||
|
||||
|
||||
def import_capture_libs():
|
||||
@@ -85,6 +100,50 @@ def capture_region_image(screen: int, region_x: int | None, region_y: int | None
|
||||
return crop, {"x": region_x, "y": region_y, "width": region_width, "height": region_height}, mon, displays, screen_selection
|
||||
|
||||
|
||||
def extract_ocr_items(image, origin_x: int, origin_y: int, min_confidence: float, lang: str, psm: int | None) -> list[dict]:
|
||||
try:
|
||||
import pytesseract
|
||||
except Exception as exc:
|
||||
api_error(503, "ocr_unavailable", f"pytesseract unavailable: {exc}")
|
||||
|
||||
config = ""
|
||||
if psm is not None:
|
||||
config = f"--psm {psm}"
|
||||
try:
|
||||
data = pytesseract.image_to_data(image, lang=lang, config=config, output_type=pytesseract.Output.DICT)
|
||||
except Exception as exc:
|
||||
api_error(503, "ocr_failed", f"ocr failed: {exc}")
|
||||
|
||||
out: list[dict] = []
|
||||
n = len(data.get("text", []))
|
||||
for i in range(n):
|
||||
text = (data["text"][i] or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
try:
|
||||
confidence = float(data["conf"][i])
|
||||
except Exception:
|
||||
continue
|
||||
if confidence < min_confidence:
|
||||
continue
|
||||
left = int(data["left"][i])
|
||||
top = int(data["top"][i])
|
||||
width = int(data["width"][i])
|
||||
height = int(data["height"][i])
|
||||
bbox = {"x": origin_x + left, "y": origin_y + top, "width": width, "height": height}
|
||||
center = {"x": bbox["x"] + (width // 2), "y": bbox["y"] + (height // 2)}
|
||||
out.append(
|
||||
{
|
||||
"text": text,
|
||||
"confidence": confidence,
|
||||
"bbox": bbox,
|
||||
"center": center,
|
||||
"region_relative_bbox": {"x": left, "y": top, "width": width, "height": height},
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def serialize_image(image, image_format: str, jpeg_quality: int) -> bytes:
|
||||
buf = io.BytesIO()
|
||||
if image_format == "jpeg":
|
||||
@@ -164,6 +223,39 @@ def enforce_allowed_region(x: int, y: int):
|
||||
raise HTTPException(status_code=403, detail="point outside allowed region")
|
||||
|
||||
|
||||
def _text_matches(candidate: str, needle: str, mode: str, case_sensitive: bool) -> bool:
|
||||
hay = candidate if case_sensitive else candidate.lower()
|
||||
ndl = needle if case_sensitive else needle.lower()
|
||||
if mode == "contains":
|
||||
return ndl in hay
|
||||
if mode == "exact":
|
||||
return hay == ndl
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
return re.search(needle, candidate, flags=flags) is not None
|
||||
|
||||
|
||||
def _resolve_text_match(click_text: ClickTextAction, items: list[dict]) -> dict:
|
||||
matches = [item for item in items if _text_matches(item["text"], click_text.text, click_text.match, click_text.case_sensitive)]
|
||||
if not matches:
|
||||
candidates = [item["text"] for item in sorted(items, key=lambda v: v["confidence"], reverse=True)[:8]]
|
||||
api_error(404, "ocr_text_not_found", "no OCR text matched", {"query": click_text.text, "candidates": candidates})
|
||||
if click_text.occurrence == "best":
|
||||
return max(matches, key=lambda item: item["confidence"])
|
||||
if click_text.occurrence == "nth":
|
||||
idx = (click_text.nth or 1) - 1
|
||||
if idx >= len(matches):
|
||||
api_error(409, "ocr_nth_out_of_range", "requested nth match is out of range", {"match_count": len(matches), "nth": click_text.nth})
|
||||
return matches[idx]
|
||||
if len(matches) > 1 and click_text.match == "exact":
|
||||
api_error(
|
||||
409,
|
||||
"ocr_text_ambiguous",
|
||||
"multiple OCR entries matched",
|
||||
{"match_count": len(matches), "candidates": [item["text"] for item in matches[:8]]},
|
||||
)
|
||||
return matches[0]
|
||||
|
||||
|
||||
def import_input_lib():
|
||||
try:
|
||||
import pyautogui
|
||||
@@ -176,7 +268,10 @@ def import_input_lib():
|
||||
|
||||
def exec_action(req: ActionRequest, screen: int = 0) -> dict:
|
||||
run_dry = SETTINGS["dry_run"] or req.dry_run
|
||||
selected_display, _, screen_selection = select_display(screen)
|
||||
action_screen = screen
|
||||
if req.action == "click_text" and req.click_text and req.click_text.screen is not None:
|
||||
action_screen = req.click_text.screen
|
||||
selected_display, _, screen_selection = select_display(action_screen)
|
||||
pyautogui = None if run_dry else import_input_lib()
|
||||
resolved_target = None
|
||||
|
||||
@@ -191,6 +286,36 @@ def exec_action(req: ActionRequest, screen: int = 0) -> dict:
|
||||
if req.action == "scroll" and resolved_target is None:
|
||||
raise HTTPException(status_code=400, detail="target is required for scroll")
|
||||
|
||||
click_text_match = None
|
||||
if req.action == "click_text":
|
||||
if req.click_text is None:
|
||||
api_error(400, "click_text_payload_required", "click_text payload is required")
|
||||
region = req.click_text.region
|
||||
img, captured_region, _, _, _ = capture_region_image(
|
||||
action_screen,
|
||||
None if region is None else region.x,
|
||||
None if region is None else region.y,
|
||||
None if region is None else region.width,
|
||||
None if region is None else region.height,
|
||||
)
|
||||
items = extract_ocr_items(
|
||||
img,
|
||||
captured_region["x"],
|
||||
captured_region["y"],
|
||||
req.click_text.min_confidence,
|
||||
req.click_text.ocr_lang,
|
||||
req.click_text.ocr_psm,
|
||||
)
|
||||
matched = _resolve_text_match(req.click_text, items)
|
||||
enforce_allowed_region(matched["center"]["x"], matched["center"]["y"])
|
||||
click_text_match = {
|
||||
"query": req.click_text.model_dump(),
|
||||
"matched": matched,
|
||||
"capture_region": captured_region,
|
||||
"screen": screen_selection,
|
||||
}
|
||||
resolved_target = {"x": matched["center"]["x"], "y": matched["center"]["y"], "target_info": {"mode": "ocr_text"}}
|
||||
|
||||
if not run_dry:
|
||||
if req.action == "move":
|
||||
pyautogui.moveTo(resolved_target["x"], resolved_target["y"], duration=duration_sec)
|
||||
@@ -211,8 +336,71 @@ def exec_action(req: ActionRequest, screen: int = 0) -> dict:
|
||||
if len(req.keys) < 1:
|
||||
raise HTTPException(status_code=400, detail="keys is required for hotkey")
|
||||
pyautogui.hotkey(*req.keys)
|
||||
elif req.action == "click_text":
|
||||
pyautogui.click(
|
||||
x=resolved_target["x"],
|
||||
y=resolved_target["y"],
|
||||
clicks=req.clicks,
|
||||
interval=req.interval_ms / 1000.0,
|
||||
button=req.button,
|
||||
duration=duration_sec,
|
||||
)
|
||||
|
||||
return {"action": req.action, "executed": not run_dry, "dry_run": run_dry, "screen": screen_selection, "display": selected_display, "resolved_target": resolved_target}
|
||||
return {
|
||||
"action": req.action,
|
||||
"executed": not run_dry,
|
||||
"dry_run": run_dry,
|
||||
"screen": screen_selection,
|
||||
"display": selected_display,
|
||||
"resolved_target": resolved_target,
|
||||
"click_text_match": click_text_match,
|
||||
}
|
||||
|
||||
|
||||
def _verify_ocr_text_near_point(spec: VerifyOCRTextNearPoint) -> dict:
|
||||
radius = spec.radius
|
||||
display, _, _ = select_display(spec.screen)
|
||||
region_x = max(display["x"], spec.x - radius)
|
||||
region_y = max(display["y"], spec.y - radius)
|
||||
max_right = display["x"] + display["width"]
|
||||
max_bottom = display["y"] + display["height"]
|
||||
region_right = min(max_right, spec.x + radius)
|
||||
region_bottom = min(max_bottom, spec.y + radius)
|
||||
region_w = max(1, region_right - region_x)
|
||||
region_h = max(1, region_bottom - region_y)
|
||||
img, region, _, _, screen_selection = capture_region_image(spec.screen, region_x, region_y, region_w, region_h)
|
||||
items = extract_ocr_items(img, region["x"], region["y"], spec.min_confidence, spec.ocr_lang, spec.ocr_psm)
|
||||
matches = [item for item in items if _text_matches(item["text"], spec.text, spec.match, spec.case_sensitive)]
|
||||
return {"ok": len(matches) > 0, "matches": matches[:8], "items_count": len(items), "screen": screen_selection, "region": region}
|
||||
|
||||
|
||||
def execute_and_verify(req: InteractVerifyRequest) -> dict:
|
||||
started = time.time()
|
||||
action_result = exec_action(req.action.action, req.action.screen)
|
||||
attempts = 0
|
||||
last_check = None
|
||||
deadline = started + (req.timeout_ms / 1000.0)
|
||||
while True:
|
||||
attempts += 1
|
||||
check = _verify_ocr_text_near_point(req.verify)
|
||||
last_check = check
|
||||
if check["ok"]:
|
||||
return {
|
||||
"action_result": action_result,
|
||||
"verified": True,
|
||||
"attempts": attempts,
|
||||
"last_check": last_check,
|
||||
"duration_ms": int((time.time() - started) * 1000),
|
||||
}
|
||||
if time.time() >= deadline:
|
||||
return {
|
||||
"action_result": action_result,
|
||||
"verified": False,
|
||||
"attempts": attempts,
|
||||
"last_check": last_check,
|
||||
"duration_ms": int((time.time() - started) * 1000),
|
||||
}
|
||||
time.sleep(req.check_interval_ms / 1000.0)
|
||||
|
||||
|
||||
def windows_only(feature: str):
|
||||
|
||||
Reference in New Issue
Block a user