diff --git a/backend/src/server.py b/backend/src/server.py index 1368c91..d871a76 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -17,6 +17,7 @@ CONTENT_TYPES = { } app = FastAPI(title="whisper-remote-backend") +WHISPER_PROCESS_TIMEOUT_SECONDS = 300 def validate_output_format(output_format: str) -> str: @@ -112,12 +113,21 @@ async def transcribe( check=False, capture_output=True, text=True, + timeout=WHISPER_PROCESS_TIMEOUT_SECONDS, ) except FileNotFoundError as exc: raise HTTPException( status_code=500, detail="The 'whisper' CLI was not found on PATH on the backend host.", ) from exc + except subprocess.TimeoutExpired as exc: + raise HTTPException( + status_code=504, + detail=( + "Whisper CLI timed out after " + f"{WHISPER_PROCESS_TIMEOUT_SECONDS}s and was terminated." + ), + ) from exc if completed.returncode != 0: detail = completed.stderr.strip() or completed.stdout.strip() or "Whisper CLI failed." diff --git a/backend/tests/test_server.py b/backend/tests/test_server.py index 59baf5d..1e95638 100644 --- a/backend/tests/test_server.py +++ b/backend/tests/test_server.py @@ -64,3 +64,22 @@ def test_transcriptions_maps_subprocess_failure(monkeypatch) -> None: assert response.status_code == 502 assert response.json()["detail"] == "bad whisper day" + + +def test_transcriptions_maps_subprocess_timeout(monkeypatch) -> None: + def fake_run(command: list[str], check: bool, capture_output: bool, text: bool, timeout: int): + raise server.subprocess.TimeoutExpired(cmd=command, timeout=timeout) + + monkeypatch.setattr(server.subprocess, "run", fake_run) + + response = client.post( + "/transcriptions", + data={"model": "base", "output_format": "txt"}, + files={"file": ("clip.wav", b"audio", "audio/wav")}, + ) + + assert response.status_code == 504 + assert ( + response.json()["detail"] + == f"Whisper CLI timed out after {server.WHISPER_PROCESS_TIMEOUT_SECONDS}s and was terminated." + ) diff --git a/cli/src/main.py b/cli/src/main.py index 5ddbe54..ae533db 100644 --- a/cli/src/main.py +++ b/cli/src/main.py @@ -64,6 +64,14 @@ def format_http_error(response: httpx.Response, endpoint: str) -> str: return f"HTTP {response.status_code} from {endpoint}: {body}" +def format_request_error(exc: httpx.RequestError, endpoint: str) -> str: + if isinstance(exc, httpx.TimeoutException): + return f"Request to {endpoint} timed out." + + reason = str(exc).strip() or exc.__class__.__name__ + return f"Request to {endpoint} failed: {reason}" + + def main() -> int: parser = build_parser() args = parser.parse_args() @@ -75,16 +83,19 @@ def main() -> int: server = resolve_server(args) endpoint = f"{server}/transcriptions" - with input_file.open("rb") as handle, httpx.Client(timeout=300.0) as client: - response = client.post( - endpoint, - data={ - "model": args.model, - "language": args.language or "", - "output_format": args.output_format, - }, - files={"file": (input_file.name, handle, "application/octet-stream")}, - ) + try: + with input_file.open("rb") as handle, httpx.Client(timeout=300.0) as client: + response = client.post( + endpoint, + data={ + "model": args.model, + "language": args.language or "", + "output_format": args.output_format, + }, + files={"file": (input_file.name, handle, "application/octet-stream")}, + ) + except httpx.RequestError as exc: + parser.exit(1, f"{format_request_error(exc, endpoint)}\n") try: response.raise_for_status() diff --git a/cli/tests/test_main.py b/cli/tests/test_main.py index b1c4c1e..42fca89 100644 --- a/cli/tests/test_main.py +++ b/cli/tests/test_main.py @@ -44,3 +44,20 @@ def test_format_http_error_with_empty_body() -> None: response = httpx.Response(500, text="", request=request) message = main.format_http_error(response, "http://localhost:8000/transcriptions") assert message == "HTTP 500 from http://localhost:8000/transcriptions: " + + +def test_format_request_error_timeout() -> None: + request = httpx.Request("POST", "http://localhost:8000/transcriptions") + exc = httpx.ReadTimeout("read timed out", request=request) + message = main.format_request_error(exc, "http://localhost:8000/transcriptions") + assert message == "Request to http://localhost:8000/transcriptions timed out." + + +def test_format_request_error_network_failure() -> None: + request = httpx.Request("POST", "http://localhost:8000/transcriptions") + exc = httpx.ConnectError("connection refused", request=request) + message = main.format_request_error(exc, "http://localhost:8000/transcriptions") + assert ( + message + == "Request to http://localhost:8000/transcriptions failed: connection refused" + )