Compare commits

..

3 Commits

Author SHA1 Message Date
Space-Banane
972ccce62a Add optional CSRF enforcement for write routes
All checks were successful
docker / build-and-push (push) Successful in 49s
2026-05-20 21:55:30 +02:00
Space-Banane
7c06d31ac1 Validate link payloads and icon uploads 2026-05-20 21:54:53 +02:00
Space-Banane
ed886c956d Add login rate limiting with lockout window 2026-05-20 21:54:28 +02:00
2 changed files with 111 additions and 9 deletions

View File

@@ -36,6 +36,11 @@ The app expects a MariaDB instance configured through environment variables.
- `SESSION_TTL_SECONDS` (default: `86400`) - `SESSION_TTL_SECONDS` (default: `86400`)
- `SESSION_COOKIE_SECURE` (default: `false`, set `true` in production HTTPS) - `SESSION_COOKIE_SECURE` (default: `false`, set `true` in production HTTPS)
- `REQUIRE_CSRF` (default: `false`, checks same-origin/same-referer for write routes when enabled)
- `LOGIN_MAX_ATTEMPTS` (default: `5`)
- `LOGIN_WINDOW_SECONDS` (default: `300`)
- `LOGIN_LOCKOUT_SECONDS` (default: `900`)
- `MAX_ICON_BYTES` (default: `2097152`)
## Gitea CI/CD ## Gitea CI/CD

View File

