"""Provider registry. Concrete provider implementations live in the ``Provider`` package. This module is the single source of truth for discovery, metadata, and lifecycle helpers for those plugins. """ from __future__ import annotations import importlib import pkgutil import sys from dataclasses import dataclass, field from types import ModuleType from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type from urllib.parse import urlparse from SYS.logger import log, debug from ProviderCore.base import FileProvider, Provider, SearchProvider, SearchResult from Provider.soulseek import download_soulseek_file @dataclass(frozen=True) class ProviderInfo: """Metadata about a single provider entry.""" canonical_name: str provider_class: Type[Provider] module: str alias_names: Tuple[str, ...] = field(default_factory=tuple) @property def supports_search(self) -> bool: return self.provider_class.search is not Provider.search @property def supports_upload(self) -> bool: return self.provider_class.upload is not Provider.upload class ProviderRegistry: """Handles discovery, registration, and lookup of provider classes.""" def __init__(self, package_name: str) -> None: self.package_name = (package_name or "").strip() self._infos: Dict[str, ProviderInfo] = {} self._lookup: Dict[str, ProviderInfo] = {} self._modules: set[str] = set() self._discovered = False def _normalize(self, value: Any) -> str: return str(value or "").strip().lower() def _candidate_names(self, provider_class: Type[Provider], override_name: Optional[str]) -> List[str]: names: List[str] = [] seen: set[str] = set() def _add(value: Any) -> None: text = str(value or "").strip() normalized = text.lower() if not text or normalized in seen: return seen.add(normalized) names.append(text) if override_name: _add(override_name) else: # Use class name as the primary canonical name _add(getattr(provider_class, "__name__", None)) _add(getattr(provider_class, "PROVIDER_NAME", None)) _add(getattr(provider_class, "NAME", None)) for alias in getattr(provider_class, "PROVIDER_ALIASES", ()) or (): _add(alias) return names def register( self, provider_class: Type[Provider], *, override_name: Optional[str] = None, extra_aliases: Optional[Sequence[str]] = None, module_name: Optional[str] = None, replace: bool = False, ) -> ProviderInfo: """Register a provider class with canonical and alias names.""" candidates = self._candidate_names(provider_class, override_name) if not candidates: raise ValueError("provider name candidates are required") canonical = self._normalize(candidates[0]) if not canonical: raise ValueError("provider name must not be empty") alias_names: List[str] = [] alias_seen: set[str] = set() for candidate in candidates[1:]: normalized = self._normalize(candidate) if not normalized or normalized == canonical or normalized in alias_seen: continue alias_seen.add(normalized) alias_names.append(normalized) for alias in extra_aliases or (): normalized = self._normalize(alias) if not normalized or normalized == canonical or normalized in alias_seen: continue alias_seen.add(normalized) alias_names.append(normalized) info = ProviderInfo( canonical_name=canonical, provider_class=provider_class, module=module_name or getattr(provider_class, "__module__", "") or "", alias_names=tuple(alias_names), ) existing = self._infos.get(canonical) if existing is not None and not replace: return existing self._infos[canonical] = info for lookup in (canonical,) + tuple(alias_names): self._lookup[lookup] = info return info def _register_module(self, module: ModuleType) -> None: module_name = getattr(module, "__name__", "") if not module_name or module_name in self._modules: return self._modules.add(module_name) for attr in dir(module): candidate = getattr(module, attr) if not isinstance(candidate, type): continue if not issubclass(candidate, Provider): continue if candidate in {Provider, SearchProvider, FileProvider}: continue if getattr(candidate, "__module__", "") != module_name: continue try: self.register(candidate, module_name=module_name) except Exception as exc: log(f"[provider] Failed to register {module_name}.{candidate.__name__}: {exc}", file=sys.stderr) def discover(self) -> None: """Import and register providers from the package.""" if self._discovered or not self.package_name: return self._discovered = True try: package = importlib.import_module(self.package_name) except Exception as exc: log(f"[provider] Failed to import package {self.package_name}: {exc}", file=sys.stderr) return self._register_module(package) package_path = getattr(package, "__path__", None) if not package_path: return for finder, module_name, _ in pkgutil.iter_modules(package_path): if module_name.startswith("_"): continue module_path = f"{self.package_name}.{module_name}" try: module = importlib.import_module(module_path) except Exception as exc: log(f"[provider] Failed to load {module_path}: {exc}", file=sys.stderr) continue self._register_module(module) def get(self, name: str) -> Optional[ProviderInfo]: self.discover() if not name: return None return self._lookup.get(self._normalize(name)) def iter_providers(self) -> Iterable[ProviderInfo]: self.discover() return tuple(self._infos.values()) def has_name(self, name: str) -> bool: return self.get(name) is not None def _sync_subclasses(self) -> None: """Walk all Provider subclasses in memory and register them.""" def _walk(cls: Type[Provider]) -> None: for sub in cls.__subclasses__(): if sub in {SearchProvider, FileProvider}: _walk(sub) continue try: self.register(sub) except Exception: pass _walk(sub) _walk(Provider) REGISTRY = ProviderRegistry("Provider") REGISTRY.discover() REGISTRY._sync_subclasses() def register_provider( provider_class: Type[Provider], *, name: Optional[str] = None, aliases: Optional[Sequence[str]] = None, module_name: Optional[str] = None, replace: bool = False, ) -> ProviderInfo: """Register a provider class from tests or third-party packages.""" return REGISTRY.register( provider_class, override_name=name, extra_aliases=aliases, module_name=module_name, replace=replace, ) def get_provider_class(name: str) -> Optional[Type[Provider]]: info = REGISTRY.get(name) if info is None: return None return info.provider_class def selection_auto_stage_for_table( table_type: str, stage_args: Optional[Sequence[str]] = None, ) -> Optional[list[str]]: t = str(table_type or "").strip().lower() if not t: return None provider_key = t.split(".", 1)[0] if "." in t else t provider_class = get_provider_class(provider_key) or get_provider_class(t) if provider_class is None: return None try: return provider_class.selection_auto_stage(t, stage_args) except Exception: return None def is_known_provider_name(name: str) -> bool: return REGISTRY.has_name(name) def _supports_search(provider: Provider) -> bool: return provider.__class__.search is not Provider.search def _supports_upload(provider: Provider) -> bool: return provider.__class__.upload is not Provider.upload def _provider_url_patterns(provider_class: Type[Provider]) -> Sequence[str]: try: return list(provider_class.url_patterns()) except Exception: return [] def get_provider(name: str, config: Optional[Dict[str, Any]] = None) -> Optional[Provider]: info = REGISTRY.get(name) if info is None: debug(f"[provider] Unknown provider: {name}") return None try: provider = info.provider_class(config) if not provider.validate(): debug(f"[provider] Provider '{name}' is not available") return None return provider except Exception as exc: debug(f"[provider] Error initializing '{name}': {exc}") return None def list_providers(config: Optional[Dict[str, Any]] = None) -> Dict[str, bool]: availability: Dict[str, bool] = {} for info in REGISTRY.iter_providers(): try: provider = info.provider_class(config) availability[info.canonical_name] = provider.validate() except Exception: availability[info.canonical_name] = False return availability def get_search_provider(name: str, config: Optional[Dict[str, Any]] = None) -> Optional[SearchProvider]: provider = get_provider(name, config) if provider is None: return None if not _supports_search(provider): debug(f"[provider] Provider '{name}' does not support search") return None return provider # type: ignore[return-value] def list_search_providers(config: Optional[Dict[str, Any]] = None) -> Dict[str, bool]: availability: Dict[str, bool] = {} for info in REGISTRY.iter_providers(): try: provider = info.provider_class(config) availability[info.canonical_name] = bool( provider.validate() and info.supports_search ) except Exception: availability[info.canonical_name] = False return availability def get_file_provider(name: str, config: Optional[Dict[str, Any]] = None) -> Optional[FileProvider]: provider = get_provider(name, config) if provider is None: return None if not _supports_upload(provider): debug(f"[provider] Provider '{name}' does not support upload") return None return provider # type: ignore[return-value] def list_file_providers(config: Optional[Dict[str, Any]] = None) -> Dict[str, bool]: availability: Dict[str, bool] = {} for info in REGISTRY.iter_providers(): try: provider = info.provider_class(config) availability[info.canonical_name] = bool( provider.validate() and info.supports_upload ) except Exception: availability[info.canonical_name] = False return availability def match_provider_name_for_url(url: str) -> Optional[str]: raw_url = str(url or "").strip() raw_url_lower = raw_url.lower() try: parsed = urlparse(raw_url) host = (parsed.hostname or "").strip().lower() path = (parsed.path or "").strip() except Exception: host = "" path = "" def _norm_host(h: str) -> str: h_norm = str(h or "").strip().lower() if h_norm.startswith("www."): h_norm = h_norm[4:] return h_norm host_norm = _norm_host(host) if host_norm: if host_norm == "openlibrary.org" or host_norm.endswith(".openlibrary.org"): return "openlibrary" if REGISTRY.has_name("openlibrary") else None if host_norm == "archive.org" or host_norm.endswith(".archive.org"): low_path = str(path or "").lower() is_borrowish = ( low_path.startswith("/borrow/") or low_path.startswith("/stream/") or low_path.startswith("/services/loans/") or "/services/loans/" in low_path ) if is_borrowish: return "openlibrary" if REGISTRY.has_name("openlibrary") else None return "internetarchive" if REGISTRY.has_name("internetarchive") else None for info in REGISTRY.iter_providers(): domains = _provider_url_patterns(info.provider_class) if not domains: continue for domain in domains: dom_raw = str(domain or "").strip() dom = dom_raw.lower() if not dom: continue if "://" in dom or dom.startswith("magnet:"): if raw_url_lower.startswith(dom): return info.canonical_name continue dom_norm = _norm_host(dom) if not dom_norm or not host_norm: continue if host_norm == dom_norm or host_norm.endswith("." + dom_norm): return info.canonical_name return None def provider_inline_query_choices( provider_name: str, field_name: str, config: Optional[Dict[str, Any]] = None, ) -> List[str]: """Return provider-declared inline query choices for a field (e.g., system:GBA). Providers can expose a mapping via ``QUERY_ARG_CHOICES`` (preferred) or ``INLINE_QUERY_FIELD_CHOICES`` / ``inline_query_field_choices()``. The helper keeps completion logic simple and reusable. This helper keeps completion logic simple and reusable. """ pname = str(provider_name or "").strip().lower() field = str(field_name or "").strip().lower() if not pname or not field: return [] provider = get_search_provider(pname, config) if provider is None: provider = get_provider(pname, config) if provider is None: return [] def _normalize_choice_entry(entry: Any) -> Optional[Dict[str, Any]]: if entry is None: return None if isinstance(entry, dict): value = entry.get("value") text = entry.get("text") or entry.get("label") or value aliases = entry.get("alias") or entry.get("aliases") or [] value_str = str(value) if value is not None else (str(text) if text is not None else None) text_str = str(text) if text is not None else value_str if not value_str or not text_str: return None alias_list = [str(a) for a in aliases if a is not None] return {"value": value_str, "text": text_str, "aliases": alias_list} # string/other primitives return {"value": str(entry), "text": str(entry), "aliases": []} def _collect_mapping(p) -> Dict[str, List[Dict[str, Any]]]: mapping: Dict[str, List[Dict[str, Any]]] = {} base = getattr(p, "QUERY_ARG_CHOICES", None) if not isinstance(base, dict): base = getattr(p, "INLINE_QUERY_FIELD_CHOICES", None) if isinstance(base, dict): for k, v in base.items(): normalized: List[Dict[str, Any]] = [] seq = v try: if callable(seq): seq = seq() except Exception: seq = v if isinstance(seq, dict): seq = seq.get("choices") or seq.get("values") or seq if isinstance(seq, (list, tuple, set)): for entry in seq: n = _normalize_choice_entry(entry) if n: normalized.append(n) if normalized: mapping[str(k).strip().lower()] = normalized try: fn = getattr(p, "inline_query_field_choices", None) if callable(fn): extra = fn() if isinstance(extra, dict): for k, v in extra.items(): normalized: List[Dict[str, Any]] = [] seq = v try: if callable(seq): seq = seq() except Exception: seq = v if isinstance(seq, dict): seq = seq.get("choices") or seq.get("values") or seq if isinstance(seq, (list, tuple, set)): for entry in seq: n = _normalize_choice_entry(entry) if n: normalized.append(n) if normalized: mapping[str(k).strip().lower()] = normalized except Exception: pass return mapping try: mapping = _collect_mapping(provider) if not mapping: return [] entries = mapping.get(field, []) if not entries: return [] seen: set[str] = set() out: List[str] = [] for entry in entries: text = entry.get("text") or entry.get("value") if not text: continue text_str = str(text) if text_str in seen: continue seen.add(text_str) out.append(text_str) for alias in entry.get("aliases", []): alias_str = str(alias) if alias_str and alias_str not in seen: seen.add(alias_str) out.append(alias_str) return out except Exception: return [] def get_provider_for_url(url: str, config: Optional[Dict[str, Any]] = None) -> Optional[Provider]: name = match_provider_name_for_url(url) if not name: return None return get_provider(name, config) def resolve_inline_filters( provider: Provider, inline_args: Dict[str, Any], *, field_transforms: Optional[Dict[str, Any]] = None, ) -> Dict[str, str]: """Map inline query args to provider filter values using declared choices. - Uses provider's inline choice mapping (value/text/aliases) to resolve user text. - Applies optional per-field transforms (e.g., str.upper). - Returns normalized filters suitable for provider.search. """ filters: Dict[str, str] = {} if not inline_args: return filters mapping = _collect_mapping(provider) transforms = field_transforms or {} for raw_key, raw_val in inline_args.items(): if raw_val is None: continue key = str(raw_key or "").strip().lower() val_str = str(raw_val).strip() if not key or not val_str: continue entries = mapping.get(key, []) resolved: Optional[str] = None val_lower = val_str.lower() for entry in entries: text = str(entry.get("text") or "").strip() value = str(entry.get("value") or "").strip() aliases = [str(a).strip() for a in entry.get("aliases", []) if a is not None] if val_lower in {text.lower(), value.lower()} or val_lower in {a.lower() for a in aliases}: resolved = value or text or val_str break if resolved is None: resolved = val_str transform = transforms.get(key) if callable(transform): try: resolved = transform(resolved) except Exception: pass if resolved: filters[key] = str(resolved) return filters __all__ = [ "ProviderInfo", "Provider", "SearchProvider", "FileProvider", "SearchResult", "register_provider", "get_provider", "list_providers", "get_search_provider", "list_search_providers", "get_file_provider", "list_file_providers", "match_provider_name_for_url", "get_provider_for_url", "get_provider_class", "selection_auto_stage_for_table", "download_soulseek_file", "provider_inline_query_choices", ]