Compare commits

..

3 Commits

Author SHA1 Message Date
Space-Banane
972ccce62a Add optional CSRF enforcement for write routes
All checks were successful
docker / build-and-push (push) Successful in 49s
2026-05-20 21:55:30 +02:00
Space-Banane
7c06d31ac1 Validate link payloads and icon uploads 2026-05-20 21:54:53 +02:00
Space-Banane
ed886c956d Add login rate limiting with lockout window 2026-05-20 21:54:28 +02:00
2 changed files with 111 additions and 9 deletions

View File

@@ -36,6 +36,11 @@ The app expects a MariaDB instance configured through environment variables.
- `SESSION_TTL_SECONDS` (default: `86400`)
- `SESSION_COOKIE_SECURE` (default: `false`, set `true` in production HTTPS)
- `REQUIRE_CSRF` (default: `false`, checks same-origin/same-referer for write routes when enabled)
- `LOGIN_MAX_ATTEMPTS` (default: `5`)
- `LOGIN_WINDOW_SECONDS` (default: `300`)
- `LOGIN_LOCKOUT_SECONDS` (default: `900`)
- `MAX_ICON_BYTES` (default: `2097152`)
## Gitea CI/CD

View File

@@ -2,10 +2,12 @@ from __future__ import annotations
import secrets
import os
import time
from datetime import datetime
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import bcrypt
import pymysql
@@ -20,6 +22,16 @@ PUBLIC_DIR = Path("public")
SESSION_COOKIE = "jellomator_session"
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400"))
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on")
LOGIN_MAX_ATTEMPTS = int(os.getenv("LOGIN_MAX_ATTEMPTS", "5"))
LOGIN_WINDOW_SECONDS = int(os.getenv("LOGIN_WINDOW_SECONDS", "300"))
LOGIN_LOCKOUT_SECONDS = int(os.getenv("LOGIN_LOCKOUT_SECONDS", "900"))
REQUIRE_CSRF = os.getenv("REQUIRE_CSRF", "false").lower() in ("1", "true", "yes", "on")
MAX_NAME_LEN = int(os.getenv("MAX_NAME_LEN", "255"))
MAX_CATEGORY_LEN = int(os.getenv("MAX_CATEGORY_LEN", "255"))
MAX_DESCRIPTION_LEN = int(os.getenv("MAX_DESCRIPTION_LEN", "2000"))
MAX_ICON_URL_LEN = int(os.getenv("MAX_ICON_URL_LEN", "2048"))
MAX_ICON_BYTES = int(os.getenv("MAX_ICON_BYTES", str(2 * 1024 * 1024)))
ALLOWED_ICON_MIME = {"image/png", "image/jpeg", "image/webp", "image/gif", "image/svg+xml", "image/x-icon"}
DB_HOST = os.getenv("DB_HOST", "mariadb")
DB_PORT = int(os.getenv("DB_PORT", "3306"))
DB_USER = os.getenv("DB_USER", "jellomator")
@@ -28,6 +40,8 @@ DB_NAME = os.getenv("DB_NAME", "jellomator")
app = FastAPI(title="Jellomator")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
login_attempts: dict[str, list[float]] = {}
login_lockouts: dict[str, float] = {}
@app.get("/healthz")
@@ -170,6 +184,30 @@ def current_user(request: Request):
return {"username": row["username"], "role": row["role"]}
def client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def login_key(request: Request, username: str) -> str:
return f"{client_ip(request)}::{username.strip().lower()}"
def prune_login_tracking(now: float) -> None:
for key, until in list(login_lockouts.items()):
if until <= now:
del login_lockouts[key]
cutoff = now - LOGIN_WINDOW_SECONDS
for key, entries in list(login_attempts.items()):
filtered = [t for t in entries if t >= cutoff]
if filtered:
login_attempts[key] = filtered
else:
del login_attempts[key]
def require_admin(request: Request):
user = current_user(request)
if not user:
@@ -177,6 +215,20 @@ def require_admin(request: Request):
return user
def require_csrf(request: Request):
if not REQUIRE_CSRF:
return
origin = request.headers.get("origin")
referer = request.headers.get("referer")
target = f"{request.url.scheme}://{request.url.netloc}"
if origin and origin != target:
raise HTTPException(403, "Invalid CSRF origin")
if referer and not referer.startswith(target):
raise HTTPException(403, "Invalid CSRF referer")
if not origin and not referer:
raise HTTPException(403, "Invalid CSRF token")
def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool:
with conn.cursor() as cur:
if exclude_id is None:
@@ -186,6 +238,38 @@ def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool:
return cur.fetchone() is not None
def validate_http_url(value: str, field_name: str = "url") -> None:
parsed = urlparse((value or "").strip())
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise HTTPException(422, f"{field_name} must be a valid http(s) URL")
def validate_length(value: str | None, limit: int, field_name: str) -> None:
if value is not None and len(value) > limit:
raise HTTPException(422, f"{field_name} exceeds max length of {limit}")
def validate_link_payload(name: str, url: str, description: str, category: str, icon_url: str | None) -> None:
validate_length(name, MAX_NAME_LEN, "name")
validate_length(description, MAX_DESCRIPTION_LEN, "description")
validate_length(category, MAX_CATEGORY_LEN, "category")
if icon_url:
validate_length(icon_url, MAX_ICON_URL_LEN, "icon_url")
validate_http_url(icon_url, "icon_url")
validate_http_url(url, "url")
def read_icon_blob(icon: UploadFile | None) -> tuple[bytes | None, str | None]:
if not icon:
return None, None
if icon.content_type not in ALLOWED_ICON_MIME:
raise HTTPException(422, "Unsupported icon file type")
blob = icon.file.read(MAX_ICON_BYTES + 1)
if len(blob) > MAX_ICON_BYTES:
raise HTTPException(422, f"Icon exceeds max size of {MAX_ICON_BYTES} bytes")
return blob, icon.content_type
@app.get("/api/me")
def me(request: Request):
with db() as c:
@@ -208,13 +292,26 @@ def setup(inp: SetupIn):
@app.post("/api/login")
def login(inp: LoginIn):
def login(request: Request, inp: LoginIn):
now_ts = time.time()
prune_login_tracking(now_ts)
key = login_key(request, inp.username)
locked_until = login_lockouts.get(key)
if locked_until and locked_until > now_ts:
raise HTTPException(429, "Too many login attempts. Try again later.")
with db() as c:
with c.cursor() as cur:
cur.execute("select id,password_hash from users where username=%s", (inp.username,))
row = cur.fetchone()
if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]):
attempts = login_attempts.get(key, [])
attempts.append(now_ts)
login_attempts[key] = [t for t in attempts if t >= now_ts - LOGIN_WINDOW_SECONDS]
if len(login_attempts[key]) >= LOGIN_MAX_ATTEMPTS:
login_lockouts[key] = now_ts + LOGIN_LOCKOUT_SECONDS
raise HTTPException(401, "Invalid credentials")
login_attempts.pop(key, None)
login_lockouts.pop(key, None)
token = secrets.token_urlsafe(32)
with c.cursor() as cur:
now = utc_now_iso()
@@ -237,6 +334,7 @@ def login(inp: LoginIn):
@app.post("/api/logout")
def logout(request: Request):
require_csrf(request)
token = request.cookies.get(SESSION_COOKIE)
with db() as c:
if token:
@@ -288,13 +386,12 @@ def create_link(
icon: UploadFile | None = File(None),
):
require_admin(request)
require_csrf(request)
validate_link_payload(name, url, description, category, icon_url)
with db() as c:
if link_name_exists(c, name):
raise HTTPException(409, "Link name already exists")
icon_blob = icon_mime = None
if icon:
icon_blob = icon.file.read()
icon_mime = icon.content_type
icon_blob, icon_mime = read_icon_blob(icon)
now = datetime.utcnow().isoformat()
with db() as c:
with c.cursor() as cur:
@@ -320,6 +417,7 @@ def link_icon(link_id: int):
@app.delete("/api/links/{link_id}")
def delete_link(request: Request, link_id: int):
require_admin(request)
require_csrf(request)
with db() as c:
with c.cursor() as cur:
cur.execute("delete from links where id=%s", (link_id,))
@@ -339,10 +437,9 @@ def update_link(
icon: UploadFile | None = File(None),
):
require_admin(request)
icon_blob = icon_mime = None
if icon:
icon_blob = icon.file.read()
icon_mime = icon.content_type
require_csrf(request)
validate_link_payload(name, url, description, category, icon_url)
icon_blob, icon_mime = read_icon_blob(icon)
now = datetime.utcnow().isoformat()
with db() as c:
if link_name_exists(c, name, exclude_id=link_id):