@@ -2,10 +2,12 @@ from __future__ import annotations
import secrets import secrets
import os import os
import time
from datetime import datetime from datetime import datetime
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from urllib.parse import urlparse
import bcrypt import bcrypt
import pymysql import pymysql
@@ -20,6 +22,16 @@ PUBLIC_DIR = Path("public")
SESSION_COOKIE = "jellomator_session" SESSION_COOKIE = "jellomator_session"
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400")) SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400"))
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on") SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on")
LOGIN_MAX_ATTEMPTS = int(os.getenv("LOGIN_MAX_ATTEMPTS", "5"))
LOGIN_WINDOW_SECONDS = int(os.getenv("LOGIN_WINDOW_SECONDS", "300"))
LOGIN_LOCKOUT_SECONDS = int(os.getenv("LOGIN_LOCKOUT_SECONDS", "900"))
REQUIRE_CSRF = os.getenv("REQUIRE_CSRF", "false").lower() in ("1", "true", "yes", "on")
MAX_NAME_LEN = int(os.getenv("MAX_NAME_LEN", "255"))
MAX_CATEGORY_LEN = int(os.getenv("MAX_CATEGORY_LEN", "255"))
MAX_DESCRIPTION_LEN = int(os.getenv("MAX_DESCRIPTION_LEN", "2000"))
MAX_ICON_URL_LEN = int(os.getenv("MAX_ICON_URL_LEN", "2048"))
MAX_ICON_BYTES = int(os.getenv("MAX_ICON_BYTES", str(2 * 1024 * 1024)))
ALLOWED_ICON_MIME = {"image/png", "image/jpeg", "image/webp", "image/gif", "image/svg+xml", "image/x-icon"}
DB_HOST = os.getenv("DB_HOST", "mariadb") DB_HOST = os.getenv("DB_HOST", "mariadb")
DB_PORT = int(os.getenv("DB_PORT", "3306")) DB_PORT = int(os.getenv("DB_PORT", "3306"))
DB_USER = os.getenv("DB_USER", "jellomator") DB_USER = os.getenv("DB_USER", "jellomator")
@@ -28,6 +40,8 @@ DB_NAME = os.getenv("DB_NAME", "jellomator")
app = FastAPI(title="Jellomator") app = FastAPI(title="Jellomator")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
login_attempts: dict[str, list[float]] = {}
login_lockouts: dict[str, float] = {}
@app.get("/healthz") @app.get("/healthz")
@@ -170,6 +184,30 @@ def current_user(request: Request):
return {"username": row["username"], "role": row["role"]} return {"username": row["username"], "role": row["role"]}
def client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def login_key(request: Request, username: str) -> str:
return f"{client_ip(request)}::{username.strip().lower()}"
def prune_login_tracking(now: float) -> None:
for key, until in list(login_lockouts.items()):
if until <= now:
del login_lockouts[key]
cutoff = now - LOGIN_WINDOW_SECONDS
for key, entries in list(login_attempts.items()):
filtered = [t for t in entries if t >= cutoff]
if filtered:
login_attempts[key] = filtered
else:
del login_attempts[key]
def require_admin(request: Request): def require_admin(request: Request):
user = current_user(request) user = current_user(request)
if not user: if not user:
@@ -177,6 +215,20 @@ def require_admin(request: Request):
return user return user
def require_csrf(request: Request):
if not REQUIRE_CSRF:
return
origin = request.headers.get("origin")
referer = request.headers.get("referer")
target = f"{request.url.scheme}://{request.url.netloc}"
if origin and origin != target:
raise HTTPException(403, "Invalid CSRF origin")
if referer and not referer.startswith(target):
raise HTTPException(403, "Invalid CSRF referer")
if not origin and not referer:
raise HTTPException(403, "Invalid CSRF token")
def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool: def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool:
with conn.cursor() as cur: with conn.cursor() as cur:
if exclude_id is None: if exclude_id is None:
@@ -186,6 +238,38 @@ def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool:
return cur.fetchone() is not None return cur.fetchone() is not None
def validate_http_url(value: str, field_name: str = "url") -> None:
parsed = urlparse((value or "").strip())
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise HTTPException(422, f"{field_name} must be a valid http(s) URL")
def validate_length(value: str | None, limit: int, field_name: str) -> None:
if value is not None and len(value) > limit:
raise HTTPException(422, f"{field_name} exceeds max length of {limit}")
def validate_link_payload(name: str, url: str, description: str, category: str, icon_url: str | None) -> None:
validate_length(name, MAX_NAME_LEN, "name")
validate_length(description, MAX_DESCRIPTION_LEN, "description")
validate_length(category, MAX_CATEGORY_LEN, "category")
if icon_url:
validate_length(icon_url, MAX_ICON_URL_LEN, "icon_url")
validate_http_url(icon_url, "icon_url")
validate_http_url(url, "url")
def read_icon_blob(icon: UploadFile | None) -> tuple[bytes | None, str | None]:
if not icon:
return None, None
if icon.content_type not in ALLOWED_ICON_MIME:
raise HTTPException(422, "Unsupported icon file type")
blob = icon.file.read(MAX_ICON_BYTES + 1)
if len(blob) > MAX_ICON_BYTES:
raise HTTPException(422, f"Icon exceeds max size of {MAX_ICON_BYTES} bytes")
return blob, icon.content_type
@app.get("/api/me") @app.get("/api/me")
def me(request: Request): def me(request: Request):
with db() as c: with db() as c:
@@ -208,13 +292,26 @@ def setup(inp: SetupIn):
@app.post("/api/login") @app.post("/api/login")
def login(inp: LoginIn): def login(request: Request, inp: LoginIn):
now_ts = time.time()
prune_login_tracking(now_ts)
key = login_key(request, inp.username)
locked_until = login_lockouts.get(key)
if locked_until and locked_until > now_ts:
raise HTTPException(429, "Too many login attempts. Try again later.")
with db() as c: with db() as c:
with c.cursor() as cur: with c.cursor() as cur:
cur.execute("select id,password_hash from users where username=%s", (inp.username,)) cur.execute("select id,password_hash from users where username=%s", (inp.username,))
row = cur.fetchone() row = cur.fetchone()
if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]): if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]):
attempts = login_attempts.get(key, [])
attempts.append(now_ts)
login_attempts[key] = [t for t in attempts if t >= now_ts - LOGIN_WINDOW_SECONDS]
if len(login_attempts[key]) >= LOGIN_MAX_ATTEMPTS:
login_lockouts[key] = now_ts + LOGIN_LOCKOUT_SECONDS
raise HTTPException(401, "Invalid credentials") raise HTTPException(401, "Invalid credentials")
login_attempts.pop(key, None)
login_lockouts.pop(key, None)
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
with c.cursor() as cur: with c.cursor() as cur:
now = utc_now_iso() now = utc_now_iso()
@@ -237,6 +334,7 @@ def login(inp: LoginIn):
@app.post("/api/logout") @app.post("/api/logout")
def logout(request: Request): def logout(request: Request):
require_csrf(request)
token = request.cookies.get(SESSION_COOKIE) token = request.cookies.get(SESSION_COOKIE)
with db() as c: with db() as c:
if token: if token:
@@ -288,13 +386,12 @@ def create_link(
icon: UploadFile | None = File(None), icon: UploadFile | None = File(None),
): ):
require_admin(request) require_admin(request)
require_csrf(request)
validate_link_payload(name, url, description, category, icon_url)
with db() as c: with db() as c:
if link_name_exists(c, name): if link_name_exists(c, name):
raise HTTPException(409, "Link name already exists") raise HTTPException(409, "Link name already exists")
icon_blob = icon_mime = None icon_blob, icon_mime = read_icon_blob(icon)
if icon:
icon_blob = icon.file.read()
icon_mime = icon.content_type
now = datetime.utcnow().isoformat() now = datetime.utcnow().isoformat()
with db() as c: with db() as c:
with c.cursor() as cur: with c.cursor() as cur:
@@ -320,6 +417,7 @@ def link_icon(link_id: int):
@app.delete("/api/links/{link_id}") @app.delete("/api/links/{link_id}")
def delete_link(request: Request, link_id: int): def delete_link(request: Request, link_id: int):
require_admin(request) require_admin(request)
require_csrf(request)
with db() as c: with db() as c:
with c.cursor() as cur: with c.cursor() as cur:
cur.execute("delete from links where id=%s", (link_id,)) cur.execute("delete from links where id=%s", (link_id,))
@@ -339,10 +437,9 @@ def update_link(
icon: UploadFile | None = File(None), icon: UploadFile | None = File(None),
): ):
require_admin(request) require_admin(request)
icon_blob = icon_mime = None require_csrf(request)
if icon: validate_link_payload(name, url, description, category, icon_url)
icon_blob = icon.file.read() icon_blob, icon_mime = read_icon_blob(icon)
icon_mime = icon.content_type
now = datetime.utcnow().isoformat() now = datetime.utcnow().isoformat()
with db() as c: with db() as c:
if link_name_exists(c, name, exclude_id=link_id): if link_name_exists(c, name, exclude_id=link_id):