diff --git a/shsf/contact-api/main.py b/shsf/contact-api/main.py index 7c514a2..b81709e 100644 --- a/shsf/contact-api/main.py +++ b/shsf/contact-api/main.py @@ -8,6 +8,7 @@ from _db_com import database RATE_LIMIT_FILE = "/app/ratelimit.json" RATE_LIMIT_WINDOW_SECONDS = 60 * 60 RATE_LIMIT_MAX_REQUESTS = 5 +MESSAGES_STORAGE = "portfolio_contact_messages" ALLOWED_ORIGINS = { "https://luna.reversed.dev", "http://localhost:5173", @@ -20,19 +21,22 @@ def _cors_headers(origin=""): allowed_origin = origin if origin in ALLOWED_ORIGINS else "https://luna.reversed.dev" return { "Access-Control-Allow-Origin": allowed_origin, - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, X-Lunas-Key", "Access-Control-Max-Age": "86400", "Vary": "Origin", "Content-Type": "application/json", } -def _response(origin, status_code, payload): +def _response(origin, status_code, payload, extra_headers=None): + headers = _cors_headers(origin) + if extra_headers: + headers.update(extra_headers) return { "_shsf": "v2", "_code": status_code, - "_headers": _cors_headers(origin), + "_headers": headers, "_res": payload, } @@ -113,24 +117,162 @@ def _check_rate_limit(ip_address, now): return True, None -def main(args): - origin = (args.get("headers", {}) or {}).get("origin", "") - method = str(args.get("method", "POST")).upper() - - if method == "OPTIONS": - return _response(origin, 204, "") - - if method != "POST": - return _response(origin, 405, {"ok": False, "error": "Method not allowed."}) - +def _parse_body(args): raw_body = args.get("body", "{}") try: payload = raw_body if isinstance(raw_body, dict) else json.loads(raw_body) if not isinstance(payload, dict): raise ValueError("JSON body must be an object") + return payload, None except (TypeError, json.JSONDecodeError, ValueError): - return _response(origin, 400, {"ok": False, "error": "Invalid JSON payload."}) + return None, "Invalid JSON payload." + +def _ensure_storage(db): + try: + db.create_storage(MESSAGES_STORAGE, purpose="Portfolio contact form submissions") + except Exception: + pass + + +def _parse_message_value(value): + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + return None + return None + + +def _list_message_keys(db): + items = db.list_items(MESSAGES_STORAGE) + if isinstance(items, dict): + return list(items.keys()) + if isinstance(items, list): + keys = [] + for item in items: + if isinstance(item, str): + keys.append(item) + elif isinstance(item, dict): + key = item.get("key") or item.get("name") or item.get("id") + if key: + keys.append(key) + return keys + return [] + + +def _get_message(db, message_id): + try: + value = db.get(MESSAGES_STORAGE, message_id) + except Exception: + return None + return _parse_message_value(value) + + +def _require_key(args): + expected = os.getenv("LUNAS_KEY", "").strip() + headers = {str(key).lower(): value for key, value in (args.get("headers", {}) or {}).items()} + provided = str(headers.get("x-lunas-key", "")).strip() + + if not expected: + return False, "LUNAS_KEY is not configured on the function." + if not provided or provided != expected: + return False, "Unauthorized." + return True, None + + +def _message_ids_from_payload(payload): + if not isinstance(payload, dict): + return [] + + ids = [] + single_id = payload.get("id") + multiple_ids = payload.get("ids") + + if isinstance(single_id, str) and single_id.strip(): + ids.append(single_id.strip()) + if isinstance(multiple_ids, list): + ids.extend(str(item).strip() for item in multiple_ids if str(item).strip()) + + deduped = [] + seen = set() + for message_id in ids: + if message_id not in seen: + seen.add(message_id) + deduped.append(message_id) + return deduped + + +def _handle_new(args, origin, db): + authorized, error_message = _require_key(args) + if not authorized: + return _response(origin, 401, {"ok": False, "error": error_message}) + + messages = [] + for message_id in _list_message_keys(db): + record = _get_message(db, message_id) + if not isinstance(record, dict): + continue + if record.get("seen") is True: + continue + messages.append(record) + + messages.sort(key=lambda item: item.get("created_at", ""), reverse=True) + return _response(origin, 200, {"ok": True, "messages": messages, "count": len(messages)}) + + +def _handle_seen(args, origin, db, payload): + authorized, error_message = _require_key(args) + if not authorized: + return _response(origin, 401, {"ok": False, "error": error_message}) + + message_ids = _message_ids_from_payload(payload) + if not message_ids: + return _response(origin, 400, {"ok": False, "error": "Provide `id` or `ids`."}) + + updated = [] + missing = [] + for message_id in message_ids: + record = _get_message(db, message_id) + if not isinstance(record, dict): + missing.append(message_id) + continue + record["seen"] = True + record["seen_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + db.set(MESSAGES_STORAGE, message_id, json.dumps(record)) + updated.append(message_id) + + return _response(origin, 200, {"ok": True, "updated": updated, "missing": missing}) + + +def _handle_delete(args, origin, db, payload): + authorized, error_message = _require_key(args) + if not authorized: + return _response(origin, 401, {"ok": False, "error": error_message}) + + message_ids = _message_ids_from_payload(payload) + if not message_ids: + return _response(origin, 400, {"ok": False, "error": "Provide `id` or `ids`."}) + + deleted = [] + missing = [] + for message_id in message_ids: + record = _get_message(db, message_id) + if not isinstance(record, dict): + missing.append(message_id) + continue + try: + db.delete_item(MESSAGES_STORAGE, message_id) + deleted.append(message_id) + except Exception: + missing.append(message_id) + + return _response(origin, 200, {"ok": True, "deleted": deleted, "missing": missing}) + + +def _handle_submit(args, origin, db, payload): if not payload and all(key in args for key in ("username", "email", "message")): payload = { "username": args.get("username", ""), @@ -146,24 +288,12 @@ def main(args): now = int(time.time()) allowed, retry_after = _check_rate_limit(ip_address, now) if not allowed: - return { - "_shsf": "v2", - "_code": 429, - "_headers": { - **_cors_headers(origin), - "Retry-After": str(retry_after), - }, - "_res": { - "ok": False, - "error": "Too many messages from this IP. Please try again later.", - }, - } - - db = database() - try: - db.create_storage("portfolio_contact_messages", purpose="Portfolio contact form submissions") - except Exception: - pass + return _response( + origin, + 429, + {"ok": False, "error": "Too many messages from this IP. Please try again later."}, + {"Retry-After": str(retry_after)}, + ) submission_id = f"message_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}" record = { @@ -174,7 +304,55 @@ def main(args): "ip": ip_address, "user_agent": (args.get("headers", {}) or {}).get("user-agent", ""), "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now)), + "seen": False, } - db.set("portfolio_contact_messages", submission_id, json.dumps(record)) + db.set(MESSAGES_STORAGE, submission_id, json.dumps(record)) return _response(origin, 201, {"ok": True, "message": "Message received.", "id": submission_id}) + + +def main(args): + headers = args.get("headers", {}) or {} + origin = headers.get("origin", "") + method = str(args.get("method", "POST")).upper() + route = str(args.get("route", "default") or "default").strip("/") or "default" + + if method == "OPTIONS": + return _response(origin, 204, "") + + db = database() + _ensure_storage(db) + + if route == "new": + if method not in {"GET", "POST"}: + return _response(origin, 405, {"ok": False, "error": "Method not allowed."}) + return _handle_new(args, origin, db) + + payload, payload_error = _parse_body(args) + if payload_error: + payload = {} + + if route == "seen": + if method != "POST": + return _response(origin, 405, {"ok": False, "error": "Method not allowed."}) + if payload_error: + return _response(origin, 400, {"ok": False, "error": payload_error}) + return _handle_seen(args, origin, db, payload) + + if route == "delete": + if method != "POST": + return _response(origin, 405, {"ok": False, "error": "Method not allowed."}) + if payload_error: + return _response(origin, 400, {"ok": False, "error": payload_error}) + return _handle_delete(args, origin, db, payload) + + if route != "default": + return _response(origin, 404, {"ok": False, "error": "Route not found."}) + + if method != "POST": + return _response(origin, 405, {"ok": False, "error": "Method not allowed."}) + + if payload_error: + return _response(origin, 400, {"ok": False, "error": payload_error}) + + return _handle_submit(args, origin, db, payload)