Add login rate limiting with lockout window

This commit is contained in:
Space-Banane
2026-05-20 21:54:28 +02:00
parent dde83a2417
commit ed886c956d

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import secrets
import os
import time
from datetime import datetime
from contextlib import contextmanager
from pathlib import Path
@@ -20,6 +21,9 @@ PUBLIC_DIR = Path("public")
SESSION_COOKIE = "jellomator_session"
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400"))
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on")
LOGIN_MAX_ATTEMPTS = int(os.getenv("LOGIN_MAX_ATTEMPTS", "5"))
LOGIN_WINDOW_SECONDS = int(os.getenv("LOGIN_WINDOW_SECONDS", "300"))
LOGIN_LOCKOUT_SECONDS = int(os.getenv("LOGIN_LOCKOUT_SECONDS", "900"))
DB_HOST = os.getenv("DB_HOST", "mariadb")
DB_PORT = int(os.getenv("DB_PORT", "3306"))
DB_USER = os.getenv("DB_USER", "jellomator")
@@ -28,6 +32,8 @@ DB_NAME = os.getenv("DB_NAME", "jellomator")
app = FastAPI(title="Jellomator")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
login_attempts: dict[str, list[float]] = {}
login_lockouts: dict[str, float] = {}
@app.get("/healthz")
@@ -170,6 +176,30 @@ def current_user(request: Request):
return {"username": row["username"], "role": row["role"]}
def client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def login_key(request: Request, username: str) -> str:
return f"{client_ip(request)}::{username.strip().lower()}"
def prune_login_tracking(now: float) -> None:
for key, until in list(login_lockouts.items()):
if until <= now:
del login_lockouts[key]
cutoff = now - LOGIN_WINDOW_SECONDS
for key, entries in list(login_attempts.items()):
filtered = [t for t in entries if t >= cutoff]
if filtered:
login_attempts[key] = filtered
else:
del login_attempts[key]
def require_admin(request: Request):
user = current_user(request)
if not user:
@@ -208,13 +238,26 @@ def setup(inp: SetupIn):
@app.post("/api/login")
def login(inp: LoginIn):
def login(request: Request, inp: LoginIn):
now_ts = time.time()
prune_login_tracking(now_ts)
key = login_key(request, inp.username)
locked_until = login_lockouts.get(key)
if locked_until and locked_until > now_ts:
raise HTTPException(429, "Too many login attempts. Try again later.")
with db() as c:
with c.cursor() as cur:
cur.execute("select id,password_hash from users where username=%s", (inp.username,))
row = cur.fetchone()
if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]):
attempts = login_attempts.get(key, [])
attempts.append(now_ts)
login_attempts[key] = [t for t in attempts if t >= now_ts - LOGIN_WINDOW_SECONDS]
if len(login_attempts[key]) >= LOGIN_MAX_ATTEMPTS:
login_lockouts[key] = now_ts + LOGIN_LOCKOUT_SECONDS
raise HTTPException(401, "Invalid credentials")
login_attempts.pop(key, None)
login_lockouts.pop(key, None)
token = secrets.token_urlsafe(32)
with c.cursor() as cur:
now = utc_now_iso()