Merge feature branch feat/contact-message-routes
This commit is contained in:
36
README.md
36
README.md
@@ -20,3 +20,39 @@ npm run dev
|
|||||||
```bash
|
```bash
|
||||||
docker compose up -d
|
docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Contact form backend
|
||||||
|
|
||||||
|
The contact form posts to an SHSF function in `shsf/contact-api/main.py`.
|
||||||
|
|
||||||
|
### Public route
|
||||||
|
|
||||||
|
- `POST /` submits a message
|
||||||
|
- Body: `{ "username": string, "email": string, "message": string }`
|
||||||
|
- Rate limited per IP
|
||||||
|
|
||||||
|
### Admin routes
|
||||||
|
|
||||||
|
These routes require `X-Lunas-Key`, which must match the function's `LUNAS_KEY` env var.
|
||||||
|
|
||||||
|
- `GET /new` returns unread messages
|
||||||
|
- `POST /seen` marks one or more messages as seen
|
||||||
|
- `POST /delete` deletes one or more messages
|
||||||
|
|
||||||
|
Example bodies:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "id": "message_..." }
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "ids": ["message_1", "message_2"] }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run backend tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 shsf/contact-api/test_main.py
|
||||||
|
```
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from _db_com import database
|
|||||||
RATE_LIMIT_FILE = "/app/ratelimit.json"
|
RATE_LIMIT_FILE = "/app/ratelimit.json"
|
||||||
RATE_LIMIT_WINDOW_SECONDS = 60 * 60
|
RATE_LIMIT_WINDOW_SECONDS = 60 * 60
|
||||||
RATE_LIMIT_MAX_REQUESTS = 5
|
RATE_LIMIT_MAX_REQUESTS = 5
|
||||||
|
MESSAGES_STORAGE = "portfolio_contact_messages"
|
||||||
ALLOWED_ORIGINS = {
|
ALLOWED_ORIGINS = {
|
||||||
"https://luna.reversed.dev",
|
"https://luna.reversed.dev",
|
||||||
"http://localhost:5173",
|
"http://localhost:5173",
|
||||||
@@ -20,19 +21,22 @@ def _cors_headers(origin=""):
|
|||||||
allowed_origin = origin if origin in ALLOWED_ORIGINS else "https://luna.reversed.dev"
|
allowed_origin = origin if origin in ALLOWED_ORIGINS else "https://luna.reversed.dev"
|
||||||
return {
|
return {
|
||||||
"Access-Control-Allow-Origin": allowed_origin,
|
"Access-Control-Allow-Origin": allowed_origin,
|
||||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||||
"Access-Control-Allow-Headers": "Content-Type",
|
"Access-Control-Allow-Headers": "Content-Type, X-Lunas-Key",
|
||||||
"Access-Control-Max-Age": "86400",
|
"Access-Control-Max-Age": "86400",
|
||||||
"Vary": "Origin",
|
"Vary": "Origin",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _response(origin, status_code, payload):
|
def _response(origin, status_code, payload, extra_headers=None):
|
||||||
|
headers = _cors_headers(origin)
|
||||||
|
if extra_headers:
|
||||||
|
headers.update(extra_headers)
|
||||||
return {
|
return {
|
||||||
"_shsf": "v2",
|
"_shsf": "v2",
|
||||||
"_code": status_code,
|
"_code": status_code,
|
||||||
"_headers": _cors_headers(origin),
|
"_headers": headers,
|
||||||
"_res": payload,
|
"_res": payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,24 +117,162 @@ def _check_rate_limit(ip_address, now):
|
|||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def _parse_body(args):
|
||||||
origin = (args.get("headers", {}) or {}).get("origin", "")
|
|
||||||
method = str(args.get("method", "POST")).upper()
|
|
||||||
|
|
||||||
if method == "OPTIONS":
|
|
||||||
return _response(origin, 204, "")
|
|
||||||
|
|
||||||
if method != "POST":
|
|
||||||
return _response(origin, 405, {"ok": False, "error": "Method not allowed."})
|
|
||||||
|
|
||||||
raw_body = args.get("body", "{}")
|
raw_body = args.get("body", "{}")
|
||||||
try:
|
try:
|
||||||
payload = raw_body if isinstance(raw_body, dict) else json.loads(raw_body)
|
payload = raw_body if isinstance(raw_body, dict) else json.loads(raw_body)
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
raise ValueError("JSON body must be an object")
|
raise ValueError("JSON body must be an object")
|
||||||
|
return payload, None
|
||||||
except (TypeError, json.JSONDecodeError, ValueError):
|
except (TypeError, json.JSONDecodeError, ValueError):
|
||||||
return _response(origin, 400, {"ok": False, "error": "Invalid JSON payload."})
|
return None, "Invalid JSON payload."
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_storage(db):
|
||||||
|
try:
|
||||||
|
db.create_storage(MESSAGES_STORAGE, purpose="Portfolio contact form submissions")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_message_value(value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _list_message_keys(db):
|
||||||
|
items = db.list_items(MESSAGES_STORAGE)
|
||||||
|
if isinstance(items, dict):
|
||||||
|
return list(items.keys())
|
||||||
|
if isinstance(items, list):
|
||||||
|
keys = []
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, str):
|
||||||
|
keys.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
key = item.get("key") or item.get("name") or item.get("id")
|
||||||
|
if key:
|
||||||
|
keys.append(key)
|
||||||
|
return keys
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _get_message(db, message_id):
|
||||||
|
try:
|
||||||
|
value = db.get(MESSAGES_STORAGE, message_id)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return _parse_message_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _require_key(args):
|
||||||
|
expected = os.getenv("LUNAS_KEY", "").strip()
|
||||||
|
headers = {str(key).lower(): value for key, value in (args.get("headers", {}) or {}).items()}
|
||||||
|
provided = str(headers.get("x-lunas-key", "")).strip()
|
||||||
|
|
||||||
|
if not expected:
|
||||||
|
return False, "LUNAS_KEY is not configured on the function."
|
||||||
|
if not provided or provided != expected:
|
||||||
|
return False, "Unauthorized."
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def _message_ids_from_payload(payload):
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return []
|
||||||
|
|
||||||
|
ids = []
|
||||||
|
single_id = payload.get("id")
|
||||||
|
multiple_ids = payload.get("ids")
|
||||||
|
|
||||||
|
if isinstance(single_id, str) and single_id.strip():
|
||||||
|
ids.append(single_id.strip())
|
||||||
|
if isinstance(multiple_ids, list):
|
||||||
|
ids.extend(str(item).strip() for item in multiple_ids if str(item).strip())
|
||||||
|
|
||||||
|
deduped = []
|
||||||
|
seen = set()
|
||||||
|
for message_id in ids:
|
||||||
|
if message_id not in seen:
|
||||||
|
seen.add(message_id)
|
||||||
|
deduped.append(message_id)
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_new(args, origin, db):
|
||||||
|
authorized, error_message = _require_key(args)
|
||||||
|
if not authorized:
|
||||||
|
return _response(origin, 401, {"ok": False, "error": error_message})
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for message_id in _list_message_keys(db):
|
||||||
|
record = _get_message(db, message_id)
|
||||||
|
if not isinstance(record, dict):
|
||||||
|
continue
|
||||||
|
if record.get("seen") is True:
|
||||||
|
continue
|
||||||
|
messages.append(record)
|
||||||
|
|
||||||
|
messages.sort(key=lambda item: item.get("created_at", ""), reverse=True)
|
||||||
|
return _response(origin, 200, {"ok": True, "messages": messages, "count": len(messages)})
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_seen(args, origin, db, payload):
|
||||||
|
authorized, error_message = _require_key(args)
|
||||||
|
if not authorized:
|
||||||
|
return _response(origin, 401, {"ok": False, "error": error_message})
|
||||||
|
|
||||||
|
message_ids = _message_ids_from_payload(payload)
|
||||||
|
if not message_ids:
|
||||||
|
return _response(origin, 400, {"ok": False, "error": "Provide `id` or `ids`."})
|
||||||
|
|
||||||
|
updated = []
|
||||||
|
missing = []
|
||||||
|
for message_id in message_ids:
|
||||||
|
record = _get_message(db, message_id)
|
||||||
|
if not isinstance(record, dict):
|
||||||
|
missing.append(message_id)
|
||||||
|
continue
|
||||||
|
record["seen"] = True
|
||||||
|
record["seen_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||||
|
db.set(MESSAGES_STORAGE, message_id, json.dumps(record))
|
||||||
|
updated.append(message_id)
|
||||||
|
|
||||||
|
return _response(origin, 200, {"ok": True, "updated": updated, "missing": missing})
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_delete(args, origin, db, payload):
|
||||||
|
authorized, error_message = _require_key(args)
|
||||||
|
if not authorized:
|
||||||
|
return _response(origin, 401, {"ok": False, "error": error_message})
|
||||||
|
|
||||||
|
message_ids = _message_ids_from_payload(payload)
|
||||||
|
if not message_ids:
|
||||||
|
return _response(origin, 400, {"ok": False, "error": "Provide `id` or `ids`."})
|
||||||
|
|
||||||
|
deleted = []
|
||||||
|
missing = []
|
||||||
|
for message_id in message_ids:
|
||||||
|
record = _get_message(db, message_id)
|
||||||
|
if not isinstance(record, dict):
|
||||||
|
missing.append(message_id)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
db.delete_item(MESSAGES_STORAGE, message_id)
|
||||||
|
deleted.append(message_id)
|
||||||
|
except Exception:
|
||||||
|
missing.append(message_id)
|
||||||
|
|
||||||
|
return _response(origin, 200, {"ok": True, "deleted": deleted, "missing": missing})
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_submit(args, origin, db, payload):
|
||||||
if not payload and all(key in args for key in ("username", "email", "message")):
|
if not payload and all(key in args for key in ("username", "email", "message")):
|
||||||
payload = {
|
payload = {
|
||||||
"username": args.get("username", ""),
|
"username": args.get("username", ""),
|
||||||
@@ -146,24 +288,12 @@ def main(args):
|
|||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
allowed, retry_after = _check_rate_limit(ip_address, now)
|
allowed, retry_after = _check_rate_limit(ip_address, now)
|
||||||
if not allowed:
|
if not allowed:
|
||||||
return {
|
return _response(
|
||||||
"_shsf": "v2",
|
origin,
|
||||||
"_code": 429,
|
429,
|
||||||
"_headers": {
|
{"ok": False, "error": "Too many messages from this IP. Please try again later."},
|
||||||
**_cors_headers(origin),
|
{"Retry-After": str(retry_after)},
|
||||||
"Retry-After": str(retry_after),
|
)
|
||||||
},
|
|
||||||
"_res": {
|
|
||||||
"ok": False,
|
|
||||||
"error": "Too many messages from this IP. Please try again later.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
db = database()
|
|
||||||
try:
|
|
||||||
db.create_storage("portfolio_contact_messages", purpose="Portfolio contact form submissions")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
submission_id = f"message_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}"
|
submission_id = f"message_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}"
|
||||||
record = {
|
record = {
|
||||||
@@ -174,7 +304,55 @@ def main(args):
|
|||||||
"ip": ip_address,
|
"ip": ip_address,
|
||||||
"user_agent": (args.get("headers", {}) or {}).get("user-agent", ""),
|
"user_agent": (args.get("headers", {}) or {}).get("user-agent", ""),
|
||||||
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now)),
|
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now)),
|
||||||
|
"seen": False,
|
||||||
}
|
}
|
||||||
db.set("portfolio_contact_messages", submission_id, json.dumps(record))
|
db.set(MESSAGES_STORAGE, submission_id, json.dumps(record))
|
||||||
|
|
||||||
return _response(origin, 201, {"ok": True, "message": "Message received.", "id": submission_id})
|
return _response(origin, 201, {"ok": True, "message": "Message received.", "id": submission_id})
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
headers = args.get("headers", {}) or {}
|
||||||
|
origin = headers.get("origin", "")
|
||||||
|
method = str(args.get("method", "POST")).upper()
|
||||||
|
route = str(args.get("route", "default") or "default").strip("/") or "default"
|
||||||
|
|
||||||
|
if method == "OPTIONS":
|
||||||
|
return _response(origin, 204, "")
|
||||||
|
|
||||||
|
db = database()
|
||||||
|
_ensure_storage(db)
|
||||||
|
|
||||||
|
if route == "new":
|
||||||
|
if method not in {"GET", "POST"}:
|
||||||
|
return _response(origin, 405, {"ok": False, "error": "Method not allowed."})
|
||||||
|
return _handle_new(args, origin, db)
|
||||||
|
|
||||||
|
payload, payload_error = _parse_body(args)
|
||||||
|
if payload_error:
|
||||||
|
payload = {}
|
||||||
|
|
||||||
|
if route == "seen":
|
||||||
|
if method != "POST":
|
||||||
|
return _response(origin, 405, {"ok": False, "error": "Method not allowed."})
|
||||||
|
if payload_error:
|
||||||
|
return _response(origin, 400, {"ok": False, "error": payload_error})
|
||||||
|
return _handle_seen(args, origin, db, payload)
|
||||||
|
|
||||||
|
if route == "delete":
|
||||||
|
if method != "POST":
|
||||||
|
return _response(origin, 405, {"ok": False, "error": "Method not allowed."})
|
||||||
|
if payload_error:
|
||||||
|
return _response(origin, 400, {"ok": False, "error": payload_error})
|
||||||
|
return _handle_delete(args, origin, db, payload)
|
||||||
|
|
||||||
|
if route != "default":
|
||||||
|
return _response(origin, 404, {"ok": False, "error": "Route not found."})
|
||||||
|
|
||||||
|
if method != "POST":
|
||||||
|
return _response(origin, 405, {"ok": False, "error": "Method not allowed."})
|
||||||
|
|
||||||
|
if payload_error:
|
||||||
|
return _response(origin, 400, {"ok": False, "error": payload_error})
|
||||||
|
|
||||||
|
return _handle_submit(args, origin, db, payload)
|
||||||
|
|||||||
158
shsf/contact-api/test_main.py
Normal file
158
shsf/contact-api/test_main.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
MODULE_PATH = Path(__file__).with_name("main.py")
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDB:
|
||||||
|
storages = {}
|
||||||
|
|
||||||
|
def create_storage(self, name, purpose=None):
|
||||||
|
self.storages.setdefault(name, {})
|
||||||
|
|
||||||
|
def set(self, storage, key, value):
|
||||||
|
self.storages.setdefault(storage, {})[key] = value
|
||||||
|
|
||||||
|
def get(self, storage, key):
|
||||||
|
return self.storages.get(storage, {}).get(key)
|
||||||
|
|
||||||
|
def list_items(self, storage):
|
||||||
|
return list(self.storages.get(storage, {}).keys())
|
||||||
|
|
||||||
|
def delete_item(self, storage, key):
|
||||||
|
self.storages.get(storage, {}).pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
class ContactApiTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
FakeDB.storages = {}
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.rate_limit_file = os.path.join(self.temp_dir.name, "ratelimit.json")
|
||||||
|
os.environ["LUNAS_KEY"] = "topsecret"
|
||||||
|
self.module = self._load_module()
|
||||||
|
self.module.RATE_LIMIT_FILE = self.rate_limit_file
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
os.environ.pop("LUNAS_KEY", None)
|
||||||
|
sys.modules.pop("contact_api_main_under_test", None)
|
||||||
|
sys.modules.pop("_db_com", None)
|
||||||
|
|
||||||
|
def _load_module(self):
|
||||||
|
fake_db_module = types.ModuleType("_db_com")
|
||||||
|
fake_db_module.database = FakeDB
|
||||||
|
sys.modules["_db_com"] = fake_db_module
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location("contact_api_main_under_test", MODULE_PATH)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def _res(self, response):
|
||||||
|
return response["_res"]
|
||||||
|
|
||||||
|
def test_submit_message_persists_record(self):
|
||||||
|
response = self.module.main(
|
||||||
|
{
|
||||||
|
"method": "POST",
|
||||||
|
"headers": {"origin": "https://luna.reversed.dev", "user-agent": "pytest"},
|
||||||
|
"body": json.dumps(
|
||||||
|
{
|
||||||
|
"username": "Space",
|
||||||
|
"email": "space@example.com",
|
||||||
|
"message": "Hello Luna, this is a valid test message.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response["_code"], 201)
|
||||||
|
payload = self._res(response)
|
||||||
|
self.assertTrue(payload["ok"])
|
||||||
|
stored = FakeDB.storages["portfolio_contact_messages"][payload["id"]]
|
||||||
|
record = json.loads(stored)
|
||||||
|
self.assertEqual(record["username"], "Space")
|
||||||
|
self.assertFalse(record["seen"])
|
||||||
|
|
||||||
|
def test_new_route_requires_key(self):
|
||||||
|
response = self.module.main({"method": "GET", "route": "new", "headers": {}})
|
||||||
|
self.assertEqual(response["_code"], 401)
|
||||||
|
self.assertFalse(self._res(response)["ok"])
|
||||||
|
|
||||||
|
def test_new_seen_and_delete_flow(self):
|
||||||
|
created = self.module.main(
|
||||||
|
{
|
||||||
|
"method": "POST",
|
||||||
|
"headers": {"origin": "https://luna.reversed.dev"},
|
||||||
|
"body": json.dumps(
|
||||||
|
{
|
||||||
|
"username": "Space",
|
||||||
|
"email": "space@example.com",
|
||||||
|
"message": "Hello Luna, this should show up in the admin inbox.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
message_id = self._res(created)["id"]
|
||||||
|
|
||||||
|
auth_headers = {"x-lunas-key": "topsecret", "origin": "https://luna.reversed.dev"}
|
||||||
|
listed = self.module.main({"method": "GET", "route": "new", "headers": auth_headers})
|
||||||
|
self.assertEqual(listed["_code"], 200)
|
||||||
|
self.assertEqual(self._res(listed)["count"], 1)
|
||||||
|
self.assertEqual(self._res(listed)["messages"][0]["id"], message_id)
|
||||||
|
|
||||||
|
seen = self.module.main(
|
||||||
|
{
|
||||||
|
"method": "POST",
|
||||||
|
"route": "seen",
|
||||||
|
"headers": auth_headers,
|
||||||
|
"body": json.dumps({"id": message_id}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(seen["_code"], 200)
|
||||||
|
self.assertEqual(self._res(seen)["updated"], [message_id])
|
||||||
|
|
||||||
|
listed_after_seen = self.module.main({"method": "GET", "route": "new", "headers": auth_headers})
|
||||||
|
self.assertEqual(self._res(listed_after_seen)["count"], 0)
|
||||||
|
|
||||||
|
deleted = self.module.main(
|
||||||
|
{
|
||||||
|
"method": "POST",
|
||||||
|
"route": "delete",
|
||||||
|
"headers": auth_headers,
|
||||||
|
"body": json.dumps({"id": message_id}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(deleted["_code"], 200)
|
||||||
|
self.assertEqual(self._res(deleted)["deleted"], [message_id])
|
||||||
|
|
||||||
|
def test_rate_limit_returns_retry_after(self):
|
||||||
|
payload = {
|
||||||
|
"username": "Space",
|
||||||
|
"email": "space@example.com",
|
||||||
|
"message": "Hello Luna, this message is long enough to pass validation.",
|
||||||
|
}
|
||||||
|
args = {
|
||||||
|
"method": "POST",
|
||||||
|
"headers": {"x-forwarded-for": "1.2.3.4", "origin": "https://luna.reversed.dev"},
|
||||||
|
"body": json.dumps(payload),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in range(self.module.RATE_LIMIT_MAX_REQUESTS):
|
||||||
|
response = self.module.main(args)
|
||||||
|
self.assertEqual(response["_code"], 201)
|
||||||
|
|
||||||
|
limited = self.module.main(args)
|
||||||
|
self.assertEqual(limited["_code"], 429)
|
||||||
|
self.assertIn("Retry-After", limited["_headers"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import { useEffect, useRef, useState } from 'react';
|
import { useEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
export function useIntersectionObserver(options = {}) {
|
export function useIntersectionObserver({ threshold = 0.1 } = {}) {
|
||||||
const ref = useRef(null);
|
const ref = useRef(null);
|
||||||
const [isIntersecting, setIsIntersecting] = useState(false);
|
const [isIntersecting, setIsIntersecting] = useState(false);
|
||||||
|
|
||||||
@@ -8,20 +8,17 @@ export function useIntersectionObserver(options = {}) {
|
|||||||
const element = ref.current;
|
const element = ref.current;
|
||||||
if (!element) return;
|
if (!element) return;
|
||||||
|
|
||||||
const observer = new IntersectionObserver(
|
const observer = new IntersectionObserver(([entry]) => {
|
||||||
([entry]) => {
|
if (entry.isIntersecting) {
|
||||||
if (entry.isIntersecting) {
|
setIsIntersecting(true);
|
||||||
setIsIntersecting(true);
|
observer.unobserve(element);
|
||||||
observer.unobserve(element);
|
}
|
||||||
}
|
}, { threshold });
|
||||||
},
|
|
||||||
{ threshold: 0.1, ...options }
|
|
||||||
);
|
|
||||||
|
|
||||||
observer.observe(element);
|
observer.observe(element);
|
||||||
|
|
||||||
return () => observer.disconnect();
|
return () => observer.disconnect();
|
||||||
}, []);
|
}, [threshold]);
|
||||||
|
|
||||||
return [ref, isIntersecting];
|
return [ref, isIntersecting];
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user