refactored config plugin defintions
This commit is contained in:
@@ -0,0 +1,223 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from SYS.config import global_config
|
||||
from ProviderCore.registry import get_plugin_class, list_plugins
|
||||
from Store.registry import _discover_store_classes, _required_keys_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ConfigField = Dict[str, Any]
|
||||
|
||||
|
||||
def _normalize_schema(fields: Optional[Iterable[Any]]) -> List[ConfigField]:
|
||||
normalized: List[ConfigField] = []
|
||||
seen: set[str] = set()
|
||||
for raw_field in list(fields or []):
|
||||
if not isinstance(raw_field, dict):
|
||||
continue
|
||||
key = str(raw_field.get("key") or "").strip()
|
||||
if not key:
|
||||
continue
|
||||
key_upper = key.upper()
|
||||
if key_upper in seen:
|
||||
continue
|
||||
seen.add(key_upper)
|
||||
|
||||
field = dict(raw_field)
|
||||
field["key"] = key
|
||||
if "label" in field and field.get("label") is not None:
|
||||
field["label"] = str(field.get("label") or "")
|
||||
choices = field.get("choices")
|
||||
if choices is not None and not isinstance(choices, (list, tuple)):
|
||||
field["choices"] = [choices]
|
||||
elif isinstance(choices, tuple):
|
||||
field["choices"] = list(choices)
|
||||
normalized.append(field)
|
||||
return normalized
|
||||
|
||||
|
||||
def _call_schema(owner: Any, label: str) -> List[ConfigField]:
|
||||
schema_fn = getattr(owner, "config_schema", None)
|
||||
if not callable(schema_fn):
|
||||
return []
|
||||
try:
|
||||
return _normalize_schema(schema_fn())
|
||||
except Exception:
|
||||
logger.exception("Failed to load config schema for %s", label)
|
||||
return []
|
||||
|
||||
|
||||
def get_store_schema(store_type: str) -> List[ConfigField]:
|
||||
classes = _discover_store_classes()
|
||||
cls = classes.get(str(store_type or "").strip())
|
||||
if cls is None:
|
||||
return []
|
||||
return _call_schema(cls, f"store '{store_type}'")
|
||||
|
||||
|
||||
def get_provider_schema(provider_name: str) -> List[ConfigField]:
|
||||
plugin_class = get_plugin_class(str(provider_name or "").strip())
|
||||
if plugin_class is None:
|
||||
return []
|
||||
return _call_schema(plugin_class, f"provider '{provider_name}'")
|
||||
|
||||
|
||||
def get_tool_schema(tool_name: str) -> List[ConfigField]:
|
||||
tool_name = str(tool_name or "").strip()
|
||||
if not tool_name:
|
||||
return []
|
||||
try:
|
||||
module = importlib.import_module(f"tool.{tool_name}")
|
||||
except Exception:
|
||||
logger.exception("Failed to import tool module 'tool.%s'", tool_name)
|
||||
return []
|
||||
return _call_schema(module, f"tool '{tool_name}'")
|
||||
|
||||
|
||||
def get_item_schema(item_type: str, item_name: str) -> List[ConfigField]:
|
||||
normalized_type = str(item_type or "").strip()
|
||||
normalized_name = str(item_name or "").strip()
|
||||
if normalized_type.startswith("store-"):
|
||||
return get_store_schema(normalized_type.replace("store-", "", 1))
|
||||
if normalized_type == "provider":
|
||||
return get_provider_schema(normalized_name)
|
||||
if normalized_type == "tool":
|
||||
return get_tool_schema(normalized_name)
|
||||
return []
|
||||
|
||||
|
||||
def get_item_schema_map(item_type: str, item_name: str) -> Dict[str, ConfigField]:
|
||||
return {field["key"].upper(): field for field in get_item_schema(item_type, item_name)}
|
||||
|
||||
|
||||
def get_global_schema() -> List[ConfigField]:
|
||||
return _normalize_schema(global_config())
|
||||
|
||||
|
||||
def get_global_schema_map() -> Dict[str, ConfigField]:
|
||||
return {field["key"].upper(): field for field in get_global_schema()}
|
||||
|
||||
|
||||
def build_default_store_config(store_type: str, instance_name: str) -> Dict[str, Any]:
|
||||
config: Dict[str, Any] = {"NAME": instance_name}
|
||||
schema = get_store_schema(store_type)
|
||||
if schema:
|
||||
for field in schema:
|
||||
key = field["key"]
|
||||
if key.upper() == "NAME":
|
||||
continue
|
||||
config[key] = field.get("default", "")
|
||||
return config
|
||||
|
||||
classes = _discover_store_classes()
|
||||
cls = classes.get(str(store_type or "").strip())
|
||||
if cls is None:
|
||||
return config
|
||||
for required_key in _required_keys_for(cls):
|
||||
if required_key.upper() == "NAME":
|
||||
continue
|
||||
config[required_key] = ""
|
||||
return config
|
||||
|
||||
|
||||
def build_default_provider_config(provider_name: str) -> Dict[str, Any]:
|
||||
config: Dict[str, Any] = {}
|
||||
schema = get_provider_schema(provider_name)
|
||||
if schema:
|
||||
for field in schema:
|
||||
config[field["key"]] = field.get("default", "")
|
||||
return config
|
||||
|
||||
plugin_class = get_plugin_class(str(provider_name or "").strip())
|
||||
if plugin_class is None:
|
||||
return config
|
||||
try:
|
||||
for required_key in plugin_class.required_config_keys():
|
||||
config[str(required_key)] = ""
|
||||
except Exception:
|
||||
logger.exception("Failed to load legacy required config keys for provider '%s'", provider_name)
|
||||
return config
|
||||
|
||||
|
||||
def build_default_tool_config(tool_name: str) -> Dict[str, Any]:
|
||||
config: Dict[str, Any] = {}
|
||||
for field in get_tool_schema(tool_name):
|
||||
config[field["key"]] = field.get("default", "")
|
||||
return config
|
||||
|
||||
|
||||
def get_required_config_keys(item_type: str, item_name: str) -> List[str]:
|
||||
normalized_type = str(item_type or "").strip()
|
||||
normalized_name = str(item_name or "").strip()
|
||||
required_keys: List[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def _add_key(value: Any) -> None:
|
||||
key = str(value or "").strip()
|
||||
if not key:
|
||||
return
|
||||
key_upper = key.upper()
|
||||
if key_upper in seen:
|
||||
return
|
||||
seen.add(key_upper)
|
||||
required_keys.append(key)
|
||||
|
||||
for field in get_item_schema(normalized_type, normalized_name):
|
||||
if field.get("required"):
|
||||
_add_key(field.get("key"))
|
||||
|
||||
if normalized_type.startswith("store-"):
|
||||
store_type = normalized_type.replace("store-", "", 1)
|
||||
classes = _discover_store_classes()
|
||||
cls = classes.get(store_type)
|
||||
if cls is not None:
|
||||
for required_key in _required_keys_for(cls):
|
||||
_add_key(required_key)
|
||||
elif normalized_type == "provider":
|
||||
plugin_class = get_plugin_class(normalized_name)
|
||||
if plugin_class is not None:
|
||||
try:
|
||||
for required_key in plugin_class.required_config_keys():
|
||||
_add_key(required_key)
|
||||
except Exception:
|
||||
logger.exception("Failed to load required config keys for provider '%s'", normalized_name)
|
||||
|
||||
return required_keys
|
||||
|
||||
|
||||
def get_configurable_store_types() -> List[str]:
|
||||
options: List[str] = []
|
||||
for store_type in _discover_store_classes().keys():
|
||||
if get_store_schema(store_type):
|
||||
options.append(str(store_type))
|
||||
return sorted(set(options))
|
||||
|
||||
|
||||
def get_configurable_provider_types() -> List[str]:
|
||||
options: List[str] = []
|
||||
for provider_name in list_plugins().keys():
|
||||
if get_provider_schema(provider_name):
|
||||
options.append(str(provider_name))
|
||||
return sorted(set(options))
|
||||
|
||||
|
||||
def get_configurable_tool_types() -> List[str]:
|
||||
options: List[str] = []
|
||||
try:
|
||||
import tool as tool_package
|
||||
|
||||
for module_info in pkgutil.iter_modules(tool_package.__path__):
|
||||
tool_name = str(module_info.name or "").strip()
|
||||
if not tool_name:
|
||||
continue
|
||||
if get_tool_schema(tool_name):
|
||||
options.append(tool_name)
|
||||
except Exception:
|
||||
logger.exception("Failed to discover configurable tool modules")
|
||||
return sorted(set(options))
|
||||
Reference in New Issue
Block a user