Compare commits
2 Commits
44af756bd3
...
869a70b621
| Author | SHA1 | Date | |
|---|---|---|---|
| 869a70b621 | |||
| 1c6415d306 |
@@ -17,6 +17,7 @@ CONTENT_TYPES = {
|
||||
}
|
||||
|
||||
app = FastAPI(title="whisper-remote-backend")
|
||||
WHISPER_PROCESS_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
def validate_output_format(output_format: str) -> str:
|
||||
@@ -112,12 +113,21 @@ async def transcribe(
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=WHISPER_PROCESS_TIMEOUT_SECONDS,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="The 'whisper' CLI was not found on PATH on the backend host.",
|
||||
) from exc
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail=(
|
||||
"Whisper CLI timed out after "
|
||||
f"{WHISPER_PROCESS_TIMEOUT_SECONDS}s and was terminated."
|
||||
),
|
||||
) from exc
|
||||
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip() or "Whisper CLI failed."
|
||||
|
||||
@@ -64,3 +64,22 @@ def test_transcriptions_maps_subprocess_failure(monkeypatch) -> None:
|
||||
|
||||
assert response.status_code == 502
|
||||
assert response.json()["detail"] == "bad whisper day"
|
||||
|
||||
|
||||
def test_transcriptions_maps_subprocess_timeout(monkeypatch) -> None:
|
||||
def fake_run(command: list[str], check: bool, capture_output: bool, text: bool, timeout: int):
|
||||
raise server.subprocess.TimeoutExpired(cmd=command, timeout=timeout)
|
||||
|
||||
monkeypatch.setattr(server.subprocess, "run", fake_run)
|
||||
|
||||
response = client.post(
|
||||
"/transcriptions",
|
||||
data={"model": "base", "output_format": "txt"},
|
||||
files={"file": ("clip.wav", b"audio", "audio/wav")},
|
||||
)
|
||||
|
||||
assert response.status_code == 504
|
||||
assert (
|
||||
response.json()["detail"]
|
||||
== f"Whisper CLI timed out after {server.WHISPER_PROCESS_TIMEOUT_SECONDS}s and was terminated."
|
||||
)
|
||||
|
||||
@@ -64,6 +64,14 @@ def format_http_error(response: httpx.Response, endpoint: str) -> str:
|
||||
return f"HTTP {response.status_code} from {endpoint}: {body}"
|
||||
|
||||
|
||||
def format_request_error(exc: httpx.RequestError, endpoint: str) -> str:
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
return f"Request to {endpoint} timed out."
|
||||
|
||||
reason = str(exc).strip() or exc.__class__.__name__
|
||||
return f"Request to {endpoint} failed: {reason}"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
@@ -75,6 +83,7 @@ def main() -> int:
|
||||
server = resolve_server(args)
|
||||
endpoint = f"{server}/transcriptions"
|
||||
|
||||
try:
|
||||
with input_file.open("rb") as handle, httpx.Client(timeout=300.0) as client:
|
||||
response = client.post(
|
||||
endpoint,
|
||||
@@ -85,6 +94,8 @@ def main() -> int:
|
||||
},
|
||||
files={"file": (input_file.name, handle, "application/octet-stream")},
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
parser.exit(1, f"{format_request_error(exc, endpoint)}\n")
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -44,3 +44,20 @@ def test_format_http_error_with_empty_body() -> None:
|
||||
response = httpx.Response(500, text="", request=request)
|
||||
message = main.format_http_error(response, "http://localhost:8000/transcriptions")
|
||||
assert message == "HTTP 500 from http://localhost:8000/transcriptions: <empty response body>"
|
||||
|
||||
|
||||
def test_format_request_error_timeout() -> None:
|
||||
request = httpx.Request("POST", "http://localhost:8000/transcriptions")
|
||||
exc = httpx.ReadTimeout("read timed out", request=request)
|
||||
message = main.format_request_error(exc, "http://localhost:8000/transcriptions")
|
||||
assert message == "Request to http://localhost:8000/transcriptions timed out."
|
||||
|
||||
|
||||
def test_format_request_error_network_failure() -> None:
|
||||
request = httpx.Request("POST", "http://localhost:8000/transcriptions")
|
||||
exc = httpx.ConnectError("connection refused", request=request)
|
||||
message = main.format_request_error(exc, "http://localhost:8000/transcriptions")
|
||||
assert (
|
||||
message
|
||||
== "Request to http://localhost:8000/transcriptions failed: connection refused"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user