145 lines
4.3 KiB
Python
145 lines
4.3 KiB
Python
from __future__ import annotations
|
|
|
|
import subprocess
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
from fastapi.responses import Response
|
|
|
|
SUPPORTED_FORMATS = {"txt", "vtt", "srt", "tsv", "json"}
|
|
CONTENT_TYPES = {
|
|
"txt": "text/plain; charset=utf-8",
|
|
"vtt": "text/vtt; charset=utf-8",
|
|
"srt": "application/x-subrip; charset=utf-8",
|
|
"tsv": "text/tab-separated-values; charset=utf-8",
|
|
"json": "application/json; charset=utf-8",
|
|
}
|
|
|
|
app = FastAPI(title="whisper-remote-backend")
|
|
|
|
|
|
def validate_output_format(output_format: str) -> str:
|
|
normalized = output_format.strip().lower()
|
|
if normalized not in SUPPORTED_FORMATS:
|
|
supported = ", ".join(sorted(SUPPORTED_FORMATS))
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported output format '{output_format}'. Supported formats: {supported}.",
|
|
)
|
|
return normalized
|
|
|
|
|
|
def build_whisper_command(
|
|
*,
|
|
input_path: Path,
|
|
output_dir: Path,
|
|
model: str,
|
|
language: str | None,
|
|
output_format: str,
|
|
) -> list[str]:
|
|
command = [
|
|
"whisper",
|
|
str(input_path),
|
|
"--model",
|
|
model,
|
|
"--output_format",
|
|
output_format,
|
|
"--output_dir",
|
|
str(output_dir),
|
|
]
|
|
if language:
|
|
command.extend(["--language", language])
|
|
return command
|
|
|
|
|
|
async def save_upload(upload: UploadFile, destination: Path) -> None:
|
|
with destination.open("wb") as handle:
|
|
while chunk := await upload.read(1024 * 1024):
|
|
handle.write(chunk)
|
|
await upload.close()
|
|
|
|
|
|
def find_transcript_file(output_dir: Path, input_name: str, output_format: str) -> Path:
|
|
expected = output_dir / f"{Path(input_name).stem}.{output_format}"
|
|
if expected.exists():
|
|
return expected
|
|
|
|
matches = list(output_dir.glob(f"*.{output_format}"))
|
|
if len(matches) == 1:
|
|
return matches[0]
|
|
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Whisper finished without producing the expected output file.",
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
def healthcheck() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/transcriptions")
|
|
async def transcribe(
|
|
file: UploadFile = File(...),
|
|
model: str = Form(...),
|
|
language: str | None = Form(default=None),
|
|
output_format: str = Form(...),
|
|
) -> Response:
|
|
normalized_format = validate_output_format(output_format)
|
|
if not file.filename:
|
|
raise HTTPException(status_code=400, detail="Uploaded file must have a filename.")
|
|
|
|
with TemporaryDirectory(prefix="whisper-remote-upload-") as upload_root, TemporaryDirectory(
|
|
prefix="whisper-remote-output-"
|
|
) as output_root:
|
|
input_path = Path(upload_root) / Path(file.filename).name
|
|
output_dir = Path(output_root)
|
|
await save_upload(file, input_path)
|
|
|
|
command = build_whisper_command(
|
|
input_path=input_path,
|
|
output_dir=output_dir,
|
|
model=model,
|
|
language=language,
|
|
output_format=normalized_format,
|
|
)
|
|
|
|
try:
|
|
completed = subprocess.run(
|
|
command,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
except FileNotFoundError as exc:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="The 'whisper' CLI was not found on PATH on the backend host.",
|
|
) from exc
|
|
|
|
if completed.returncode != 0:
|
|
detail = completed.stderr.strip() or completed.stdout.strip() or "Whisper CLI failed."
|
|
raise HTTPException(status_code=502, detail=detail)
|
|
|
|
transcript_path = find_transcript_file(output_dir, file.filename, normalized_format)
|
|
content = transcript_path.read_bytes()
|
|
download_name = f"{Path(file.filename).stem}.{normalized_format}"
|
|
|
|
return Response(
|
|
content=content,
|
|
media_type=CONTENT_TYPES[normalized_format],
|
|
headers={
|
|
"Content-Disposition": f'attachment; filename="{download_name}"',
|
|
"X-Whisper-Output-Format": normalized_format,
|
|
"X-Whisper-Model": model,
|
|
},
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
import uvicorn
|
|
|
|
uvicorn.run("whisper_remote_backend.server:app", host="0.0.0.0", port=8000)
|