Add explicit detectors and optional auth
This commit is contained in:
@@ -1 +1,4 @@
|
|||||||
ENV=dev
|
ENV=prod
|
||||||
|
# Optional auth
|
||||||
|
# FACE_LOCK_AUTH_TOKEN=change-me
|
||||||
|
# FACE_LOCK_AUTH_HEADER=X-API-Key
|
||||||
|
|||||||
34
README.md
34
README.md
@@ -2,7 +2,37 @@
|
|||||||
|
|
||||||
FastAPI microservice that finds the primary subject in an image, draws a square around it, and returns a buffered crop.
|
FastAPI microservice that finds the primary subject in an image, draws a square around it, and returns a buffered crop.
|
||||||
|
|
||||||
## Dev
|
## UI
|
||||||
|
|
||||||
|
The Tailwind UI is always available at `/`.
|
||||||
|
|
||||||
|
## Auth
|
||||||
|
|
||||||
|
Optional header auth is enabled when `FACE_LOCK_AUTH_TOKEN` is set.
|
||||||
|
|
||||||
|
- Default header: `X-API-Key`
|
||||||
|
- Alternate: `Authorization: Bearer <token>`
|
||||||
|
- Override the header name with `FACE_LOCK_AUTH_HEADER`
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
- `POST /api/focus`
|
||||||
|
- `POST /api/focus/image`
|
||||||
|
- `GET /health`
|
||||||
|
|
||||||
|
## Detectors
|
||||||
|
|
||||||
|
- `face`
|
||||||
|
- `animal`
|
||||||
|
- `person`
|
||||||
|
- `subject`
|
||||||
|
|
||||||
|
## Docs
|
||||||
|
|
||||||
|
- OpenAPI UI: `/docs`
|
||||||
|
- Project docs: `docs/README.md`
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
@@ -10,8 +40,6 @@ pip install -r requirements.txt
|
|||||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
```
|
```
|
||||||
|
|
||||||
Set `ENV=dev` to enable the Tailwind UI at `/`.
|
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,13 +1,20 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from dotenv import load_dotenv
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Settings:
|
class Settings:
|
||||||
env: str = os.getenv("ENV", "prod").strip().lower()
|
env: str = os.getenv("ENV", "prod").strip().lower()
|
||||||
|
auth_token: str = os.getenv("FACE_LOCK_AUTH_TOKEN", "").strip()
|
||||||
|
auth_header_name: str = os.getenv("FACE_LOCK_AUTH_HEADER", "X-API-Key").strip()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_enabled(self) -> bool:
|
||||||
|
return bool(self.auth_token)
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
61
app/main.py
61
app/main.py
@@ -1,21 +1,36 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||||
from fastapi.responses import HTMLResponse, StreamingResponse
|
from fastapi.responses import HTMLResponse, StreamingResponse
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
app = FastAPI(title="face-lock", version="0.2.0")
|
app = FastAPI(title="face-lock", version="0.3.0")
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(request: Request) -> None:
|
||||||
|
if not settings.auth_enabled:
|
||||||
|
return
|
||||||
|
header_name = settings.auth_header_name.lower()
|
||||||
|
provided = request.headers.get(header_name)
|
||||||
|
if not provided and request.headers.get("authorization", "").lower().startswith("bearer "):
|
||||||
|
provided = request.headers.get("authorization", "")[7:].strip()
|
||||||
|
if provided != settings.auth_token:
|
||||||
|
raise HTTPException(status_code=401, detail="unauthorized")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health():
|
def health():
|
||||||
return {"ok": True, "env": settings.env}
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"env": settings.env,
|
||||||
|
"auth_enabled": settings.auth_enabled,
|
||||||
|
"auth_header": settings.auth_header_name if settings.auth_enabled else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/", response_class=HTMLResponse)
|
@app.get("/", response_class=HTMLResponse)
|
||||||
def index():
|
def index():
|
||||||
if settings.env != "dev":
|
|
||||||
return HTMLResponse("<h1>face-lock</h1><p>Set ENV=dev for the UI.</p>")
|
|
||||||
return HTMLResponse(
|
return HTMLResponse(
|
||||||
"""
|
"""
|
||||||
<!doctype html>
|
<!doctype html>
|
||||||
@@ -28,9 +43,12 @@ def index():
|
|||||||
</head>
|
</head>
|
||||||
<body class="bg-slate-950 text-slate-100 min-h-screen">
|
<body class="bg-slate-950 text-slate-100 min-h-screen">
|
||||||
<main class="mx-auto max-w-6xl p-6">
|
<main class="mx-auto max-w-6xl p-6">
|
||||||
<div class="mb-6">
|
<div class="mb-6 flex items-center justify-between gap-4">
|
||||||
|
<div>
|
||||||
<h1 class="text-3xl font-bold">face-lock</h1>
|
<h1 class="text-3xl font-bold">face-lock</h1>
|
||||||
<p class="text-slate-400">Auto-detect the subject, square it up, and crop with buffer.</p>
|
<p class="text-slate-400">Square the subject, crop it, and keep the raw blobs out of sight.</p>
|
||||||
|
</div>
|
||||||
|
<a class="rounded-lg border border-slate-700 px-3 py-2 text-sm text-cyan-300 hover:bg-slate-900" href="/docs" target="_blank">Docs</a>
|
||||||
</div>
|
</div>
|
||||||
<div class="grid gap-6 md:grid-cols-2">
|
<div class="grid gap-6 md:grid-cols-2">
|
||||||
<section class="rounded-2xl border border-slate-800 bg-slate-900 p-4">
|
<section class="rounded-2xl border border-slate-800 bg-slate-900 p-4">
|
||||||
@@ -40,10 +58,10 @@ def index():
|
|||||||
<div>
|
<div>
|
||||||
<label class="block text-sm text-slate-400">Detector</label>
|
<label class="block text-sm text-slate-400">Detector</label>
|
||||||
<select id="detector" class="mt-2 block w-full rounded-lg border border-slate-700 bg-slate-950 p-3">
|
<select id="detector" class="mt-2 block w-full rounded-lg border border-slate-700 bg-slate-950 p-3">
|
||||||
<option value="auto">Auto</option>
|
|
||||||
<option value="face">Face</option>
|
<option value="face">Face</option>
|
||||||
|
<option value="animal">Animal</option>
|
||||||
<option value="person">Person</option>
|
<option value="person">Person</option>
|
||||||
<option value="salient">Subject</option>
|
<option value="subject" selected>Subject</option>
|
||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
@@ -51,6 +69,10 @@ def index():
|
|||||||
<input id="buffer_ratio" type="number" step="0.05" min="0" max="0.6" value="0.20" class="mt-2 block w-full rounded-lg border border-slate-700 bg-slate-950 p-3" />
|
<input id="buffer_ratio" type="number" step="0.05" min="0" max="0.6" value="0.20" class="mt-2 block w-full rounded-lg border border-slate-700 bg-slate-950 p-3" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="mt-4">
|
||||||
|
<label class="block text-sm text-slate-400">Auth token (only if enabled)</label>
|
||||||
|
<input id="auth_token" type="password" placeholder="paste token here" class="mt-2 block w-full rounded-lg border border-slate-700 bg-slate-950 p-3" />
|
||||||
|
</div>
|
||||||
<button id="go" class="mt-4 rounded-lg bg-cyan-500 px-4 py-2 font-semibold text-slate-950">Process</button>
|
<button id="go" class="mt-4 rounded-lg bg-cyan-500 px-4 py-2 font-semibold text-slate-950">Process</button>
|
||||||
<pre id="meta" class="mt-4 whitespace-pre-wrap rounded-lg bg-slate-950 p-3 text-xs text-slate-300"></pre>
|
<pre id="meta" class="mt-4 whitespace-pre-wrap rounded-lg bg-slate-950 p-3 text-xs text-slate-300"></pre>
|
||||||
</section>
|
</section>
|
||||||
@@ -82,8 +104,15 @@ def index():
|
|||||||
form.append('detector', document.getElementById('detector').value);
|
form.append('detector', document.getElementById('detector').value);
|
||||||
form.append('buffer_ratio', document.getElementById('buffer_ratio').value);
|
form.append('buffer_ratio', document.getElementById('buffer_ratio').value);
|
||||||
meta.textContent = 'Working...';
|
meta.textContent = 'Working...';
|
||||||
const resp = await fetch('/api/focus', { method: 'POST', body: form });
|
const headers = {};
|
||||||
|
const token = document.getElementById('auth_token').value.trim();
|
||||||
|
if (token) headers['__AUTH_HEADER_NAME__'] = token;
|
||||||
|
const resp = await fetch('/api/focus', { method: 'POST', body: form, headers });
|
||||||
const data = await resp.json();
|
const data = await resp.json();
|
||||||
|
if (!resp.ok) {
|
||||||
|
meta.textContent = JSON.stringify(data, null, 2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
meta.textContent = JSON.stringify({
|
meta.textContent = JSON.stringify({
|
||||||
filename: data.filename,
|
filename: data.filename,
|
||||||
detector: data.detector,
|
detector: data.detector,
|
||||||
@@ -101,15 +130,17 @@ def index():
|
|||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
"""
|
""".replace("__AUTH_HEADER_NAME__", settings.auth_header_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/focus")
|
@app.post("/api/focus")
|
||||||
async def focus(
|
async def focus(
|
||||||
|
request: Request,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
buffer_ratio: float = Form(0.15),
|
buffer_ratio: float = Form(0.15),
|
||||||
detector: str = Form("auto"),
|
detector: str = Form("subject"),
|
||||||
|
_auth: None = Depends(require_auth),
|
||||||
):
|
):
|
||||||
from app.vision import process_image
|
from app.vision import process_image
|
||||||
|
|
||||||
@@ -133,15 +164,17 @@ async def focus(
|
|||||||
|
|
||||||
@app.post("/api/focus/image")
|
@app.post("/api/focus/image")
|
||||||
async def focus_image(
|
async def focus_image(
|
||||||
|
request: Request,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
buffer_ratio: float = Form(0.15),
|
buffer_ratio: float = Form(0.15),
|
||||||
detector: str = Form("auto"),
|
detector: str = Form("subject"),
|
||||||
|
_auth: None = Depends(require_auth),
|
||||||
):
|
):
|
||||||
from app.vision import process_image
|
from app.vision import process_image
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = await file.read()
|
payload = await file.read()
|
||||||
result = process_image(payload, file.filename or "upload", buffer_ratio=buffer_ratio, detector=detector)
|
result = process_image(payload, file.filename or "upload", buffer_ratio=buffer_ratio, detector=detector)
|
||||||
return StreamingResponse(BytesIO(result["crop_bytes"]), media_type=result["mime_type"])
|
return StreamingResponse(BytesIO(result["crop_bytes"]), media_type="image/jpeg")
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ class BBox:
|
|||||||
return self.y + self.h
|
return self.y + self.h
|
||||||
|
|
||||||
|
|
||||||
FACE_CASCADE = cv2.CascadeClassifier(
|
HAAR_DIR = Path(cv2.data.haarcascades)
|
||||||
str(Path(cv2.data.haarcascades) / "haarcascade_frontalface_default.xml")
|
FACE_CASCADE = cv2.CascadeClassifier(str(HAAR_DIR / "haarcascade_frontalface_default.xml"))
|
||||||
)
|
CAT_CASCADE = cv2.CascadeClassifier(str(HAAR_DIR / "haarcascade_frontalcatface_extended.xml"))
|
||||||
HOG = cv2.HOGDescriptor()
|
HOG = cv2.HOGDescriptor()
|
||||||
HOG.setSVMDetector(cv2.HOGDescriptor_getDefaultPeopleDetector())
|
HOG.setSVMDetector(cv2.HOGDescriptor_getDefaultPeopleDetector())
|
||||||
|
|
||||||
@@ -39,40 +39,23 @@ def decode_image(image_bytes: bytes) -> np.ndarray:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def select_primary_bbox(image: np.ndarray, detector: str = "auto") -> tuple[BBox, str]:
|
def select_primary_bbox(image: np.ndarray, detector: str = "subject") -> tuple[BBox, str]:
|
||||||
detector = (detector or "auto").strip().lower()
|
detector = (detector or "subject").strip().lower()
|
||||||
|
|
||||||
if detector == "face":
|
if detector == "face":
|
||||||
face_bbox = detect_face(image)
|
bbox = detect_face(image)
|
||||||
if face_bbox is not None:
|
return (bbox, "face_cascade") if bbox is not None else (fallback_bbox(image), "center_fallback")
|
||||||
return face_bbox, "face_cascade"
|
|
||||||
return fallback_bbox(image), "center_fallback"
|
|
||||||
|
|
||||||
if detector == "person":
|
if detector == "person":
|
||||||
person_bbox = detect_person(image)
|
bbox = detect_person(image)
|
||||||
if person_bbox is not None:
|
return (bbox, "person_hog") if bbox is not None else (fallback_bbox(image), "center_fallback")
|
||||||
return person_bbox, "person_hog"
|
|
||||||
return fallback_bbox(image), "center_fallback"
|
|
||||||
|
|
||||||
if detector == "salient":
|
if detector == "animal":
|
||||||
salient_bbox = detect_salient_object(image)
|
bbox = detect_animal(image)
|
||||||
if salient_bbox is not None:
|
return (bbox, "animal_cascade") if bbox is not None else (fallback_bbox(image), "center_fallback")
|
||||||
return salient_bbox, "salient_contour"
|
|
||||||
return fallback_bbox(image), "center_fallback"
|
|
||||||
|
|
||||||
face_bbox = detect_face(image)
|
bbox = detect_subject(image)
|
||||||
if face_bbox is not None:
|
return (bbox, "subject_contour") if bbox is not None else (fallback_bbox(image), "center_fallback")
|
||||||
return face_bbox, "face_cascade"
|
|
||||||
|
|
||||||
person_bbox = detect_person(image)
|
|
||||||
if person_bbox is not None:
|
|
||||||
return person_bbox, "person_hog"
|
|
||||||
|
|
||||||
salient_bbox = detect_salient_object(image)
|
|
||||||
if salient_bbox is not None:
|
|
||||||
return salient_bbox, "salient_contour"
|
|
||||||
|
|
||||||
return fallback_bbox(image), "center_fallback"
|
|
||||||
|
|
||||||
|
|
||||||
def fallback_bbox(image: np.ndarray) -> BBox:
|
def fallback_bbox(image: np.ndarray) -> BBox:
|
||||||
@@ -95,19 +78,39 @@ def detect_face(image: np.ndarray) -> BBox | None:
|
|||||||
|
|
||||||
|
|
||||||
def detect_person(image: np.ndarray) -> BBox | None:
|
def detect_person(image: np.ndarray) -> BBox | None:
|
||||||
rects, weights = HOG.detectMultiScale(image, winStride=(8, 8), padding=(8, 8), scale=1.05)
|
rects, _ = HOG.detectMultiScale(image, winStride=(8, 8), padding=(8, 8), scale=1.05)
|
||||||
if len(rects) == 0:
|
if len(rects) == 0:
|
||||||
return None
|
return None
|
||||||
best = max(zip(rects, weights), key=lambda item: int(item[0][2]) * int(item[0][3]))[0]
|
best = max((tuple(map(int, rect)) for rect in rects), key=lambda rect: rect[2] * rect[3])
|
||||||
x, y, w, h = map(int, best)
|
return BBox(x=best[0], y=best[1], w=best[2], h=best[3])
|
||||||
|
|
||||||
|
|
||||||
|
def detect_animal(image: np.ndarray) -> BBox | None:
|
||||||
|
if CAT_CASCADE.empty():
|
||||||
|
return detect_subject(image, min_area_ratio=0.02, blur_size=7, dilate_size=11)
|
||||||
|
|
||||||
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||||
|
gray = cv2.equalizeHist(gray)
|
||||||
|
cats = CAT_CASCADE.detectMultiScale(gray, scaleFactor=1.08, minNeighbors=4, minSize=(24, 24))
|
||||||
|
if len(cats) > 0:
|
||||||
|
x, y, w, h = max((tuple(map(int, cat)) for cat in cats), key=lambda rect: rect[2] * rect[3])
|
||||||
return BBox(x=x, y=y, w=w, h=h)
|
return BBox(x=x, y=y, w=w, h=h)
|
||||||
|
|
||||||
|
return detect_subject(image, min_area_ratio=0.02, blur_size=7, dilate_size=11)
|
||||||
|
|
||||||
def detect_salient_object(image: np.ndarray) -> BBox | None:
|
|
||||||
|
def detect_subject(
|
||||||
|
image: np.ndarray,
|
||||||
|
min_area_ratio: float = 0.015,
|
||||||
|
blur_size: int = 9,
|
||||||
|
dilate_size: int = 13,
|
||||||
|
) -> BBox | None:
|
||||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||||
blurred = cv2.GaussianBlur(gray, (9, 9), 0)
|
blur_size = blur_size + (blur_size % 2 == 0)
|
||||||
|
dilate_size = max(3, dilate_size)
|
||||||
|
kernel = np.ones((dilate_size, dilate_size), np.uint8)
|
||||||
|
blurred = cv2.GaussianBlur(gray, (blur_size, blur_size), 0)
|
||||||
edges = cv2.Canny(blurred, 30, 110)
|
edges = cv2.Canny(blurred, 30, 110)
|
||||||
kernel = np.ones((13, 13), np.uint8)
|
|
||||||
expanded = cv2.dilate(edges, kernel, iterations=1)
|
expanded = cv2.dilate(edges, kernel, iterations=1)
|
||||||
closed = cv2.morphologyEx(expanded, cv2.MORPH_CLOSE, kernel, iterations=1)
|
closed = cv2.morphologyEx(expanded, cv2.MORPH_CLOSE, kernel, iterations=1)
|
||||||
contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
@@ -120,7 +123,7 @@ def detect_salient_object(image: np.ndarray) -> BBox | None:
|
|||||||
for contour in contours:
|
for contour in contours:
|
||||||
x, y, bw, bh = cv2.boundingRect(contour)
|
x, y, bw, bh = cv2.boundingRect(contour)
|
||||||
area = bw * bh
|
area = bw * bh
|
||||||
if area < max(1000, int(image_area * 0.015)):
|
if area < max(1000, int(image_area * min_area_ratio)):
|
||||||
continue
|
continue
|
||||||
candidates.append((area, BBox(x=x, y=y, w=bw, h=bh)))
|
candidates.append((area, BBox(x=x, y=y, w=bw, h=bh)))
|
||||||
|
|
||||||
@@ -170,7 +173,7 @@ def _data_url(image_bytes: bytes, mime_type: str) -> str:
|
|||||||
return f"data:{mime_type};base64,{base64.b64encode(image_bytes).decode('ascii')}"
|
return f"data:{mime_type};base64,{base64.b64encode(image_bytes).decode('ascii')}"
|
||||||
|
|
||||||
|
|
||||||
def process_image(image_bytes: bytes, filename: str, buffer_ratio: float = 0.15, detector: str = "auto") -> dict[str, Any]:
|
def process_image(image_bytes: bytes, filename: str, buffer_ratio: float = 0.15, detector: str = "subject") -> dict[str, Any]:
|
||||||
image = decode_image(image_bytes)
|
image = decode_image(image_bytes)
|
||||||
bbox, method = select_primary_bbox(image, detector=detector)
|
bbox, method = select_primary_bbox(image, detector=detector)
|
||||||
square = square_bbox(bbox, image.shape, buffer_ratio=buffer_ratio)
|
square = square_bbox(bbox, image.shape, buffer_ratio=buffer_ratio)
|
||||||
|
|||||||
42
docs/README.md
Normal file
42
docs/README.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# face-lock docs
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
face-lock is a FastAPI service that detects a primary subject, makes a square crop, and returns both a crop and an annotated preview.
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
|
||||||
|
- `GET /health`
|
||||||
|
- `GET /`
|
||||||
|
- `POST /api/focus`
|
||||||
|
- `POST /api/focus/image`
|
||||||
|
- `GET /docs`
|
||||||
|
|
||||||
|
## Detectors
|
||||||
|
|
||||||
|
- `face` for human faces
|
||||||
|
- `animal` for pets / animals, with a contour fallback
|
||||||
|
- `person` for full-body person detection
|
||||||
|
- `subject` for general foreground subjects
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
Set `FACE_LOCK_AUTH_TOKEN` to require a header token.
|
||||||
|
|
||||||
|
Supported headers:
|
||||||
|
|
||||||
|
- `X-API-Key: <token>`
|
||||||
|
- `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
Optional override:
|
||||||
|
|
||||||
|
- `FACE_LOCK_AUTH_HEADER` changes the expected header name.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -H 'X-API-Key: your-token' \
|
||||||
|
-F 'file=@image.jpg' \
|
||||||
|
-F 'detector=animal' \
|
||||||
|
http://localhost:8000/api/focus
|
||||||
|
```
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from app.vision import BBox, crop_image, detect_salient_object, select_primary_bbox, square_bbox
|
from app.vision import BBox, crop_image, detect_animal, detect_subject, select_primary_bbox, square_bbox
|
||||||
|
|
||||||
|
|
||||||
def test_square_bbox_is_square_and_inside_bounds():
|
def test_square_bbox_is_square_and_inside_bounds():
|
||||||
@@ -21,17 +21,25 @@ def test_crop_image_uses_bbox():
|
|||||||
assert crop.shape[:2] == (20, 30)
|
assert crop.shape[:2] == (20, 30)
|
||||||
|
|
||||||
|
|
||||||
def test_detect_salient_object_finds_rectangle():
|
def test_detect_subject_finds_rectangle():
|
||||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
image[25:75, 30:80] = 255
|
image[25:75, 30:80] = 255
|
||||||
bbox = detect_salient_object(image)
|
bbox = detect_subject(image)
|
||||||
assert bbox is not None
|
assert bbox is not None
|
||||||
assert bbox.w >= 45
|
assert bbox.w >= 45
|
||||||
assert bbox.h >= 45
|
assert bbox.h >= 45
|
||||||
|
|
||||||
|
|
||||||
def test_select_primary_bbox_falls_back_when_detector_disabled():
|
def test_detect_animal_uses_contour_fallback():
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
image[20:70, 15:85] = 255
|
||||||
|
bbox = detect_animal(image)
|
||||||
|
assert bbox is not None
|
||||||
|
assert bbox.w >= 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_primary_bbox_defaults_to_subject():
|
||||||
image = np.zeros((100, 120, 3), dtype=np.uint8)
|
image = np.zeros((100, 120, 3), dtype=np.uint8)
|
||||||
bbox, method = select_primary_bbox(image, detector="center")
|
bbox, method = select_primary_bbox(image)
|
||||||
assert method == "center_fallback"
|
assert method == "center_fallback"
|
||||||
assert bbox.w == bbox.h
|
assert bbox.w == bbox.h
|
||||||
|
|||||||
Reference in New Issue
Block a user