435 lines
15 KiB
Python
435 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import secrets
|
|
import os
|
|
import time
|
|
from datetime import datetime
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import bcrypt
|
|
import pymysql
|
|
from fastapi import FastAPI, File, Form, HTTPException, Request, Response, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
STATIC_DIR = Path("frontend/dist")
|
|
PUBLIC_DIR = Path("public")
|
|
SESSION_COOKIE = "jellomator_session"
|
|
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", "86400"))
|
|
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "false").lower() in ("1", "true", "yes", "on")
|
|
LOGIN_MAX_ATTEMPTS = int(os.getenv("LOGIN_MAX_ATTEMPTS", "5"))
|
|
LOGIN_WINDOW_SECONDS = int(os.getenv("LOGIN_WINDOW_SECONDS", "300"))
|
|
LOGIN_LOCKOUT_SECONDS = int(os.getenv("LOGIN_LOCKOUT_SECONDS", "900"))
|
|
DB_HOST = os.getenv("DB_HOST", "mariadb")
|
|
DB_PORT = int(os.getenv("DB_PORT", "3306"))
|
|
DB_USER = os.getenv("DB_USER", "jellomator")
|
|
DB_PASSWORD = os.getenv("DB_PASSWORD", "jellomator")
|
|
DB_NAME = os.getenv("DB_NAME", "jellomator")
|
|
|
|
app = FastAPI(title="Jellomator")
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
|
login_attempts: dict[str, list[float]] = {}
|
|
login_lockouts: dict[str, float] = {}
|
|
|
|
|
|
@app.get("/healthz")
|
|
def healthz():
|
|
return {"ok": True}
|
|
|
|
|
|
@app.get("/readyz")
|
|
def readyz():
|
|
try:
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("select 1 as ok")
|
|
cur.fetchone()
|
|
except Exception:
|
|
raise HTTPException(503, "Database not ready")
|
|
return {"ok": True}
|
|
|
|
|
|
@contextmanager
|
|
def db():
|
|
conn = pymysql.connect(
|
|
host=DB_HOST,
|
|
port=DB_PORT,
|
|
user=DB_USER,
|
|
password=DB_PASSWORD,
|
|
cursorclass=pymysql.cursors.DictCursor,
|
|
autocommit=True,
|
|
)
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(f"create database if not exists `{DB_NAME}` default charset=utf8mb4")
|
|
conn.select_db(DB_NAME)
|
|
yield conn
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def init_db():
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("""
|
|
create table if not exists users(
|
|
id bigint auto_increment primary key,
|
|
username varchar(255) not null unique,
|
|
password_hash varbinary(255) not null,
|
|
role varchar(32) not null
|
|
) engine=InnoDB default charset=utf8mb4
|
|
""")
|
|
cur.execute("""
|
|
create table if not exists sessions(
|
|
token varchar(255) primary key,
|
|
user_id bigint not null,
|
|
created_at varchar(64) not null,
|
|
expires_at varchar(64) null,
|
|
last_seen_at varchar(64) null,
|
|
index (user_id),
|
|
constraint sessions_user_fk foreign key (user_id) references users(id) on delete cascade
|
|
) engine=InnoDB default charset=utf8mb4
|
|
""")
|
|
cur.execute("show columns from sessions like 'expires_at'")
|
|
if cur.fetchone() is None:
|
|
cur.execute("alter table sessions add column expires_at varchar(64) null after created_at")
|
|
cur.execute("show columns from sessions like 'last_seen_at'")
|
|
if cur.fetchone() is None:
|
|
cur.execute("alter table sessions add column last_seen_at varchar(64) null after expires_at")
|
|
cur.execute("""
|
|
create table if not exists links(
|
|
id bigint auto_increment primary key,
|
|
name varchar(255) not null unique,
|
|
url text not null,
|
|
description text,
|
|
category varchar(255),
|
|
icon_blob longblob,
|
|
icon_mime varchar(255),
|
|
icon_url text,
|
|
enabled tinyint(1) not null default 1,
|
|
created_at varchar(64) not null,
|
|
updated_at varchar(64) not null
|
|
) engine=InnoDB default charset=utf8mb4
|
|
""")
|
|
init_db()
|
|
|
|
|
|
class SetupIn(BaseModel):
|
|
username: str
|
|
password: str
|
|
|
|
|
|
class LinkIn(BaseModel):
|
|
name: str
|
|
url: str
|
|
description: str = ""
|
|
category: str = "General"
|
|
icon_url: Optional[str] = None
|
|
enabled: bool = True
|
|
|
|
|
|
class LoginIn(BaseModel):
|
|
username: str
|
|
password: str
|
|
|
|
|
|
def utc_now_iso() -> str:
|
|
return datetime.utcnow().isoformat()
|
|
|
|
|
|
def expires_at_iso() -> str:
|
|
now = datetime.utcnow().timestamp()
|
|
return datetime.utcfromtimestamp(now + SESSION_TTL_SECONDS).isoformat()
|
|
|
|
|
|
def current_user(request: Request):
|
|
token = request.cookies.get(SESSION_COOKIE)
|
|
if not token:
|
|
return None
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute(
|
|
"select s.expires_at,u.username,u.role from sessions s join users u on u.id=s.user_id where s.token=%s",
|
|
(token,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return None
|
|
expires_at = row.get("expires_at")
|
|
now = datetime.utcnow()
|
|
if expires_at:
|
|
try:
|
|
if now >= datetime.fromisoformat(expires_at):
|
|
cur.execute("delete from sessions where token=%s", (token,))
|
|
return None
|
|
except ValueError:
|
|
cur.execute("delete from sessions where token=%s", (token,))
|
|
return None
|
|
cur.execute(
|
|
"update sessions set last_seen_at=%s where token=%s",
|
|
(utc_now_iso(), token),
|
|
)
|
|
return {"username": row["username"], "role": row["role"]}
|
|
|
|
|
|
def client_ip(request: Request) -> str:
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
|
|
def login_key(request: Request, username: str) -> str:
|
|
return f"{client_ip(request)}::{username.strip().lower()}"
|
|
|
|
|
|
def prune_login_tracking(now: float) -> None:
|
|
for key, until in list(login_lockouts.items()):
|
|
if until <= now:
|
|
del login_lockouts[key]
|
|
cutoff = now - LOGIN_WINDOW_SECONDS
|
|
for key, entries in list(login_attempts.items()):
|
|
filtered = [t for t in entries if t >= cutoff]
|
|
if filtered:
|
|
login_attempts[key] = filtered
|
|
else:
|
|
del login_attempts[key]
|
|
|
|
|
|
def require_admin(request: Request):
|
|
user = current_user(request)
|
|
if not user:
|
|
raise HTTPException(401, "Unauthorized")
|
|
return user
|
|
|
|
|
|
def link_name_exists(conn, name: str, *, exclude_id: int | None = None) -> bool:
|
|
with conn.cursor() as cur:
|
|
if exclude_id is None:
|
|
cur.execute("select 1 from links where lower(name)=lower(%s) limit 1", (name,))
|
|
else:
|
|
cur.execute("select 1 from links where lower(name)=lower(%s) and id<>%s limit 1", (name, exclude_id))
|
|
return cur.fetchone() is not None
|
|
|
|
|
|
@app.get("/api/me")
|
|
def me(request: Request):
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("select count(*) as count from users")
|
|
needs_setup = cur.fetchone()["count"] == 0
|
|
return {"needs_setup": needs_setup, "current_user": current_user(request)}
|
|
|
|
|
|
@app.post("/api/setup")
|
|
def setup(inp: SetupIn):
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("select count(*) as count from users")
|
|
if cur.fetchone()["count"] > 0:
|
|
raise HTTPException(400, "Setup already complete")
|
|
pw = bcrypt.hashpw(inp.password.encode(), bcrypt.gensalt())
|
|
cur.execute("insert into users(username,password_hash,role) values (%s,%s,%s)", (inp.username, pw, "admin"))
|
|
return {"ok": True}
|
|
|
|
|
|
@app.post("/api/login")
|
|
def login(request: Request, inp: LoginIn):
|
|
now_ts = time.time()
|
|
prune_login_tracking(now_ts)
|
|
key = login_key(request, inp.username)
|
|
locked_until = login_lockouts.get(key)
|
|
if locked_until and locked_until > now_ts:
|
|
raise HTTPException(429, "Too many login attempts. Try again later.")
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("select id,password_hash from users where username=%s", (inp.username,))
|
|
row = cur.fetchone()
|
|
if not row or not bcrypt.checkpw(inp.password.encode(), row["password_hash"]):
|
|
attempts = login_attempts.get(key, [])
|
|
attempts.append(now_ts)
|
|
login_attempts[key] = [t for t in attempts if t >= now_ts - LOGIN_WINDOW_SECONDS]
|
|
if len(login_attempts[key]) >= LOGIN_MAX_ATTEMPTS:
|
|
login_lockouts[key] = now_ts + LOGIN_LOCKOUT_SECONDS
|
|
raise HTTPException(401, "Invalid credentials")
|
|
login_attempts.pop(key, None)
|
|
login_lockouts.pop(key, None)
|
|
token = secrets.token_urlsafe(32)
|
|
with c.cursor() as cur:
|
|
now = utc_now_iso()
|
|
cur.execute(
|
|
"insert into sessions(token,user_id,created_at,expires_at,last_seen_at) values (%s,%s,%s,%s,%s)",
|
|
(token, row["id"], now, expires_at_iso(), now),
|
|
)
|
|
response = JSONResponse({"ok": True})
|
|
response.set_cookie(
|
|
SESSION_COOKIE,
|
|
token,
|
|
httponly=True,
|
|
samesite="lax",
|
|
secure=SESSION_COOKIE_SECURE,
|
|
max_age=SESSION_TTL_SECONDS,
|
|
path="/",
|
|
)
|
|
return response
|
|
|
|
|
|
@app.post("/api/logout")
|
|
def logout(request: Request):
|
|
token = request.cookies.get(SESSION_COOKIE)
|
|
with db() as c:
|
|
if token:
|
|
with c.cursor() as cur:
|
|
cur.execute("delete from sessions where token=%s", (token,))
|
|
resp = JSONResponse({"ok": True})
|
|
resp.delete_cookie(SESSION_COOKIE, path="/")
|
|
return resp
|
|
|
|
|
|
@app.get("/api/links")
|
|
def links():
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("select * from links order by enabled desc, category, name")
|
|
rows = cur.fetchall()
|
|
out = []
|
|
for r in rows:
|
|
icon_url = None
|
|
if r["icon_blob"]:
|
|
icon_url = f"/api/links/{r['id']}/icon"
|
|
elif r["icon_url"]:
|
|
icon_url = r["icon_url"]
|
|
out.append({k: r[k] for k in ["id", "name", "url", "description", "category", "enabled"]} | {"icon_url": icon_url})
|
|
return out
|
|
|
|
|
|
@app.get("/api/links/{link_id}")
|
|
def get_link(link_id: int):
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("select * from links where id=%s", (link_id,))
|
|
row = cur.fetchone()
|
|
if not row:
|
|
raise HTTPException(404, "Not found")
|
|
icon_url = f"/api/links/{row['id']}/icon" if row["icon_blob"] else row["icon_url"]
|
|
return {k: row[k] for k in ["id", "name", "url", "description", "category", "enabled", "icon_url"]}
|
|
|
|
|
|
@app.post("/api/links")
|
|
def create_link(
|
|
request: Request,
|
|
name: str = Form(...),
|
|
url: str = Form(...),
|
|
description: str = Form(""),
|
|
category: str = Form("General"),
|
|
icon_url: str | None = Form(None),
|
|
enabled: bool = Form(True),
|
|
icon: UploadFile | None = File(None),
|
|
):
|
|
require_admin(request)
|
|
with db() as c:
|
|
if link_name_exists(c, name):
|
|
raise HTTPException(409, "Link name already exists")
|
|
icon_blob = icon_mime = None
|
|
if icon:
|
|
icon_blob = icon.file.read()
|
|
icon_mime = icon.content_type
|
|
now = datetime.utcnow().isoformat()
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute(
|
|
"""insert into links(name,url,description,category,icon_blob,icon_mime,icon_url,enabled,created_at,updated_at)
|
|
values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)""",
|
|
(name, url, description, category, icon_blob, icon_mime, icon_url, int(enabled), now, now),
|
|
)
|
|
return {"ok": True}
|
|
|
|
|
|
@app.get("/api/links/{link_id}/icon")
|
|
def link_icon(link_id: int):
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("select icon_blob,icon_mime from links where id=%s", (link_id,))
|
|
row = cur.fetchone()
|
|
if not row or not row["icon_blob"]:
|
|
raise HTTPException(404, "Not found")
|
|
return Response(content=row["icon_blob"], media_type=row["icon_mime"] or "image/png")
|
|
|
|
|
|
@app.delete("/api/links/{link_id}")
|
|
def delete_link(request: Request, link_id: int):
|
|
require_admin(request)
|
|
with db() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("delete from links where id=%s", (link_id,))
|
|
return {"ok": True}
|
|
|
|
|
|
@app.patch("/api/links/{link_id}")
|
|
def update_link(
|
|
request: Request,
|
|
link_id: int,
|
|
name: str = Form(...),
|
|
url: str = Form(...),
|
|
description: str = Form(""),
|
|
category: str = Form("General"),
|
|
icon_url: str | None = Form(None),
|
|
enabled: bool = Form(True),
|
|
icon: UploadFile | None = File(None),
|
|
):
|
|
require_admin(request)
|
|
icon_blob = icon_mime = None
|
|
if icon:
|
|
icon_blob = icon.file.read()
|
|
icon_mime = icon.content_type
|
|
now = datetime.utcnow().isoformat()
|
|
with db() as c:
|
|
if link_name_exists(c, name, exclude_id=link_id):
|
|
raise HTTPException(409, "Link name already exists")
|
|
with c.cursor() as cur:
|
|
if icon_blob:
|
|
cur.execute(
|
|
"""update links set name=%s,url=%s,description=%s,category=%s,icon_blob=%s,icon_mime=%s,icon_url=%s,enabled=%s,updated_at=%s where id=%s""",
|
|
(name, url, description, category, icon_blob, icon_mime, icon_url, int(enabled), now, link_id),
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"""update links set name=%s,url=%s,description=%s,category=%s,icon_url=%s,enabled=%s,updated_at=%s where id=%s""",
|
|
(name, url, description, category, icon_url, int(enabled), now, link_id),
|
|
)
|
|
return {"ok": True}
|
|
|
|
|
|
if STATIC_DIR.exists():
|
|
app.mount("/assets", StaticFiles(directory=STATIC_DIR / "assets"), name="assets")
|
|
if PUBLIC_DIR.exists():
|
|
app.mount("/static", StaticFiles(directory=PUBLIC_DIR), name="public")
|
|
elif STATIC_DIR.exists():
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="public-dist")
|
|
|
|
|
|
@app.get("/jellomator.png")
|
|
def root_icon():
|
|
icon = PUBLIC_DIR / "jellomator.png"
|
|
if not icon.exists():
|
|
icon = STATIC_DIR / "jellomator.png"
|
|
if not icon.exists():
|
|
raise HTTPException(404, "Not found")
|
|
return FileResponse(icon)
|
|
|
|
|
|
@app.get("/{path:path}")
|
|
def spa(path: str):
|
|
if path.startswith("api/"):
|
|
raise HTTPException(404)
|
|
if "." in path:
|
|
raise HTTPException(404)
|
|
index = STATIC_DIR / "index.html"
|
|
if index.exists():
|
|
return FileResponse(index)
|
|
return HTMLResponse("<h1>Jellomator</h1>")
|