Compare commits
3 Commits
dde83a2417
...
972ccce62a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
972ccce62a | ||
|
|
7c06d31ac1 | ||
|
|
ed886c956d |
@@ -36,6 +36,11 @@ The app expects a MariaDB instance configured through environment variables.
|
|||||||
|
|
||||||
- `SESSION_TTL_SECONDS` (default: `86400`)
|
- `SESSION_TTL_SECONDS` (default: `86400`)
|
||||||
- `SESSION_COOKIE_SECURE` (default: `false`, set `true` in production HTTPS)
|
- `SESSION_COOKIE_SECURE` (default: `false`, set `true` in production HTTPS)
|
||||||
|
- `REQUIRE_CSRF` (default: `false`, checks same-origin/same-referer for write routes when enabled)
|
||||||
|
- `LOGIN_MAX_ATTEMPTS` (default: `5`)
|
||||||
|
- `LOGIN_WINDOW_SECONDS` (default: `300`)
|
||||||
|
- `LOGIN_LOCKOUT_SECONDS` (default: `900`)
|
||||||
|
- `MAX_ICON_BYTES` (default: `2097152`)
|
||||||
|
|
||||||
## Gitea CI/CD
|
## Gitea CI/CD
|
||||||
|
|
||||||
|
|||||||
115
backend/main.py
115
backend/main.py
@@ -2,10 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import pymysql
|
import pymysql
|
||||||
@@ -20,6 +22,16 @@ PUBLIC_DIR = Path("public")
|
|||||||
SESSION_COOKIE = "jellomator_session"
|
SESSION_COOKIE = "jellomator_session"
|
||||||
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400"))
|
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400"))
|
||||||
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on")
|
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on")
|
||||||
|
LOGIN_MAX_ATTEMPTS = int(os.getenv("LOGIN_MAX_ATTEMPTS", "5"))
|
||||||
|
LOGIN_WINDOW_SECONDS = int(os.getenv("LOGIN_WINDOW_SECONDS", "300"))
|
||||||
|
LOGIN_LOCKOUT_SECONDS = int(os.getenv("LOGIN_LOCKOUT_SECONDS", "900"))
|
||||||
|
REQUIRE_CSRF = os.getenv("REQUIRE_CSRF", "false").lower() in ("1", "true", "yes", "on")
|
||||||
|
MAX_NAME_LEN = int(os.getenv("MAX_NAME_LEN", "255"))
|
||||||
|
MAX_CATEGORY_LEN = int(os.getenv("MAX_CATEGORY_LEN", "255"))
|
||||||
|
MAX_DESCRIPTION_LEN = int(os.getenv("MAX_DESCRIPTION_LEN", "2000"))
|
||||||
|
MAX_ICON_URL_LEN = int(os.getenv("MAX_ICON_URL_LEN", "2048"))
|
||||||
|
MAX_ICON_BYTES = int(os.getenv("MAX_ICON_BYTES", str(2 * 1024 * 1024)))
|
||||||
|
ALLOWED_ICON_MIME = {"image/png", "image/jpeg", "image/webp", "image/gif", "image/svg+xml", "image/x-icon"}
|
||||||
DB_HOST = os.getenv("DB_HOST", "mariadb")
|
DB_HOST = os.getenv("DB_HOST", "mariadb")
|
||||||
DB_PORT = int(os.getenv("DB_PORT", "3306"))
|
DB_PORT = int(os.getenv("DB_PORT", "3306"))
|
||||||
DB_USER = os.getenv("DB_USER", "jellomator")
|
DB_USER = os.getenv("DB_USER", "jellomator")
|
||||||
@@ -28,6 +40,8 @@ DB_NAME = os.getenv("DB_NAME", "jellomator")
|
|||||||
|
|
||||||
app = FastAPI(title="Jellomator")
|
app = FastAPI(title="Jellomator")
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
||||||
|
login_attempts: dict[str, list[float]] = {}
|
||||||
|
login_lockouts: dict[str, float] = {}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/healthz")
|
@app.get("/healthz")
|
||||||
@@ -170,6 +184,30 @@ def current_user(request: Request):
|
|||||||
return {"username": row["username"], "role": row["role"]}
|
return {"username": row["username"], "role": row["role"]}
|
||||||
|
|
||||||
|
|
||||||
|
def client_ip(request: Request) -> str:
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def login_key(request: Request, username: str) -> str:
|
||||||
|
return f"{client_ip(request)}::{username.strip().lower()}"
|
||||||
|
|
||||||
|
|
||||||
|
def prune_login_tracking(now: float) -> None:
|
||||||
|
for key, until in list(login_lockouts.items()):
|
||||||
|
if until <= now:
|
||||||
|
del login_lockouts[key]
|
||||||
|
cutoff = now - LOGIN_WINDOW_SECONDS
|
||||||
|
for key, entries in list(login_attempts.items()):
|
||||||
|
filtered = [t for t in entries if t >= cutoff]
|
||||||
|
if filtered:
|
||||||
|
login_attempts[key] = filtered
|
||||||
|
else:
|
||||||
|
del login_attempts[key]
|
||||||
|
|
||||||
|
|
||||||
def require_admin(request: Request):
|
def require_admin(request: Request):
|
||||||
user = current_user(request)
|
user = current_user(request)
|
||||||
if not user:
|
if not user:
|
||||||
@@ -177,6 +215,20 @@ def require_admin(request: Request):
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def require_csrf(request: Request):
|
||||||
|
if not REQUIRE_CSRF:
|
||||||
|
return
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
referer = request.headers.get("referer")
|
||||||
|
target = f"{request.url.scheme}://{request.url.netloc}"
|
||||||
|
if origin and origin != target:
|
||||||
|
raise HTTPException(403, "Invalid CSRF origin")
|
||||||
|
if referer and not referer.startswith(target):
|
||||||
|
raise HTTPException(403, "Invalid CSRF referer")
|
||||||
|
if not origin and not referer:
|
||||||
|
raise HTTPException(403, "Invalid CSRF token")
|
||||||
|
|
||||||
|
|
||||||
def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool:
|
def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
if exclude_id is None:
|
if exclude_id is None:
|
||||||
@@ -186,6 +238,38 @@ def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool:
|
|||||||
return cur.fetchone() is not None
|
return cur.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_http_url(value: str, field_name: str = "url") -> None:
|
||||||
|
parsed = urlparse((value or "").strip())
|
||||||
|
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||||
|
raise HTTPException(422, f"{field_name} must be a valid http(s) URL")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_length(value: str | None, limit: int, field_name: str) -> None:
|
||||||
|
if value is not None and len(value) > limit:
|
||||||
|
raise HTTPException(422, f"{field_name} exceeds max length of {limit}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_link_payload(name: str, url: str, description: str, category: str, icon_url: str | None) -> None:
|
||||||
|
validate_length(name, MAX_NAME_LEN, "name")
|
||||||
|
validate_length(description, MAX_DESCRIPTION_LEN, "description")
|
||||||
|
validate_length(category, MAX_CATEGORY_LEN, "category")
|
||||||
|
if icon_url:
|
||||||
|
validate_length(icon_url, MAX_ICON_URL_LEN, "icon_url")
|
||||||
|
validate_http_url(icon_url, "icon_url")
|
||||||
|
validate_http_url(url, "url")
|
||||||
|
|
||||||
|
|
||||||
|
def read_icon_blob(icon: UploadFile | None) -> tuple[bytes | None, str | None]:
|
||||||
|
if not icon:
|
||||||
|
return None, None
|
||||||
|
if icon.content_type not in ALLOWED_ICON_MIME:
|
||||||
|
raise HTTPException(422, "Unsupported icon file type")
|
||||||
|
blob = icon.file.read(MAX_ICON_BYTES + 1)
|
||||||
|
if len(blob) > MAX_ICON_BYTES:
|
||||||
|
raise HTTPException(422, f"Icon exceeds max size of {MAX_ICON_BYTES} bytes")
|
||||||
|
return blob, icon.content_type
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/me")
|
@app.get("/api/me")
|
||||||
def me(request: Request):
|
def me(request: Request):
|
||||||
with db() as c:
|
with db() as c:
|
||||||
@@ -208,13 +292,26 @@ def setup(inp: SetupIn):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/api/login")
|
@app.post("/api/login")
|
||||||
def login(inp: LoginIn):
|
def login(request: Request, inp: LoginIn):
|
||||||
|
now_ts = time.time()
|
||||||
|
prune_login_tracking(now_ts)
|
||||||
|
key = login_key(request, inp.username)
|
||||||
|
locked_until = login_lockouts.get(key)
|
||||||
|
if locked_until and locked_until > now_ts:
|
||||||
|
raise HTTPException(429, "Too many login attempts. Try again later.")
|
||||||
with db() as c:
|
with db() as c:
|
||||||
with c.cursor() as cur:
|
with c.cursor() as cur:
|
||||||
cur.execute("select id,password_hash from users where username=%s", (inp.username,))
|
cur.execute("select id,password_hash from users where username=%s", (inp.username,))
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]):
|
if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]):
|
||||||
|
attempts = login_attempts.get(key, [])
|
||||||
|
attempts.append(now_ts)
|
||||||
|
login_attempts[key] = [t for t in attempts if t >= now_ts - LOGIN_WINDOW_SECONDS]
|
||||||
|
if len(login_attempts[key]) >= LOGIN_MAX_ATTEMPTS:
|
||||||
|
login_lockouts[key] = now_ts + LOGIN_LOCKOUT_SECONDS
|
||||||
raise HTTPException(401, "Invalid credentials")
|
raise HTTPException(401, "Invalid credentials")
|
||||||
|
login_attempts.pop(key, None)
|
||||||
|
login_lockouts.pop(key, None)
|
||||||
token = secrets.token_urlsafe(32)
|
token = secrets.token_urlsafe(32)
|
||||||
with c.cursor() as cur:
|
with c.cursor() as cur:
|
||||||
now = utc_now_iso()
|
now = utc_now_iso()
|
||||||
@@ -237,6 +334,7 @@ def login(inp: LoginIn):
|
|||||||
|
|
||||||
@app.post("/api/logout")
|
@app.post("/api/logout")
|
||||||
def logout(request: Request):
|
def logout(request: Request):
|
||||||
|
require_csrf(request)
|
||||||
token = request.cookies.get(SESSION_COOKIE)
|
token = request.cookies.get(SESSION_COOKIE)
|
||||||
with db() as c:
|
with db() as c:
|
||||||
if token:
|
if token:
|
||||||
@@ -288,13 +386,12 @@ def create_link(
|
|||||||
icon: UploadFile | None = File(None),
|
icon: UploadFile | None = File(None),
|
||||||
):
|
):
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
|
require_csrf(request)
|
||||||
|
validate_link_payload(name, url, description, category, icon_url)
|
||||||
with db() as c:
|
with db() as c:
|
||||||
if link_name_exists(c, name):
|
if link_name_exists(c, name):
|
||||||
raise HTTPException(409, "Link name already exists")
|
raise HTTPException(409, "Link name already exists")
|
||||||
icon_blob = icon_mime = None
|
icon_blob, icon_mime = read_icon_blob(icon)
|
||||||
if icon:
|
|
||||||
icon_blob = icon.file.read()
|
|
||||||
icon_mime = icon.content_type
|
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
with db() as c:
|
with db() as c:
|
||||||
with c.cursor() as cur:
|
with c.cursor() as cur:
|
||||||
@@ -320,6 +417,7 @@ def link_icon(link_id: int):
|
|||||||
@app.delete("/api/links/{link_id}")
|
@app.delete("/api/links/{link_id}")
|
||||||
def delete_link(request: Request, link_id: int):
|
def delete_link(request: Request, link_id: int):
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
|
require_csrf(request)
|
||||||
with db() as c:
|
with db() as c:
|
||||||
with c.cursor() as cur:
|
with c.cursor() as cur:
|
||||||
cur.execute("delete from links where id=%s", (link_id,))
|
cur.execute("delete from links where id=%s", (link_id,))
|
||||||
@@ -339,10 +437,9 @@ def update_link(
|
|||||||
icon: UploadFile | None = File(None),
|
icon: UploadFile | None = File(None),
|
||||||
):
|
):
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
icon_blob = icon_mime = None
|
require_csrf(request)
|
||||||
if icon:
|
validate_link_payload(name, url, description, category, icon_url)
|
||||||
icon_blob = icon.file.read()
|
icon_blob, icon_mime = read_icon_blob(icon)
|
||||||
icon_mime = icon.content_type
|
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
with db() as c:
|
with db() as c:
|
||||||
if link_name_exists(c, name, exclude_id=link_id):
|
if link_name_exists(c, name, exclude_id=link_id):
|
||||||
|
|||||||
Reference in New Issue
Block a user