Files
Medios-Macina/ProviderCore/registry.py
2026-02-11 18:16:07 -08:00

678 lines
22 KiB
Python

"""Provider registry.
Concrete provider implementations live in the ``Provider`` package. This module
is the single source of truth for discovery, metadata, and lifecycle helpers
for those plugins.
"""
from __future__ import annotations
from functools import lru_cache
import importlib
import pkgutil
import sys
from dataclasses import dataclass, field
from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type
from urllib.parse import urlparse
from SYS.logger import log, debug
from ProviderCore.base import FileProvider, Provider, SearchProvider, SearchResult
def download_soulseek_file(*args: Any, **kwargs: Any) -> Any:
"""Lazy proxy for the soulseek downloader.
Importing the provider modules can be expensive; keeping this lazy avoids
paying that cost at registry import time.
"""
from Provider.soulseek import download_soulseek_file as _download
return _download(*args, **kwargs)
@dataclass(frozen=True)
class ProviderInfo:
"""Metadata about a single provider entry."""
canonical_name: str
provider_class: Type[Provider]
module: str
alias_names: Tuple[str, ...] = field(default_factory=tuple)
@property
def supports_search(self) -> bool:
return self.provider_class.search is not Provider.search
@property
def supports_upload(self) -> bool:
try:
exposed = bool(getattr(self.provider_class, "EXPOSE_AS_FILE_PROVIDER", True))
except Exception:
exposed = True
return exposed and (self.provider_class.upload is not Provider.upload)
class ProviderRegistry:
"""Handles discovery, registration, and lookup of provider classes."""
def __init__(self, package_name: str) -> None:
self.package_name = (package_name or "").strip()
self._infos: Dict[str, ProviderInfo] = {}
self._lookup: Dict[str, ProviderInfo] = {}
self._modules: set[str] = set()
self._discovered = False
def _normalize(self, value: Any) -> str:
return str(value or "").strip().lower()
def _candidate_names(self,
provider_class: Type[Provider],
override_name: Optional[str]) -> List[str]:
names: List[str] = []
seen: set[str] = set()
def _add(value: Any) -> None:
text = str(value or "").strip()
normalized = text.lower()
if not text or normalized in seen:
return
seen.add(normalized)
names.append(text)
if override_name:
_add(override_name)
else:
# Use explicit NAME or PROVIDER_NAME if available, else class name
_add(getattr(provider_class, "NAME", None))
_add(getattr(provider_class, "PROVIDER_NAME", None))
_add(getattr(provider_class, "__name__", None))
for alias in getattr(provider_class, "PROVIDER_ALIASES", ()) or ():
_add(alias)
return names
def register(
self,
provider_class: Type[Provider],
*,
override_name: Optional[str] = None,
extra_aliases: Optional[Sequence[str]] = None,
module_name: Optional[str] = None,
replace: bool = False,
) -> ProviderInfo:
"""Register a provider class with canonical and alias names."""
candidates = self._candidate_names(provider_class, override_name)
if not candidates:
raise ValueError("provider name candidates are required")
canonical = self._normalize(candidates[0])
if not canonical:
raise ValueError("provider name must not be empty")
alias_names: List[str] = []
alias_seen: set[str] = set()
for candidate in candidates[1:]:
normalized = self._normalize(candidate)
if not normalized or normalized == canonical or normalized in alias_seen:
continue
alias_seen.add(normalized)
alias_names.append(normalized)
for alias in extra_aliases or ():
normalized = self._normalize(alias)
if not normalized or normalized == canonical or normalized in alias_seen:
continue
alias_seen.add(normalized)
alias_names.append(normalized)
info = ProviderInfo(
canonical_name=canonical,
provider_class=provider_class,
module=module_name or getattr(provider_class, "__module__", "") or "",
alias_names=tuple(alias_names),
)
existing = self._infos.get(canonical)
if existing is not None and not replace:
return existing
self._infos[canonical] = info
for lookup in (canonical,) + tuple(alias_names):
self._lookup[lookup] = info
return info
def _register_module(self, module: ModuleType) -> None:
module_name = getattr(module, "__name__", "")
if not module_name or module_name in self._modules:
return
self._modules.add(module_name)
# Iterate module dict directly (faster than dir()+getattr()).
for candidate in vars(module).values():
if not isinstance(candidate, type):
continue
if not issubclass(candidate, Provider):
continue
if candidate in {Provider, SearchProvider, FileProvider}:
continue
if getattr(candidate, "__module__", "") != module_name:
continue
try:
self.register(candidate, module_name=module_name)
except Exception as exc:
log(f"[provider] Failed to register {module_name}.{candidate.__name__}: {exc}", file=sys.stderr)
def discover(self) -> None:
"""Import and register providers from the package."""
if self._discovered or not self.package_name:
return
self._discovered = True
try:
package = importlib.import_module(self.package_name)
except Exception as exc:
log(f"[provider] Failed to import package {self.package_name}: {exc}", file=sys.stderr)
return
self._register_module(package)
package_path = getattr(package, "__path__", None)
if not package_path:
return
for finder, module_name, _ in pkgutil.iter_modules(package_path):
if module_name.startswith("_"):
continue
if module_name.strip().lower() == "hifi":
continue
module_path = f"{self.package_name}.{module_name}"
try:
module = importlib.import_module(module_path)
except Exception as exc:
log(f"[provider] Failed to load {module_path}: {exc}", file=sys.stderr)
continue
self._register_module(module)
# Pick up any Provider subclasses loaded via other mechanisms.
self._sync_subclasses()
def _try_import_for_name(self, normalized_name: str) -> None:
"""Best-effort import for a single provider module.
This avoids importing every provider module when the caller only needs
one provider (common for CLI usage).
"""
name = str(normalized_name or "").strip().lower()
if not name or not self.package_name:
return
# Keep behavior consistent with full discovery (which skips hifi).
if name == "hifi":
return
candidates: List[str] = [name]
if "-" in name:
candidates.append(name.replace("-", "_"))
if "." in name:
candidates.append(name.split(".", 1)[0])
for mod_name in candidates:
if not mod_name:
continue
module_path = f"{self.package_name}.{mod_name}"
if module_path in self._modules:
continue
try:
module = importlib.import_module(module_path)
except Exception:
continue
self._register_module(module)
# Pick up subclasses in case the module registers indirectly.
self._sync_subclasses()
return
def get(self, name: str) -> Optional[ProviderInfo]:
if not name:
return None
normalized = self._normalize(name)
info = self._lookup.get(normalized)
if info is not None:
return info
# If we haven't done a full discovery yet, try importing just the
# module that matches the requested name.
if not self._discovered:
self._try_import_for_name(normalized)
info = self._lookup.get(normalized)
if info is not None:
return info
# Fall back to full package scan.
self.discover()
return self._lookup.get(normalized)
def iter_providers(self) -> Iterable[ProviderInfo]:
self.discover()
return tuple(self._infos.values())
def has_name(self, name: str) -> bool:
return self.get(name) is not None
def _sync_subclasses(self) -> None:
"""Walk all Provider subclasses in memory and register them."""
def _walk(cls: Type[Provider]) -> None:
for sub in cls.__subclasses__():
if sub in {SearchProvider, FileProvider}:
_walk(sub)
continue
try:
self.register(sub)
except Exception:
pass
_walk(sub)
_walk(Provider)
REGISTRY = ProviderRegistry("Provider")
@lru_cache(maxsize=512)
def _provider_url_patterns(provider_class: Type[Provider]) -> Sequence[str]:
try:
return list(provider_class.url_patterns())
except Exception:
return []
def register_provider(
provider_class: Type[Provider],
*,
name: Optional[str] = None,
aliases: Optional[Sequence[str]] = None,
module_name: Optional[str] = None,
replace: bool = False,
) -> ProviderInfo:
"""Register a provider class from tests or third-party packages."""
return REGISTRY.register(
provider_class,
override_name=name,
extra_aliases=aliases,
module_name=module_name,
replace=replace,
)
def get_provider_class(name: str) -> Optional[Type[Provider]]:
info = REGISTRY.get(name)
if info is None:
return None
return info.provider_class
def selection_auto_stage_for_table(
table_type: str,
stage_args: Optional[Sequence[str]] = None,
) -> Optional[list[str]]:
t = str(table_type or "").strip().lower()
if not t:
return None
provider_key = t.split(".", 1)[0] if "." in t else t
provider_class = get_provider_class(provider_key) or get_provider_class(t)
if provider_class is None:
return None
try:
return provider_class.selection_auto_stage(t, stage_args)
except Exception:
return None
def is_known_provider_name(name: str) -> bool:
return REGISTRY.has_name(name)
def _supports_search(provider: Provider) -> bool:
return provider.__class__.search is not Provider.search
def _supports_upload(provider: Provider) -> bool:
try:
exposed = bool(getattr(provider.__class__, "EXPOSE_AS_FILE_PROVIDER", True))
except Exception:
exposed = True
return exposed and (provider.__class__.upload is not Provider.upload)
def _normalize_choice_entry(entry: Any) -> Optional[Dict[str, Any]]:
if entry is None:
return None
if isinstance(entry, dict):
value = entry.get("value")
text = entry.get("text") or entry.get("label") or value
aliases = entry.get("alias") or entry.get("aliases") or []
value_str = str(value) if value is not None else (str(text) if text is not None else None)
text_str = str(text) if text is not None else value_str
if not value_str or not text_str:
return None
alias_list = [str(a) for a in aliases if a is not None]
return {"value": value_str, "text": text_str, "aliases": alias_list}
return {"value": str(entry), "text": str(entry), "aliases": []}
def _collect_inline_choice_mapping(provider: Provider) -> Dict[str, List[Dict[str, Any]]]:
mapping: Dict[str, List[Dict[str, Any]]] = {}
base = getattr(provider, "QUERY_ARG_CHOICES", None)
if not isinstance(base, dict):
base = getattr(provider, "INLINE_QUERY_FIELD_CHOICES", None)
def _merge_from(obj: Any) -> None:
if not isinstance(obj, dict):
return
for key, value in obj.items():
normalized: List[Dict[str, Any]] = []
seq = value
try:
if callable(seq):
seq = seq()
except Exception:
seq = value
if isinstance(seq, dict):
seq = seq.get("choices") or seq.get("values") or seq
if isinstance(seq, (list, tuple, set)):
for entry in seq:
n = _normalize_choice_entry(entry)
if n:
normalized.append(n)
if normalized:
mapping[str(key).strip().lower()] = normalized
_merge_from(base)
try:
fn = getattr(provider, "inline_query_field_choices", None)
if callable(fn):
_merge_from(fn())
except Exception:
pass
return mapping
def get_provider(name: str, config: Optional[Dict[str, Any]] = None) -> Optional[Provider]:
info = REGISTRY.get(name)
if info is None:
debug(f"[provider] Unknown provider: {name}")
return None
try:
provider = info.provider_class(config)
if not provider.validate():
debug(f"[provider] Provider '{name}' is not available")
return None
return provider
except Exception as exc:
debug(f"[provider] Error initializing '{name}': {exc}")
return None
def list_providers(config: Optional[Dict[str, Any]] = None) -> Dict[str, bool]:
availability: Dict[str, bool] = {}
for info in REGISTRY.iter_providers():
try:
provider = info.provider_class(config)
availability[info.canonical_name] = provider.validate()
except Exception:
availability[info.canonical_name] = False
return availability
def get_search_provider(name: str,
config: Optional[Dict[str, Any]] = None) -> Optional[SearchProvider]:
provider = get_provider(name, config)
if provider is None:
return None
if not _supports_search(provider):
debug(f"[provider] Provider '{name}' does not support search")
return None
return provider # type: ignore[return-value]
def list_search_providers(config: Optional[Dict[str, Any]] = None) -> Dict[str, bool]:
availability: Dict[str, bool] = {}
for info in REGISTRY.iter_providers():
try:
provider = info.provider_class(config)
availability[info.canonical_name] = bool(
provider.validate() and info.supports_search
)
except Exception:
availability[info.canonical_name] = False
return availability
def get_file_provider(name: str,
config: Optional[Dict[str, Any]] = None) -> Optional[FileProvider]:
provider = get_provider(name, config)
if provider is None:
return None
if not _supports_upload(provider):
debug(f"[provider] Provider '{name}' does not support upload")
return None
return provider # type: ignore[return-value]
def list_file_providers(config: Optional[Dict[str, Any]] = None) -> Dict[str, bool]:
availability: Dict[str, bool] = {}
for info in REGISTRY.iter_providers():
try:
provider = info.provider_class(config)
availability[info.canonical_name] = bool(
provider.validate() and info.supports_upload
)
except Exception:
availability[info.canonical_name] = False
return availability
def match_provider_name_for_url(url: str) -> Optional[str]:
raw_url = str(url or "").strip()
raw_url_lower = raw_url.lower()
try:
parsed = urlparse(raw_url)
host = (parsed.hostname or "").strip().lower()
path = (parsed.path or "").strip()
except Exception:
host = ""
path = ""
def _norm_host(h: str) -> str:
h_norm = str(h or "").strip().lower()
if h_norm.startswith("www."):
h_norm = h_norm[4:]
return h_norm
host_norm = _norm_host(host)
if host_norm:
if host_norm == "openlibrary.org" or host_norm.endswith(".openlibrary.org"):
return "openlibrary" if REGISTRY.has_name("openlibrary") else None
if host_norm == "archive.org" or host_norm.endswith(".archive.org"):
low_path = str(path or "").lower()
is_borrowish = (
low_path.startswith("/borrow/")
or low_path.startswith("/stream/")
or low_path.startswith("/services/loans/")
or "/services/loans/" in low_path
)
if is_borrowish:
return "openlibrary" if REGISTRY.has_name("openlibrary") else None
return "internetarchive" if REGISTRY.has_name("internetarchive") else None
for info in REGISTRY.iter_providers():
domains = _provider_url_patterns(info.provider_class)
if not domains:
continue
for domain in domains:
dom_raw = str(domain or "").strip()
dom = dom_raw.lower()
if not dom:
continue
if "://" in dom or dom.startswith("magnet:") or dom.endswith(":") or "🧲" in dom:
if raw_url_lower.startswith(dom):
return info.canonical_name
continue
dom_norm = _norm_host(dom)
if not dom_norm or not host_norm:
continue
if host_norm == dom_norm or host_norm.endswith("." + dom_norm):
return info.canonical_name
return None
def provider_inline_query_choices(
provider_name: str,
field_name: str,
config: Optional[Dict[str, Any]] = None,
) -> List[str]:
"""Return provider-declared inline query choices for a field (e.g., system:GBA).
Providers can expose a mapping via ``QUERY_ARG_CHOICES`` (preferred) or
``INLINE_QUERY_FIELD_CHOICES`` / ``inline_query_field_choices()``. The helper
keeps completion logic simple and reusable.
"""
pname = str(provider_name or "").strip().lower()
field = str(field_name or "").strip().lower()
if not pname or not field:
return []
provider = get_search_provider(pname, config)
if provider is None:
provider = get_provider(pname, config)
if provider is None:
return []
try:
mapping = _collect_inline_choice_mapping(provider)
if not mapping:
return []
entries = mapping.get(field, [])
if not entries:
return []
seen: set[str] = set()
out: List[str] = []
for entry in entries:
text = entry.get("text") or entry.get("value")
if not text:
continue
text_str = str(text)
if text_str in seen:
continue
seen.add(text_str)
out.append(text_str)
for alias in entry.get("aliases", []):
alias_str = str(alias)
if alias_str and alias_str not in seen:
seen.add(alias_str)
out.append(alias_str)
return out
except Exception:
return []
def get_provider_for_url(url: str,
config: Optional[Dict[str, Any]] = None) -> Optional[Provider]:
name = match_provider_name_for_url(url)
if not name:
return None
return get_provider(name, config)
def resolve_inline_filters(
provider: Provider,
inline_args: Dict[str, Any],
*,
field_transforms: Optional[Dict[str, Any]] = None,
) -> Dict[str, str]:
"""Map inline query args to provider filter values using declared choices.
- Uses provider's inline choice mapping (value/text/aliases) to resolve user text.
- Applies optional per-field transforms (e.g., str.upper).
- Returns normalized filters suitable for provider.search.
"""
filters: Dict[str, str] = {}
if not inline_args:
return filters
mapping = _collect_inline_choice_mapping(provider)
transforms = field_transforms or {}
for raw_key, raw_val in inline_args.items():
if raw_val is None:
continue
key = str(raw_key or "").strip().lower()
val_str = str(raw_val).strip()
if not key or not val_str:
continue
entries = mapping.get(key, [])
resolved: Optional[str] = None
val_lower = val_str.lower()
for entry in entries:
text = str(entry.get("text") or "").strip()
value = str(entry.get("value") or "").strip()
aliases = [str(a).strip() for a in entry.get("aliases", []) if a is not None]
if val_lower in {text.lower(), value.lower()} or val_lower in {a.lower() for a in aliases}:
resolved = value or text or val_str
break
if resolved is None:
resolved = val_str
transform = transforms.get(key)
if callable(transform):
try:
resolved = transform(resolved)
except Exception:
pass
if resolved:
filters[key] = str(resolved)
return filters
__all__ = [
"ProviderInfo",
"Provider",
"SearchProvider",
"FileProvider",
"SearchResult",
"register_provider",
"get_provider",
"list_providers",
"get_search_provider",
"list_search_providers",
"get_file_provider",
"list_file_providers",
"match_provider_name_for_url",
"get_provider_for_url",
"get_provider_class",
"selection_auto_stage_for_table",
"download_soulseek_file",
"provider_inline_query_choices",
]