Add login rate limiting with lockout window

This commit is contained in:
Space-Banane
2026-05-20 21:54:28 +02:00
parent dde83a2417
commit ed886c956d

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import secrets import secrets
import os import os
import time
from datetime import datetime from datetime import datetime
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
@@ -20,6 +21,9 @@ PUBLIC_DIR = Path("public")
SESSION_COOKIE = "jellomator_session" SESSION_COOKIE = "jellomator_session"
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400")) SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400"))
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on") SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on")
LOGIN_MAX_ATTEMPTS = int(os.getenv("LOGIN_MAX_ATTEMPTS", "5"))
LOGIN_WINDOW_SECONDS = int(os.getenv("LOGIN_WINDOW_SECONDS", "300"))
LOGIN_LOCKOUT_SECONDS = int(os.getenv("LOGIN_LOCKOUT_SECONDS", "900"))
DB_HOST = os.getenv("DB_HOST", "mariadb") DB_HOST = os.getenv("DB_HOST", "mariadb")
DB_PORT = int(os.getenv("DB_PORT", "3306")) DB_PORT = int(os.getenv("DB_PORT", "3306"))
DB_USER = os.getenv("DB_USER", "jellomator") DB_USER = os.getenv("DB_USER", "jellomator")
@@ -28,6 +32,8 @@ DB_NAME = os.getenv("DB_NAME", "jellomator")
app = FastAPI(title="Jellomator") app = FastAPI(title="Jellomator")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
login_attempts: dict[str, list[float]] = {}
login_lockouts: dict[str, float] = {}
@app.get("/healthz") @app.get("/healthz")
@@ -170,6 +176,30 @@ def current_user(request: Request):
return {"username": row["username"], "role": row["role"]} return {"username": row["username"], "role": row["role"]}
def client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def login_key(request: Request, username: str) -> str:
return f"{client_ip(request)}::{username.strip().lower()}"
def prune_login_tracking(now: float) -> None:
for key, until in list(login_lockouts.items()):
if until <= now:
del login_lockouts[key]
cutoff = now - LOGIN_WINDOW_SECONDS
for key, entries in list(login_attempts.items()):
filtered = [t for t in entries if t >= cutoff]
if filtered:
login_attempts[key] = filtered
else:
del login_attempts[key]
def require_admin(request: Request): def require_admin(request: Request):
user = current_user(request) user = current_user(request)
if not user: if not user:
@@ -208,13 +238,26 @@ def setup(inp: SetupIn):
@app.post("/api/login") @app.post("/api/login")
def login(inp: LoginIn): def login(request: Request, inp: LoginIn):
now_ts = time.time()
prune_login_tracking(now_ts)
key = login_key(request, inp.username)
locked_until = login_lockouts.get(key)
if locked_until and locked_until > now_ts:
raise HTTPException(429, "Too many login attempts. Try again later.")
with db() as c: with db() as c:
with c.cursor() as cur: with c.cursor() as cur:
cur.execute("select id,password_hash from users where username=%s", (inp.username,)) cur.execute("select id,password_hash from users where username=%s", (inp.username,))
row = cur.fetchone() row = cur.fetchone()
if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]): if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]):
attempts = login_attempts.get(key, [])
attempts.append(now_ts)
login_attempts[key] = [t for t in attempts if t >= now_ts - LOGIN_WINDOW_SECONDS]
if len(login_attempts[key]) >= LOGIN_MAX_ATTEMPTS:
login_lockouts[key] = now_ts + LOGIN_LOCKOUT_SECONDS
raise HTTPException(401, "Invalid credentials") raise HTTPException(401, "Invalid credentials")
login_attempts.pop(key, None)
login_lockouts.pop(key, None)
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
with c.cursor() as cur: with c.cursor() as cur:
now = utc_now_iso() now = utc_now_iso()