603 lines
24 KiB
Python
603 lines
24 KiB
Python
import ctypes
|
|
import io
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from typing import Literal
|
|
|
|
from fastapi import HTTPException
|
|
from PIL import ImageChops, ImageStat
|
|
|
|
from .config import SETTINGS
|
|
from .models import (
|
|
ActionRequest,
|
|
ClickTextAction,
|
|
GridTarget,
|
|
LaunchRequest,
|
|
PixelTarget,
|
|
Target,
|
|
WindowActionRequest,
|
|
WindowQuery,
|
|
)
|
|
|
|
|
|
def api_error(status_code: int, code: str, message: str, details=None):
|
|
raise HTTPException(status_code=status_code, detail={"code": code, "message": message, "details": details})
|
|
|
|
|
|
def import_capture_libs():
|
|
try:
|
|
from PIL import Image, ImageDraw
|
|
import mss
|
|
|
|
return Image, ImageDraw, mss
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=f"capture backend unavailable: {exc}") from exc
|
|
|
|
|
|
def display_region(mon: dict, screen: int, mss_index: int, primary: bool) -> dict:
|
|
return {
|
|
"screen": screen,
|
|
"mss_index": mss_index,
|
|
"primary": primary,
|
|
"x": mon["left"],
|
|
"y": mon["top"],
|
|
"width": mon["width"],
|
|
"height": mon["height"],
|
|
}
|
|
|
|
|
|
def ordered_displays(sct) -> list[dict]:
|
|
raw_monitors = list(enumerate(sct.monitors[1:], start=1))
|
|
if not raw_monitors:
|
|
raise HTTPException(status_code=500, detail="no displays detected")
|
|
|
|
primary_pos = next((idx for idx, (_, mon) in enumerate(raw_monitors) if mon["left"] == 0 and mon["top"] == 0), 0)
|
|
ordered = [raw_monitors[primary_pos]] + [item for idx, item in enumerate(raw_monitors) if idx != primary_pos]
|
|
return [display_region(mon, screen=index, mss_index=mss_index, primary=(index == 0)) for index, (mss_index, mon) in enumerate(ordered)]
|
|
|
|
|
|
def get_displays() -> list[dict]:
|
|
_, _, mss = import_capture_libs()
|
|
with mss.mss() as sct:
|
|
return ordered_displays(sct)
|
|
|
|
|
|
def select_display(screen: int) -> tuple[dict, list[dict], dict]:
|
|
displays = get_displays()
|
|
selected = displays[screen] if 0 <= screen < len(displays) else displays[0]
|
|
return selected, displays, {"requested": screen, "selected": selected["screen"], "fallback": selected["screen"] != screen}
|
|
|
|
|
|
def capture_screen(screen: int = 0):
|
|
Image, _, mss = import_capture_libs()
|
|
with mss.mss() as sct:
|
|
displays = ordered_displays(sct)
|
|
mon = displays[screen] if 0 <= screen < len(displays) else displays[0]
|
|
shot = sct.grab({"left": mon["x"], "top": mon["y"], "width": mon["width"], "height": mon["height"]})
|
|
image = Image.frombytes("RGB", shot.size, shot.rgb)
|
|
selection = {"requested": screen, "selected": mon["screen"], "fallback": mon["screen"] != screen}
|
|
return image, mon, displays, selection
|
|
|
|
|
|
def capture_region_image(screen: int, region_x: int | None, region_y: int | None, region_width: int | None, region_height: int | None):
|
|
base_img, mon, displays, screen_selection = capture_screen(screen)
|
|
if None in {region_x, region_y, region_width, region_height}:
|
|
return base_img, {"x": mon["x"], "y": mon["y"], "width": mon["width"], "height": mon["height"]}, mon, displays, screen_selection
|
|
|
|
left = region_x - mon["x"]
|
|
top = region_y - mon["y"]
|
|
right = left + region_width
|
|
bottom = top + region_height
|
|
if left < 0 or top < 0 or right > base_img.size[0] or bottom > base_img.size[1]:
|
|
raise HTTPException(status_code=400, detail="requested region is outside the captured monitor")
|
|
|
|
crop = base_img.crop((left, top, right, bottom))
|
|
return crop, {"x": region_x, "y": region_y, "width": region_width, "height": region_height}, mon, displays, screen_selection
|
|
|
|
|
|
def extract_ocr_items(image, origin_x: int, origin_y: int, min_confidence: float, lang: str, psm: int | None) -> list[dict]:
|
|
try:
|
|
import pytesseract
|
|
except Exception as exc:
|
|
api_error(503, "ocr_unavailable", f"pytesseract unavailable: {exc}")
|
|
|
|
config = ""
|
|
if psm is not None:
|
|
config = f"--psm {psm}"
|
|
try:
|
|
data = pytesseract.image_to_data(image, lang=lang, config=config, output_type=pytesseract.Output.DICT)
|
|
except Exception as exc:
|
|
api_error(503, "ocr_failed", f"ocr failed: {exc}")
|
|
|
|
out: list[dict] = []
|
|
n = len(data.get("text", []))
|
|
for i in range(n):
|
|
text = (data["text"][i] or "").strip()
|
|
if not text:
|
|
continue
|
|
try:
|
|
confidence = float(data["conf"][i])
|
|
except Exception:
|
|
continue
|
|
if confidence < min_confidence:
|
|
continue
|
|
left = int(data["left"][i])
|
|
top = int(data["top"][i])
|
|
width = int(data["width"][i])
|
|
height = int(data["height"][i])
|
|
bbox = {"x": origin_x + left, "y": origin_y + top, "width": width, "height": height}
|
|
center = {"x": bbox["x"] + (width // 2), "y": bbox["y"] + (height // 2)}
|
|
out.append(
|
|
{
|
|
"text": text,
|
|
"confidence": confidence,
|
|
"bbox": bbox,
|
|
"center": center,
|
|
"region_relative_bbox": {"x": left, "y": top, "width": width, "height": height},
|
|
}
|
|
)
|
|
return out
|
|
|
|
|
|
def serialize_image(image, image_format: str, jpeg_quality: int) -> bytes:
|
|
buf = io.BytesIO()
|
|
if image_format == "jpeg":
|
|
image.save(buf, format="JPEG", quality=jpeg_quality)
|
|
else:
|
|
image.save(buf, format="PNG")
|
|
return buf.getvalue()
|
|
|
|
|
|
def encode_image(image, image_format: str, jpeg_quality: int) -> str:
|
|
import base64
|
|
|
|
return base64.b64encode(serialize_image(image, image_format, jpeg_quality)).decode("ascii")
|
|
|
|
|
|
def draw_grid(image, region_x: int, region_y: int, rows: int, cols: int, include_labels: bool):
|
|
_, ImageDraw, _ = import_capture_libs()
|
|
out = image.copy()
|
|
draw = ImageDraw.Draw(out)
|
|
w, h = out.size
|
|
cell_w = w / cols
|
|
cell_h = h / rows
|
|
|
|
for c in range(1, cols):
|
|
x = int(round(c * cell_w))
|
|
draw.line([(x, 0), (x, h)], fill=(255, 0, 0), width=1)
|
|
for r in range(1, rows):
|
|
y = int(round(r * cell_h))
|
|
draw.line([(0, y), (w, y)], fill=(255, 0, 0), width=1)
|
|
|
|
draw.rectangle([(0, 0), (w - 1, h - 1)], outline=(255, 0, 0), width=2)
|
|
if include_labels:
|
|
for r in range(rows):
|
|
for c in range(cols):
|
|
cx = int((c + 0.5) * cell_w)
|
|
cy = int((r + 0.5) * cell_h)
|
|
draw.text((cx - 12, cy - 6), f"{r},{c}", fill=(255, 255, 0))
|
|
|
|
meta = {
|
|
"region": {"x": region_x, "y": region_y, "width": w, "height": h},
|
|
"grid": {
|
|
"rows": rows,
|
|
"cols": cols,
|
|
"cell_width": cell_w,
|
|
"cell_height": cell_h,
|
|
"indexing": "zero-based",
|
|
"point_formula": {
|
|
"pixel_x": "region.x + ((col + 0.5 + dx*0.5) * cell_width)",
|
|
"pixel_y": "region.y + ((row + 0.5 + dy*0.5) * cell_height)",
|
|
"dx_range": "[-1,1]",
|
|
"dy_range": "[-1,1]",
|
|
},
|
|
},
|
|
}
|
|
return out, meta
|
|
|
|
|
|
def resolve_target(target: Target) -> tuple[int, int, dict]:
|
|
if isinstance(target, PixelTarget):
|
|
x = target.x + target.dx
|
|
y = target.y + target.dy
|
|
return x, y, {"mode": "pixel", "source": target.model_dump()}
|
|
|
|
cell_w = target.region_width / target.cols
|
|
cell_h = target.region_height / target.rows
|
|
x = target.region_x + int(round((target.col + 0.5 + (target.dx * 0.5)) * cell_w))
|
|
y = target.region_y + int(round((target.row + 0.5 + (target.dy * 0.5)) * cell_h))
|
|
return x, y, {"mode": "grid", "source": target.model_dump(), "derived": {"cell_width": cell_w, "cell_height": cell_h}}
|
|
|
|
|
|
def enforce_allowed_region(x: int, y: int):
|
|
region = SETTINGS["allowed_region"]
|
|
if region is None:
|
|
return
|
|
rx, ry, rw, rh = region
|
|
if not (rx <= x < rx + rw and ry <= y < ry + rh):
|
|
raise HTTPException(status_code=403, detail="point outside allowed region")
|
|
|
|
|
|
def _text_matches(candidate: str, needle: str, mode: str, case_sensitive: bool) -> bool:
|
|
hay = candidate if case_sensitive else candidate.lower()
|
|
ndl = needle if case_sensitive else needle.lower()
|
|
if mode == "contains":
|
|
return ndl in hay
|
|
if mode == "exact":
|
|
return hay == ndl
|
|
flags = 0 if case_sensitive else re.IGNORECASE
|
|
return re.search(needle, candidate, flags=flags) is not None
|
|
|
|
|
|
def _resolve_text_match(click_text: ClickTextAction, items: list[dict]) -> dict:
|
|
matches = [item for item in items if _text_matches(item["text"], click_text.text, click_text.match, click_text.case_sensitive)]
|
|
if not matches:
|
|
candidates = [item["text"] for item in sorted(items, key=lambda v: v["confidence"], reverse=True)[:8]]
|
|
api_error(404, "ocr_text_not_found", "no OCR text matched", {"query": click_text.text, "candidates": candidates})
|
|
if click_text.occurrence == "best":
|
|
return max(matches, key=lambda item: item["confidence"])
|
|
if click_text.occurrence == "nth":
|
|
idx = (click_text.nth or 1) - 1
|
|
if idx >= len(matches):
|
|
api_error(409, "ocr_nth_out_of_range", "requested nth match is out of range", {"match_count": len(matches), "nth": click_text.nth})
|
|
return matches[idx]
|
|
if len(matches) > 1 and click_text.match == "exact":
|
|
api_error(
|
|
409,
|
|
"ocr_text_ambiguous",
|
|
"multiple OCR entries matched",
|
|
{"match_count": len(matches), "candidates": [item["text"] for item in matches[:8]]},
|
|
)
|
|
return matches[0]
|
|
|
|
|
|
def import_input_lib():
|
|
try:
|
|
import pyautogui
|
|
|
|
pyautogui.FAILSAFE = True
|
|
return pyautogui
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=f"input backend unavailable: {exc}") from exc
|
|
|
|
|
|
def exec_action(req: ActionRequest, screen: int = 0) -> dict:
|
|
run_dry = SETTINGS["dry_run"] or req.dry_run
|
|
action_screen = screen
|
|
if req.action == "click_text" and req.click_text and req.click_text.screen is not None:
|
|
action_screen = req.click_text.screen
|
|
selected_display, _, screen_selection = select_display(action_screen)
|
|
pyautogui = None if run_dry else import_input_lib()
|
|
resolved_target = None
|
|
|
|
if req.target is not None:
|
|
x, y, info = resolve_target(req.target)
|
|
enforce_allowed_region(x, y)
|
|
resolved_target = {"x": x, "y": y, "target_info": info}
|
|
|
|
duration_sec = req.duration_ms / 1000.0
|
|
if req.action in {"move", "click", "right_click", "double_click", "middle_click"} and resolved_target is None:
|
|
raise HTTPException(status_code=400, detail="target is required for pointer actions")
|
|
if req.action == "scroll" and resolved_target is None:
|
|
raise HTTPException(status_code=400, detail="target is required for scroll")
|
|
|
|
click_text_match = None
|
|
if req.action == "click_text":
|
|
if req.click_text is None:
|
|
api_error(400, "click_text_payload_required", "click_text payload is required")
|
|
region = req.click_text.region
|
|
img, captured_region, _, _, _ = capture_region_image(
|
|
action_screen,
|
|
None if region is None else region.x,
|
|
None if region is None else region.y,
|
|
None if region is None else region.width,
|
|
None if region is None else region.height,
|
|
)
|
|
items = extract_ocr_items(
|
|
img,
|
|
captured_region["x"],
|
|
captured_region["y"],
|
|
req.click_text.min_confidence,
|
|
req.click_text.ocr_lang,
|
|
req.click_text.ocr_psm,
|
|
)
|
|
matched = _resolve_text_match(req.click_text, items)
|
|
enforce_allowed_region(matched["center"]["x"], matched["center"]["y"])
|
|
click_text_match = {
|
|
"query": req.click_text.model_dump(),
|
|
"matched": matched,
|
|
"capture_region": captured_region,
|
|
"screen": screen_selection,
|
|
}
|
|
resolved_target = {"x": matched["center"]["x"], "y": matched["center"]["y"], "target_info": {"mode": "ocr_text"}}
|
|
|
|
if not run_dry:
|
|
if req.action == "move":
|
|
pyautogui.moveTo(resolved_target["x"], resolved_target["y"], duration=duration_sec)
|
|
elif req.action == "click":
|
|
pyautogui.click(x=resolved_target["x"], y=resolved_target["y"], clicks=req.clicks, interval=req.interval_ms / 1000.0, button=req.button, duration=duration_sec)
|
|
elif req.action == "right_click":
|
|
pyautogui.click(x=resolved_target["x"], y=resolved_target["y"], button="right", duration=duration_sec)
|
|
elif req.action == "double_click":
|
|
pyautogui.doubleClick(x=resolved_target["x"], y=resolved_target["y"], interval=req.interval_ms / 1000.0)
|
|
elif req.action == "middle_click":
|
|
pyautogui.click(x=resolved_target["x"], y=resolved_target["y"], button="middle", duration=duration_sec)
|
|
elif req.action == "scroll":
|
|
pyautogui.moveTo(resolved_target["x"], resolved_target["y"], duration=duration_sec)
|
|
pyautogui.scroll(req.scroll_amount)
|
|
elif req.action == "type":
|
|
pyautogui.write(req.text, interval=req.interval_ms / 1000.0)
|
|
elif req.action == "hotkey":
|
|
if len(req.keys) < 1:
|
|
raise HTTPException(status_code=400, detail="keys is required for hotkey")
|
|
pyautogui.hotkey(*req.keys)
|
|
elif req.action == "click_text":
|
|
pyautogui.click(
|
|
x=resolved_target["x"],
|
|
y=resolved_target["y"],
|
|
clicks=req.clicks,
|
|
interval=req.interval_ms / 1000.0,
|
|
button=req.button,
|
|
duration=duration_sec,
|
|
)
|
|
|
|
return {
|
|
"action": req.action,
|
|
"executed": not run_dry,
|
|
"dry_run": run_dry,
|
|
"screen": screen_selection,
|
|
"display": selected_display,
|
|
"resolved_target": resolved_target,
|
|
"click_text_match": click_text_match,
|
|
}
|
|
|
|
|
|
def windows_only(feature: str):
|
|
if sys.platform != "win32":
|
|
raise HTTPException(status_code=501, detail=f"{feature} is currently supported on Windows hosts only")
|
|
|
|
|
|
def tasklist_process_name(pid: int) -> str | None:
|
|
try:
|
|
completed = subprocess.run(["tasklist", "/FI", f"PID eq {pid}", "/FO", "CSV", "/NH"], capture_output=True, text=True, timeout=5, check=False)
|
|
except Exception:
|
|
return None
|
|
line = (completed.stdout or "").strip().splitlines()
|
|
if not line:
|
|
return None
|
|
row = line[0].strip()
|
|
if not row or row.startswith("INFO:"):
|
|
return None
|
|
if row.startswith('"') and '","' in row:
|
|
return row.split('","', 1)[0].strip('"')
|
|
return None
|
|
|
|
|
|
def list_windows(query: WindowQuery | None = None) -> list[dict]:
|
|
windows_only("window endpoints")
|
|
query = query or WindowQuery()
|
|
|
|
user32 = ctypes.windll.user32
|
|
kernel32 = ctypes.windll.kernel32
|
|
psapi = ctypes.windll.psapi
|
|
|
|
user32.GetWindowTextLengthW.argtypes = [ctypes.c_void_p]
|
|
user32.GetWindowTextLengthW.restype = ctypes.c_int
|
|
user32.GetWindowTextW.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_int]
|
|
user32.GetWindowTextW.restype = ctypes.c_int
|
|
user32.IsWindowVisible.argtypes = [ctypes.c_void_p]
|
|
user32.IsWindowVisible.restype = ctypes.c_bool
|
|
user32.IsWindowEnabled.argtypes = [ctypes.c_void_p]
|
|
user32.IsWindowEnabled.restype = ctypes.c_bool
|
|
user32.IsIconic.argtypes = [ctypes.c_void_p]
|
|
user32.IsIconic.restype = ctypes.c_bool
|
|
user32.IsZoomed.argtypes = [ctypes.c_void_p]
|
|
user32.IsZoomed.restype = ctypes.c_bool
|
|
user32.GetForegroundWindow.restype = ctypes.c_void_p
|
|
user32.GetWindowRect.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.wintypes.RECT)]
|
|
user32.GetWindowRect.restype = ctypes.c_bool
|
|
user32.GetClassNameW.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_int]
|
|
user32.GetClassNameW.restype = ctypes.c_int
|
|
|
|
kernel32.OpenProcess.argtypes = [ctypes.wintypes.DWORD, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD]
|
|
kernel32.OpenProcess.restype = ctypes.wintypes.HANDLE
|
|
kernel32.CloseHandle.argtypes = [ctypes.wintypes.HANDLE]
|
|
kernel32.CloseHandle.restype = ctypes.wintypes.BOOL
|
|
psapi.GetModuleBaseNameW.argtypes = [ctypes.wintypes.HANDLE, ctypes.wintypes.HMODULE, ctypes.c_wchar_p, ctypes.wintypes.DWORD]
|
|
psapi.GetModuleBaseNameW.restype = ctypes.wintypes.DWORD
|
|
|
|
foreground = int(user32.GetForegroundWindow() or 0)
|
|
results: list[dict] = []
|
|
|
|
def callback(hwnd, _lparam):
|
|
hwnd_int = int(hwnd)
|
|
if query.hwnd and hwnd_int != query.hwnd:
|
|
return True
|
|
visible = bool(user32.IsWindowVisible(hwnd))
|
|
if query.visible_only and not visible:
|
|
return True
|
|
|
|
length = user32.GetWindowTextLengthW(hwnd)
|
|
title_buf = ctypes.create_unicode_buffer(max(1, length + 1))
|
|
user32.GetWindowTextW(hwnd, title_buf, len(title_buf))
|
|
title = title_buf.value or ""
|
|
|
|
if query.title_contains and query.title_contains.lower() not in title.lower():
|
|
return True
|
|
if query.title_regex and re.search(query.title_regex, title, flags=re.IGNORECASE) is None:
|
|
return True
|
|
|
|
pid = ctypes.wintypes.DWORD(0)
|
|
user32.GetWindowThreadProcessId(hwnd, ctypes.byref(pid))
|
|
process_name = tasklist_process_name(pid.value)
|
|
if query.process_name and (process_name or "").lower() != query.process_name.lower():
|
|
return True
|
|
|
|
class_buf = ctypes.create_unicode_buffer(256)
|
|
user32.GetClassNameW(hwnd, class_buf, len(class_buf))
|
|
rect = ctypes.wintypes.RECT()
|
|
user32.GetWindowRect(hwnd, ctypes.byref(rect))
|
|
|
|
results.append(
|
|
{
|
|
"hwnd": hwnd_int,
|
|
"title": title,
|
|
"class_name": class_buf.value,
|
|
"pid": int(pid.value),
|
|
"process_name": process_name,
|
|
"visible": visible,
|
|
"enabled": bool(user32.IsWindowEnabled(hwnd)),
|
|
"minimized": bool(user32.IsIconic(hwnd)),
|
|
"maximized": bool(user32.IsZoomed(hwnd)),
|
|
"foreground": hwnd_int == foreground,
|
|
"rect": {"x": int(rect.left), "y": int(rect.top), "width": int(rect.right - rect.left), "height": int(rect.bottom - rect.top)},
|
|
}
|
|
)
|
|
return True
|
|
|
|
enum_proc = ctypes.WINFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_void_p)(callback)
|
|
user32.EnumWindows(enum_proc, 0)
|
|
results.sort(key=lambda item: (not item["foreground"], item["title"].lower(), item["hwnd"]))
|
|
return results
|
|
|
|
|
|
def _pick_single_window(query: WindowQuery) -> dict:
|
|
matches = list_windows(query)
|
|
if not matches:
|
|
raise HTTPException(status_code=404, detail="no window matched")
|
|
if len(matches) > 1:
|
|
raise HTTPException(status_code=409, detail={"message": "multiple windows matched", "matches": matches[:10]})
|
|
return matches[0]
|
|
|
|
|
|
def apply_window_action(req: WindowActionRequest) -> dict:
|
|
windows_only("window endpoints")
|
|
match = _pick_single_window(req)
|
|
hwnd = match["hwnd"]
|
|
user32 = ctypes.windll.user32
|
|
|
|
SW_RESTORE, SW_MINIMIZE, SW_MAXIMIZE = 9, 6, 3
|
|
WM_CLOSE = 0x0010
|
|
|
|
if req.action == "focus":
|
|
user32.ShowWindow(hwnd, SW_RESTORE)
|
|
ok = bool(user32.SetForegroundWindow(hwnd))
|
|
if not ok:
|
|
raise HTTPException(status_code=500, detail="failed to focus window")
|
|
elif req.action == "restore":
|
|
user32.ShowWindow(hwnd, SW_RESTORE)
|
|
elif req.action == "minimize":
|
|
user32.ShowWindow(hwnd, SW_MINIMIZE)
|
|
elif req.action == "maximize":
|
|
user32.ShowWindow(hwnd, SW_MAXIMIZE)
|
|
elif req.action == "close":
|
|
user32.PostMessageW(hwnd, WM_CLOSE, 0, 0)
|
|
|
|
deadline = time.time() + (req.timeout_ms / 1000.0)
|
|
final = None
|
|
while time.time() <= deadline:
|
|
current = list_windows(WindowQuery(hwnd=hwnd, visible_only=False))
|
|
if not current:
|
|
if req.action == "close":
|
|
return {"matched": match, "closed": True, "final": None}
|
|
time.sleep(0.05)
|
|
continue
|
|
final = current[0]
|
|
if req.action == "focus" and final.get("foreground"):
|
|
break
|
|
if req.action in {"restore", "minimize", "maximize"}:
|
|
break
|
|
time.sleep(0.05)
|
|
|
|
return {"matched": match, "closed": False, "final": final}
|
|
|
|
|
|
def launch_app(req: LaunchRequest) -> dict:
|
|
if req.cwd and not os.path.isdir(req.cwd):
|
|
raise HTTPException(status_code=400, detail="cwd does not exist or is not a directory")
|
|
argv = [req.executable, *req.args]
|
|
cwd = req.cwd or None
|
|
|
|
if req.dry_run or SETTINGS["dry_run"]:
|
|
return {"executed": False, "dry_run": True, "argv": argv, "cwd": cwd}
|
|
|
|
try:
|
|
proc = subprocess.Popen(argv, cwd=cwd)
|
|
except FileNotFoundError as exc:
|
|
raise HTTPException(status_code=400, detail=f"executable not found: {exc}") from exc
|
|
except OSError as exc:
|
|
raise HTTPException(status_code=400, detail=f"failed to launch process: {exc}") from exc
|
|
|
|
result = {"executed": True, "dry_run": False, "argv": argv, "cwd": cwd, "pid": proc.pid}
|
|
if req.wait_for_window:
|
|
query = req.match or WindowQuery(process_name=os.path.basename(req.executable), visible_only=True)
|
|
deadline = time.time() + (req.timeout_ms / 1000.0)
|
|
match = None
|
|
while time.time() <= deadline:
|
|
matches = list_windows(query)
|
|
if matches:
|
|
match = matches[0]
|
|
break
|
|
time.sleep(0.2)
|
|
result["window"] = match
|
|
result["window_found"] = match is not None
|
|
return result
|
|
|
|
|
|
def _truncate_text(text: str, limit: int) -> tuple[str, bool]:
|
|
if len(text) <= limit:
|
|
return text, False
|
|
return text[:limit], True
|
|
|
|
|
|
def _resolve_exec_program(shell_name: str, command: str) -> list[str]:
|
|
if shell_name == "powershell":
|
|
return ["powershell", "-NoProfile", "-NonInteractive", "-ExecutionPolicy", "Bypass", "-Command", command]
|
|
if shell_name == "bash":
|
|
return ["bash", "-lc", command]
|
|
if shell_name == "cmd":
|
|
return ["cmd", "/c", command]
|
|
raise HTTPException(status_code=400, detail="unsupported shell")
|
|
|
|
|
|
def exec_command(req):
|
|
if not SETTINGS["exec_enabled"]:
|
|
raise HTTPException(status_code=403, detail="exec endpoint disabled")
|
|
if not SETTINGS["exec_secret"]:
|
|
raise HTTPException(status_code=403, detail="exec secret not configured")
|
|
|
|
shell_name = (req.shell or SETTINGS["exec_default_shell"] or "powershell").lower().strip()
|
|
if shell_name not in {"powershell", "bash", "cmd"}:
|
|
raise HTTPException(status_code=400, detail="unsupported shell")
|
|
|
|
run_dry = SETTINGS["dry_run"] or req.dry_run
|
|
timeout_s = req.timeout_s if req.timeout_s is not None else SETTINGS["exec_default_timeout_s"]
|
|
timeout_s = min(timeout_s, SETTINGS["exec_max_timeout_s"])
|
|
|
|
cwd = None
|
|
if req.cwd:
|
|
cwd = os.path.abspath(req.cwd)
|
|
if not os.path.isdir(cwd):
|
|
raise HTTPException(status_code=400, detail="cwd does not exist or is not a directory")
|
|
|
|
argv = _resolve_exec_program(shell_name, req.command)
|
|
if run_dry:
|
|
return {"executed": False, "dry_run": True, "shell": shell_name, "command": req.command, "argv": argv, "timeout_s": timeout_s, "cwd": cwd}
|
|
|
|
start = time.time()
|
|
try:
|
|
completed = subprocess.run(argv, cwd=cwd, capture_output=True, text=True, timeout=timeout_s, check=False)
|
|
except subprocess.TimeoutExpired as exc:
|
|
stdout, stdout_truncated = _truncate_text(str(exc.stdout or ""), SETTINGS["exec_max_output_chars"])
|
|
stderr, stderr_truncated = _truncate_text(str(exc.stderr or ""), SETTINGS["exec_max_output_chars"])
|
|
return {"executed": True, "timed_out": True, "shell": shell_name, "command": req.command, "argv": argv, "timeout_s": timeout_s, "cwd": cwd, "duration_ms": int((time.time() - start) * 1000), "exit_code": None, "stdout": stdout, "stderr": stderr, "stdout_truncated": stdout_truncated, "stderr_truncated": stderr_truncated}
|
|
except FileNotFoundError as exc:
|
|
raise HTTPException(status_code=400, detail=f"shell executable not found: {exc}") from exc
|
|
|
|
stdout, stdout_truncated = _truncate_text(completed.stdout or "", SETTINGS["exec_max_output_chars"])
|
|
stderr, stderr_truncated = _truncate_text(completed.stderr or "", SETTINGS["exec_max_output_chars"])
|
|
return {"executed": True, "timed_out": False, "shell": shell_name, "command": req.command, "argv": argv, "timeout_s": timeout_s, "cwd": cwd, "duration_ms": int((time.time() - start) * 1000), "exit_code": completed.returncode, "stdout": stdout, "stderr": stderr, "stdout_truncated": stdout_truncated, "stderr_truncated": stderr_truncated}
|