refactor: harden CLI/client/config and centralize serialization
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
"""Configuration loader — reads config.yaml and merges with defaults.
|
||||
|
||||
Uses a simple built-in YAML parser to avoid adding PyYAML as a dependency.
|
||||
"""
|
||||
"""Configuration loader with YAML parsing and normalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import copy
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
# Default configuration
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"fetch": {
|
||||
"count": 50,
|
||||
@@ -31,131 +32,118 @@ DEFAULT_CONFIG = {
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
|
||||
def _parse_value(s):
|
||||
# type: (str) -> Union[str, int, float, bool]
|
||||
"""Parse a scalar YAML value."""
|
||||
if s == "true":
|
||||
return True
|
||||
if s == "false":
|
||||
return False
|
||||
# Remove surrounding quotes
|
||||
if (s.startswith('"') and s.endswith('"')) or (s.startswith("'") and s.endswith("'")):
|
||||
return s[1:-1]
|
||||
# Try number
|
||||
def load_config(config_path=None):
|
||||
# type: (Optional[str]) -> Dict[str, Any]
|
||||
"""Load and normalize config from YAML, merged with defaults."""
|
||||
config = copy.deepcopy(DEFAULT_CONFIG)
|
||||
path = _resolve_config_path(config_path)
|
||||
if not path:
|
||||
return config
|
||||
|
||||
try:
|
||||
if "." in s:
|
||||
return float(s)
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return s
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.warning("Failed to read config file %s: %s", path, exc)
|
||||
return config
|
||||
|
||||
try:
|
||||
parsed = yaml.safe_load(raw) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse YAML config %s: %s", path, exc)
|
||||
return config
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
logger.warning("Config root must be a mapping, got %s", type(parsed).__name__)
|
||||
return config
|
||||
|
||||
merged = _deep_merge(config, parsed)
|
||||
return _normalize_config(merged)
|
||||
|
||||
|
||||
def _parse_yaml(text):
|
||||
# type: (str) -> Dict[str, Any]
|
||||
"""Minimal YAML parser for our flat config structure.
|
||||
def _resolve_config_path(config_path):
|
||||
# type: (Optional[str]) -> Optional[Path]
|
||||
"""Find config path from explicit argument or default locations."""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
return path if path.exists() else None
|
||||
|
||||
Supports: scalars, inline arrays [...], indented "- item" arrays,
|
||||
nested objects via indentation.
|
||||
"""
|
||||
result = {} # type: Dict[str, Any]
|
||||
lines = text.split("\n")
|
||||
stack = [{"indent": -1, "obj": result}] # type: List[Dict[str, Any]]
|
||||
|
||||
for line in lines:
|
||||
# Strip comments and trailing whitespace
|
||||
trimmed = re.sub(r"#.*$", "", line).rstrip()
|
||||
if not trimmed or not trimmed.strip():
|
||||
continue
|
||||
|
||||
indent = len(line) - len(line.lstrip())
|
||||
content = trimmed.strip()
|
||||
|
||||
# Handle "- item" array entries
|
||||
if content.startswith("- "):
|
||||
parent = stack[-1]["obj"]
|
||||
keys = list(parent.keys())
|
||||
if keys:
|
||||
last_key = keys[-1]
|
||||
if not isinstance(parent[last_key], list):
|
||||
parent[last_key] = []
|
||||
parent[last_key].append(_parse_value(content[2:].strip()))
|
||||
continue
|
||||
|
||||
colon_idx = content.find(":")
|
||||
if colon_idx == -1:
|
||||
continue
|
||||
|
||||
key = content[:colon_idx].strip()
|
||||
raw_value = content[colon_idx + 1:].strip()
|
||||
|
||||
# Pop stack to find parent at correct indentation
|
||||
while len(stack) > 1 and stack[-1]["indent"] >= indent:
|
||||
stack.pop()
|
||||
parent = stack[-1]["obj"]
|
||||
|
||||
if raw_value == "" or raw_value == "|":
|
||||
# Nested object
|
||||
child = {} # type: Dict[str, Any]
|
||||
parent[key] = child
|
||||
stack.append({"indent": indent, "obj": child})
|
||||
elif raw_value.startswith("[") and raw_value.endswith("]"):
|
||||
# Inline array
|
||||
inner = raw_value[1:-1].strip()
|
||||
if inner == "":
|
||||
parent[key] = []
|
||||
else:
|
||||
parent[key] = [_parse_value(s.strip()) for s in inner.split(",")]
|
||||
else:
|
||||
parent[key] = _parse_value(raw_value)
|
||||
|
||||
return result
|
||||
candidates = [
|
||||
Path.cwd() / "config.yaml",
|
||||
Path(__file__).parent.parent / "config.yaml",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _deep_merge(target, source):
|
||||
# type: (Dict[str, Any], Dict[str, Any]) -> Dict[str, Any]
|
||||
# type: (Dict[str, Any], Mapping[str, Any]) -> Dict[str, Any]
|
||||
"""Deep merge source into target (source values override target)."""
|
||||
result = dict(target)
|
||||
for key in source:
|
||||
if (
|
||||
isinstance(source[key], dict)
|
||||
and isinstance(result.get(key), dict)
|
||||
):
|
||||
result[key] = _deep_merge(result[key], source[key])
|
||||
result = copy.deepcopy(target)
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict) and isinstance(result.get(key), dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = source[key]
|
||||
result[key] = copy.deepcopy(value)
|
||||
return result
|
||||
|
||||
|
||||
def load_config(config_path=None):
|
||||
# type: (str) -> Dict[str, Any]
|
||||
"""Load config from config.yaml, merged with defaults."""
|
||||
if config_path is None:
|
||||
# Look in current directory first, then script directory
|
||||
candidates = [
|
||||
Path.cwd() / "config.yaml",
|
||||
Path(__file__).parent.parent / "config.yaml",
|
||||
]
|
||||
for p in candidates:
|
||||
if p.exists():
|
||||
config_path = str(p)
|
||||
break
|
||||
def _normalize_config(config):
|
||||
# type: (Dict[str, Any]) -> Dict[str, Any]
|
||||
"""Normalize shape and value types."""
|
||||
normalized = copy.deepcopy(DEFAULT_CONFIG)
|
||||
merged = _deep_merge(normalized, config)
|
||||
|
||||
if config_path and Path(config_path).exists():
|
||||
try:
|
||||
raw = Path(config_path).read_text(encoding="utf-8")
|
||||
parsed = _parse_yaml(raw)
|
||||
config = _deep_merge(DEFAULT_CONFIG, parsed)
|
||||
except Exception:
|
||||
config = dict(DEFAULT_CONFIG)
|
||||
else:
|
||||
config = dict(DEFAULT_CONFIG)
|
||||
fetch = merged.get("fetch")
|
||||
if not isinstance(fetch, dict):
|
||||
fetch = {}
|
||||
fetch_count = _as_int(fetch.get("count"), DEFAULT_CONFIG["fetch"]["count"])
|
||||
fetch["count"] = max(fetch_count, 1)
|
||||
merged["fetch"] = fetch
|
||||
|
||||
# Ensure nested dicts exist
|
||||
config.setdefault("fetch", DEFAULT_CONFIG["fetch"])
|
||||
config.setdefault("filter", DEFAULT_CONFIG["filter"])
|
||||
filter_config = merged.get("filter")
|
||||
if not isinstance(filter_config, dict):
|
||||
filter_config = {}
|
||||
mode = str(filter_config.get("mode", "topN"))
|
||||
if mode not in {"topN", "score", "all"}:
|
||||
mode = "topN"
|
||||
filter_config["mode"] = mode
|
||||
filter_config["topN"] = max(_as_int(filter_config.get("topN"), 20), 1)
|
||||
filter_config["minScore"] = _as_float(filter_config.get("minScore"), 50.0)
|
||||
filter_config["excludeRetweets"] = bool(filter_config.get("excludeRetweets", False))
|
||||
|
||||
# Deep-copy filter weights if needed
|
||||
if "filter" in config and "weights" not in config["filter"]:
|
||||
config["filter"]["weights"] = dict(DEFAULT_CONFIG["filter"]["weights"])
|
||||
langs = filter_config.get("lang", [])
|
||||
if not isinstance(langs, list):
|
||||
langs = []
|
||||
filter_config["lang"] = [str(lang) for lang in langs if str(lang)]
|
||||
|
||||
return config
|
||||
weights = filter_config.get("weights", {})
|
||||
if not isinstance(weights, dict):
|
||||
weights = {}
|
||||
normalized_weights = {}
|
||||
default_weights = DEFAULT_CONFIG["filter"]["weights"]
|
||||
for key, default_value in default_weights.items():
|
||||
normalized_weights[key] = _as_float(weights.get(key), float(default_value))
|
||||
filter_config["weights"] = normalized_weights
|
||||
merged["filter"] = filter_config
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _as_int(value, default):
|
||||
# type: (Any, int) -> int
|
||||
"""Best-effort int conversion."""
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _as_float(value, default):
|
||||
# type: (Any, float) -> float
|
||||
"""Best-effort float conversion."""
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
Reference in New Issue
Block a user