diff --git a/CLI.py b/CLI.py index 260e79d..139bc46 100644 --- a/CLI.py +++ b/CLI.py @@ -808,14 +808,6 @@ class CmdletIntrospection: provider_choices: List[str] = [] - if canonical_cmd in {"search-provider" - } and list_search_providers is not None: - providers = list_search_providers(config) or {} - available = [ - name for name, is_ready in providers.items() if is_ready - ] - return sorted(available) if available else sorted(providers.keys()) - if canonical_cmd in {"add-file"} and list_file_providers is not None: providers = list_file_providers(config) or {} available = [ @@ -1298,48 +1290,6 @@ class CmdletExecutor: emitted_items: Optional[List[Any]] = None, cmd_args: Optional[List[str]] = None, ) -> str: - if cmd_name in ("search-provider", "search_provider") and cmd_args: - provider: str = "" - query: str = "" - tokens = [str(a) for a in (cmd_args or [])] - pos: List[str] = [] - i = 0 - while i < len(tokens): - low = tokens[i].lower() - if low in {"-provider", - "--provider"} and i + 1 < len(tokens): - provider = str(tokens[i + 1]).strip() - i += 2 - continue - if low in {"-query", - "--query"} and i + 1 < len(tokens): - query = str(tokens[i + 1]).strip() - i += 2 - continue - if low in {"-limit", - "--limit"} and i + 1 < len(tokens): - i += 2 - continue - if not str(tokens[i]).startswith("-"): - pos.append(str(tokens[i])) - i += 1 - - if not provider and pos: - provider = str(pos[0]).strip() - pos = pos[1:] - if not query and pos: - query = " ".join(pos).strip() - - if provider and query: - provider_lower = provider.lower() - if provider_lower == "youtube": - provider_label = "Youtube" - elif provider_lower == "openlibrary": - provider_label = "OpenLibrary" - else: - provider_label = provider[:1].upper() + provider[1:] - return f"{provider_label}: {query}".strip().rstrip(":") - title_map = { "search-file": "Results", "search_file": "Results", @@ -1807,10 +1757,6 @@ class CmdletExecutor: "tags", "search-file", "search_file", - "search-provider", - "search_provider", - "search-store", - "search_store", } if cmd_name in self_managing_commands: @@ -3029,8 +2975,6 @@ class PipelineExecutor: stage_is_last = (stage_index + 1 >= len(stages)) if filter_spec is not None and stage_is_last: try: - from SYS.result_table import ResultTable - base_table = stage_table if base_table is None: base_table = ctx.get_last_result_table() @@ -3801,6 +3745,15 @@ class MedeiaCLI: def __init__(self) -> None: self._config_loader = ConfigLoader(root=self.ROOT) + + # Optional dependency auto-install for configured tools (best-effort). + try: + from SYS.optional_deps import maybe_auto_install_configured_tools + + maybe_auto_install_configured_tools(self._config_loader.load()) + except Exception: + pass + self._cmdlet_executor = CmdletExecutor(config_loader=self._config_loader) self._pipeline_executor = PipelineExecutor(config_loader=self._config_loader) @@ -3833,54 +3786,6 @@ class MedeiaCLI: pass return value - def _complete_search_provider(ctx, param, incomplete: str): # pragma: no cover - try: - from click.shell_completion import CompletionItem - except Exception: - return [] - - try: - from ProviderCore.registry import list_search_providers - - providers = list_search_providers(self._config_loader.load()) or {} - available = [n for n, ok in providers.items() if ok] - choices = sorted(available) if available else sorted(providers.keys()) - except Exception: - choices = [] - - inc = (incomplete or "").lower() - return [ - CompletionItem(name) for name in choices - if name and name.lower().startswith(inc) - ] - - @app.command("search-provider") - def search_provider( - provider: str = typer.Option( - ..., - "--provider", - "-p", - help="Provider name (bandcamp, libgen, soulseek, youtube)", - shell_complete=_complete_search_provider, - ), - query: str = typer.Argument(..., - help="Search query (quote for spaces)"), - limit: int = typer.Option( - 36, - "--limit", - "-l", - help="Maximum results to return" - ), - ) -> None: - self._cmdlet_executor.execute( - "search-provider", - ["-provider", - provider, - query, - "-limit", - str(limit)] - ) - @app.command("pipeline") def pipeline( command: str = typer.Option( @@ -3942,7 +3847,7 @@ class MedeiaCLI: if ctx.invoked_subcommand is None: self.run_repl() - _ = (search_provider, pipeline, repl, main_callback) + _ = (pipeline, repl, main_callback) # Dynamically register all cmdlets as top-level Typer commands so users can # invoke `mm [args]` directly from the shell. We use Click/Typer @@ -3950,9 +3855,7 @@ class MedeiaCLI: # the cmdlet system without Typer trying to parse them. try: names = list_cmdlet_names() - skip = {"search-provider", - "pipeline", - "repl"} + skip = {"pipeline", "repl"} for nm in names: if not nm or nm in skip: continue @@ -3964,7 +3867,7 @@ class MedeiaCLI: cmd_name, context_settings={ "ignore_unknown_options": True, - "allow_extra_args": True + "allow_extra_args": True, }, ) def _handler(ctx: typer.Context): @@ -4119,6 +4022,13 @@ Come to love it when others take what you share, as there is no greater joy block = provider_cfg.get(str(name).strip().lower()) return isinstance(block, dict) and bool(block) + def _has_tool(cfg: dict, name: str) -> bool: + tool_cfg = cfg.get("tool") + if not isinstance(tool_cfg, dict): + return False + block = tool_cfg.get(str(name).strip().lower()) + return isinstance(block, dict) and bool(block) + def _ping_url(url: str, timeout: float = 3.0) -> tuple[bool, str]: try: from API.HTTP import HTTPClient @@ -4542,6 +4452,45 @@ Come to love it when others take what you share, as there is no greater joy except Exception as exc: _add_startup_check("ERROR", "Cookies", detail=str(exc)) + # Tool checks (configured via [tool=...]) + if _has_tool(config, "florencevision"): + try: + tool_cfg = config.get("tool") + fv_cfg = tool_cfg.get("florencevision") if isinstance(tool_cfg, dict) else None + enabled = bool(fv_cfg.get("enabled")) if isinstance(fv_cfg, dict) else False + if not enabled: + _add_startup_check( + "DISABLED", + "FlorenceVision", + provider="tool", + detail="Not enabled", + ) + else: + from SYS.optional_deps import florencevision_missing_modules + + missing = florencevision_missing_modules() + if missing: + _add_startup_check( + "DISABLED", + "FlorenceVision", + provider="tool", + detail="Missing: " + ", ".join(missing), + ) + else: + _add_startup_check( + "ENABLED", + "FlorenceVision", + provider="tool", + detail="Ready", + ) + except Exception as exc: + _add_startup_check( + "DISABLED", + "FlorenceVision", + provider="tool", + detail=str(exc), + ) + if startup_table.rows: stdout_console().print() stdout_console().print(startup_table) @@ -4709,7 +4658,7 @@ Come to love it when others take what you share, as there is no greater joy if last_table is None: last_table = ctx.get_last_result_table() - # Auto-refresh search-store tables when navigating back, + # Auto-refresh search-file tables when navigating back, # so row payloads (titles/tags) reflect latest store state. try: src_cmd = ( @@ -4720,7 +4669,7 @@ Come to love it when others take what you share, as there is no greater joy if (isinstance(src_cmd, str) and src_cmd.lower().replace("_", - "-") == "search-store"): + "-") == "search-file"): src_args = ( getattr(last_table, "source_args", @@ -4748,7 +4697,7 @@ Come to love it when others take what you share, as there is no greater joy else: ctx.set_current_command_text( " ".join( - ["search-store", + ["search-file", *cleaned_args] ).strip() ) @@ -4756,7 +4705,7 @@ Come to love it when others take what you share, as there is no greater joy pass try: self._cmdlet_executor.execute( - "search-store", + "search-file", cleaned_args + ["--refresh"] ) finally: @@ -4768,7 +4717,7 @@ Come to love it when others take what you share, as there is no greater joy continue except Exception as exc: print( - f"Error refreshing search-store table: {exc}", + f"Error refreshing search-file table: {exc}", file=sys.stderr ) diff --git a/Provider/internetarchive.py b/Provider/internetarchive.py index 51cb111..863e9f3 100644 --- a/Provider/internetarchive.py +++ b/Provider/internetarchive.py @@ -311,7 +311,7 @@ class InternetArchive(Provider): """Internet Archive provider using the `internetarchive` Python module. Supports: - - search-provider -provider internetarchive + - search-file -provider internetarchive - download-file / provider.download() from search results - add-file -provider internetarchive (uploads) diff --git a/SYS/config.py b/SYS/config.py index 413fd95..802a7ef 100644 --- a/SYS/config.py +++ b/SYS/config.py @@ -14,13 +14,35 @@ _CONFIG_CACHE: Dict[str, Dict[str, Any]] = {} def _strip_inline_comment(line: str) -> str: - # Keep it simple: only strip full-line comments and inline comments that start after whitespace. - # Users can always quote values that contain '#' or ';'. + # Strip comments in a way that's friendly to common .conf usage: + # - Full-line comments starting with '#' or ';' + # - Inline comments starting with '#' or ';' *outside quotes* + # (e.g. dtype="float16" # optional) stripped = line.strip() if not stripped: return "" if stripped.startswith("#") or stripped.startswith(";"): return "" + + in_single = False + in_double = False + for i, ch in enumerate(line): + if ch == "'" and not in_double: + in_single = not in_single + continue + if ch == '"' and not in_single: + in_double = not in_double + continue + if in_single or in_double: + continue + + if ch in {"#", ";"}: + # Treat as a comment start only when preceded by whitespace. + # This keeps values like paths or tokens containing '#' working + # when quoted, and reduces surprises for unquoted values. + if i == 0 or line[i - 1].isspace(): + return line[:i].rstrip() + return line diff --git a/SYS/optional_deps.py b/SYS/optional_deps.py new file mode 100644 index 0000000..f53657e --- /dev/null +++ b/SYS/optional_deps.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import importlib +import os +import subprocess +import sys +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from SYS.logger import log +from SYS.rich_display import stdout_console + + +def _as_bool(value: Any, default: bool = False) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + s = str(value).strip().lower() + if s in {"1", "true", "yes", "on"}: + return True + if s in {"0", "false", "no", "off"}: + return False + return default + + +def _is_pytest() -> bool: + return bool(os.environ.get("PYTEST_CURRENT_TEST")) + + +def _try_import(module: str) -> bool: + try: + importlib.import_module(module) + return True + except Exception: + return False + + +def florencevision_missing_modules() -> List[str]: + missing: List[str] = [] + # pillow is already in requirements, but keep the check for robustness. + if not _try_import("transformers"): + missing.append("transformers") + if not _try_import("torch"): + missing.append("torch") + if not _try_import("PIL"): + missing.append("pillow") + # Florence-2 remote code frequently requires these extras. + if not _try_import("einops"): + missing.append("einops") + if not _try_import("timm"): + missing.append("timm") + return missing + + +def _pip_install(requirements: List[str]) -> Tuple[bool, str]: + if not requirements: + return True, "No requirements" + + cmd = [sys.executable, "-m", "pip", "install", "--upgrade", *requirements] + try: + proc = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + ) + if proc.returncode == 0: + importlib.invalidate_caches() + return True, proc.stdout.strip() or "Installed" + out = (proc.stdout or "") + "\n" + (proc.stderr or "") + return False, out.strip() or f"pip exited with code {proc.returncode}" + except Exception as exc: + return False, str(exc) + + +def maybe_auto_install_configured_tools(config: Dict[str, Any]) -> None: + """Best-effort dependency auto-installer for configured tools. + + This is intentionally conservative: + - Only acts when a tool block is enabled. + - Skips under pytest. + + Current supported tool(s): florencevision + """ + if _is_pytest(): + return + + tool_cfg = (config or {}).get("tool") + if not isinstance(tool_cfg, dict): + return + + fv = tool_cfg.get("florencevision") + if isinstance(fv, dict) and _as_bool(fv.get("enabled"), False): + auto_install = _as_bool(fv.get("auto_install"), True) + if not auto_install: + return + + missing = florencevision_missing_modules() + if not missing: + return + + names = ", ".join(missing) + try: + with stdout_console().status( + f"Installing FlorenceVision dependencies: {names}", + spinner="dots", + ): + ok, detail = _pip_install(missing) + except Exception: + log(f"[startup] FlorenceVision dependencies missing ({names}). Attempting auto-install...") + ok, detail = _pip_install(missing) + + if ok: + log("[startup] FlorenceVision dependency install OK") + else: + log(f"[startup] FlorenceVision dependency auto-install failed. {detail}") + + +__all__ = ["maybe_auto_install_configured_tools", "florencevision_missing_modules"] diff --git a/cmdlet/add_file.py b/cmdlet/add_file.py index 75885ca..3525fe4 100644 --- a/cmdlet/add_file.py +++ b/cmdlet/add_file.py @@ -38,6 +38,95 @@ from SYS.metadata import write_metadata SUPPORTED_MEDIA_EXTENSIONS = ALL_SUPPORTED_EXTENSIONS +def _maybe_apply_florencevision_tags( + media_path: Path, + tags: List[str], + config: Dict[str, Any], + pipe_obj: Optional[models.PipeObject] = None, +) -> List[str]: + """Optionally auto-tag images using the FlorenceVision tool. + + Controlled via config: + [tool=florencevision] + enabled=true + strict=false + + If strict=false (default), failures log a warning and return the original tags. + If strict=true, failures raise to abort the ingest. + """ + try: + tool_block = (config or {}).get("tool") + fv_block = tool_block.get("florencevision") if isinstance(tool_block, dict) else None + enabled = False + strict = False + if isinstance(fv_block, dict): + enabled = bool(fv_block.get("enabled")) + strict = bool(fv_block.get("strict")) + if not enabled: + return tags + + from tool.florencevision import FlorenceVisionTool + + # Special-case: if this file was produced by the `screen-shot` cmdlet, + # OCR is more useful than caption/detection for tagging screenshots. + cfg_for_tool: Dict[str, Any] = config + try: + action = str(getattr(pipe_obj, "action", "") or "") if pipe_obj is not None else "" + cmdlet_name = "" + if action.lower().startswith("cmdlet:"): + cmdlet_name = action.split(":", 1)[1].strip().lower() + if cmdlet_name in {"screen-shot", "screen_shot", "screenshot"}: + tool_block2 = dict((config or {}).get("tool") or {}) + fv_block2 = dict(tool_block2.get("florencevision") or {}) + fv_block2["task"] = "ocr" + tool_block2["florencevision"] = fv_block2 + cfg_for_tool = dict(config or {}) + cfg_for_tool["tool"] = tool_block2 + except Exception: + cfg_for_tool = config + + fv = FlorenceVisionTool(cfg_for_tool) + if not fv.enabled() or not fv.applicable_path(media_path): + return tags + + auto_tags = fv.tags_for_file(media_path) + + # Capture caption (if any) into PipeObject notes for downstream persistence. + try: + caption_text = getattr(fv, "last_caption", None) + if caption_text and pipe_obj is not None: + if not isinstance(pipe_obj.extra, dict): + pipe_obj.extra = {} + notes = pipe_obj.extra.get("notes") + if not isinstance(notes, dict): + notes = {} + notes.setdefault("caption", caption_text) + pipe_obj.extra["notes"] = notes + except Exception: + pass + + if not auto_tags: + return tags + + merged = merge_sequences(tags or [], auto_tags, case_sensitive=False) + debug(f"[add-file] FlorenceVision added {len(auto_tags)} tag(s)") + return merged + except Exception as exc: + # Decide strictness from config if we couldn't read it above. + strict2 = False + try: + tool_block = (config or {}).get("tool") + fv_block = tool_block.get("florencevision") if isinstance(tool_block, dict) else None + strict2 = bool(fv_block.get("strict")) if isinstance(fv_block, dict) else False + except Exception: + strict2 = False + + if strict or strict2: + raise + log(f"[add-file] Warning: FlorenceVision tagging failed: {exc}", file=sys.stderr) + return tags + + class Add_File(Cmdlet): """Add file into the DB""" @@ -349,14 +438,14 @@ class Add_File(Cmdlet): successes = 0 failures = 0 - # When add-file -store is the last stage, always show a final search-store table. + # When add-file -store is the last stage, always show a final search-file table. # This is especially important for multi-item ingests (e.g., multi-clip downloads) # so the user always gets a selectable ResultTable. - want_final_search_store = ( + want_final_search_file = ( bool(is_last_stage) and bool(is_storage_backend_location) and bool(location) ) - auto_search_store_after_add = False + auto_search_file_after_add = False # When ingesting multiple items into a backend store, defer URL association and # apply it once at the end (bulk) to avoid per-item URL API calls. @@ -879,9 +968,9 @@ class Add_File(Cmdlet): pending_url_associations= pending_url_associations, suppress_last_stage_overlay= - want_final_search_store, - auto_search_store= - auto_search_store_after_add, + want_final_search_file, + auto_search_file= + auto_search_file_after_add, ) else: code = self._handle_local_export( @@ -1005,8 +1094,8 @@ class Add_File(Cmdlet): collect_relationship_pairs=pending_relationship_pairs, defer_url_association=defer_url_association, pending_url_associations=pending_url_associations, - suppress_last_stage_overlay=want_final_search_store, - auto_search_store=auto_search_store_after_add, + suppress_last_stage_overlay=want_final_search_file, + auto_search_file=auto_search_file_after_add, ) else: code = self._handle_local_export( @@ -1053,7 +1142,7 @@ class Add_File(Cmdlet): # Always end add-file -store (when last stage) by showing the canonical store table. # This keeps output consistent and ensures @N selection works for multi-item ingests. - if want_final_search_store and collected_payloads: + if want_final_search_file and collected_payloads: try: hashes: List[str] = [] for payload in collected_payloads: @@ -1064,7 +1153,7 @@ class Add_File(Cmdlet): seen: set[str] = set() hashes = [h for h in hashes if not (h in seen or seen.add(h))] - refreshed_items = Add_File._try_emit_search_store_by_hashes( + refreshed_items = Add_File._try_emit_search_file_by_hashes( store=str(location), hash_values=hashes, config=config, @@ -1102,29 +1191,29 @@ class Add_File(Cmdlet): return 1 @staticmethod - def _try_emit_search_store_by_hashes( + def _try_emit_search_file_by_hashes( *, store: str, hash_values: List[str], config: Dict[str, Any] ) -> Optional[List[Any]]: - """Run search-store for a list of hashes and promote the table to a display overlay. + """Run search-file for a list of hashes and promote the table to a display overlay. - Returns the emitted search-store payload items on success, else None. + Returns the emitted search-file payload items on success, else None. """ hashes = [h for h in (hash_values or []) if isinstance(h, str) and len(h) == 64] if not store or not hashes: return None try: - from cmdlet.search_store import CMDLET as search_store_cmdlet + from cmdlet.search_file import CMDLET as search_file_cmdlet query = "hash:" + ",".join(hashes) args = ["-store", str(store), query] - debug(f'[add-file] Refresh: search-store -store {store} "{query}"') + debug(f'[add-file] Refresh: search-file -store {store} "{query}"') - # Run search-store under a temporary stage context so its ctx.emit() calls + # Run search-file under a temporary stage context so its ctx.emit() calls # don't interfere with the outer add-file pipeline stage. prev_ctx = ctx.get_stage_context() temp_ctx = ctx.PipelineStageContext( @@ -1137,7 +1226,7 @@ class Add_File(Cmdlet): ) ctx.set_stage_context(temp_ctx) try: - code = search_store_cmdlet.run(None, args, config) + code = search_file_cmdlet.run(None, args, config) emitted_items = list(getattr(temp_ctx, "emits", []) or []) finally: ctx.set_stage_context(prev_ctx) @@ -1145,7 +1234,7 @@ class Add_File(Cmdlet): if code != 0: return None - # Promote the search-store result to a display overlay so the CLI prints it + # Promote the search-file result to a display overlay so the CLI prints it # for action commands like add-file. stage_ctx = ctx.get_stage_context() is_last = (stage_ctx @@ -1171,7 +1260,7 @@ class Add_File(Cmdlet): return emitted_items except Exception as exc: debug( - f"[add-file] Failed to run search-store after add-file: {type(exc).__name__}: {exc}" + f"[add-file] Failed to run search-file after add-file: {type(exc).__name__}: {exc}" ) return None @@ -2109,7 +2198,7 @@ class Add_File(Cmdlet): """Emit a storage-style result payload. - Always emits the dict downstream (when in a pipeline). - - If this is the last stage (or not in a pipeline), prints a search-store-like table + - If this is the last stage (or not in a pipeline), prints a search-file-like table and sets an overlay table/items for @N selection. """ # Emit for downstream commands (no-op if not in a pipeline) @@ -2139,28 +2228,28 @@ class Add_File(Cmdlet): pass @staticmethod - def _try_emit_search_store_by_hash( + def _try_emit_search_file_by_hash( *, store: str, hash_value: str, config: Dict[str, Any] ) -> Optional[List[Any]]: - """Run search-store for a single hash so the final table/payload is consistent. + """Run search-file for a single hash so the final table/payload is consistent. Important: `add-file` is treated as an action command by the CLI, so the CLI only - prints tables for it when a display overlay exists. After running search-store, + prints tables for it when a display overlay exists. After running search-file, this copies the resulting table into the display overlay (when this is the last stage) so the canonical store table is what the user sees and can select from. - Returns the emitted search-store payload items on success, else None. + Returns the emitted search-file payload items on success, else None. """ try: - from cmdlet.search_store import CMDLET as search_store_cmdlet + from cmdlet.search_file import CMDLET as search_file_cmdlet args = ["-store", str(store), f"hash:{str(hash_value)}"] - # Run search-store under a temporary stage context so its ctx.emit() calls + # Run search-file under a temporary stage context so its ctx.emit() calls # don't interfere with the outer add-file pipeline stage. prev_ctx = ctx.get_stage_context() temp_ctx = ctx.PipelineStageContext( @@ -2173,14 +2262,14 @@ class Add_File(Cmdlet): ) ctx.set_stage_context(temp_ctx) try: - code = search_store_cmdlet.run(None, args, config) + code = search_file_cmdlet.run(None, args, config) emitted_items = list(getattr(temp_ctx, "emits", []) or []) finally: ctx.set_stage_context(prev_ctx) if code != 0: return None - # Promote the search-store result to a display overlay so the CLI prints it + # Promote the search-file result to a display overlay so the CLI prints it # for action commands like add-file. stage_ctx = ctx.get_stage_context() is_last = (stage_ctx @@ -2206,7 +2295,7 @@ class Add_File(Cmdlet): return emitted_items except Exception as exc: debug( - f"[add-file] Failed to run search-store after add-file: {type(exc).__name__}: {exc}" + f"[add-file] Failed to run search-file after add-file: {type(exc).__name__}: {exc}" ) return None @@ -3097,7 +3186,7 @@ class Add_File(Cmdlet): List[tuple[str, List[str]]]]] = None, suppress_last_stage_overlay: bool = False, - auto_search_store: bool = True, + auto_search_file: bool = True, ) -> int: """Handle uploading to a registered storage backend (e.g., 'test' folder store, 'hydrus', etc.).""" ##log(f"Adding file to storage backend '{backend_name}': {media_path.name}", file=sys.stderr) @@ -3217,6 +3306,15 @@ class Add_File(Cmdlet): ) ] + # Auto-tag (best-effort) BEFORE uploading so tags land with the stored file. + try: + tags = _maybe_apply_florencevision_tags(media_path, list(tags or []), config, pipe_obj=pipe_obj) + pipe_obj.tag = list(tags or []) + except Exception as exc: + # strict mode raises from helper; treat here as a hard failure + log(f"[add-file] FlorenceVision tagging error: {exc}", file=sys.stderr) + return 1 + # Call backend's add_file with full metadata # Backend returns hash as identifier file_identifier = backend.add_file( @@ -3254,7 +3352,7 @@ class Add_File(Cmdlet): }, ) - # Emit a search-store-like payload for consistent tables and natural piping. + # Emit a search-file-like payload for consistent tables and natural piping. # Keep hash/store for downstream commands (get-tag, get-file, etc.). resolved_hash = ( file_identifier if len(file_identifier) == 64 else @@ -3299,6 +3397,15 @@ class Add_File(Cmdlet): except Exception: pass + caption_note = Add_File._get_note_text(result, pipe_obj, "caption") + if caption_note: + try: + setter = getattr(backend, "set_note", None) + if callable(setter): + setter(resolved_hash, "caption", caption_note) + except Exception: + pass + meta: Dict[str, Any] = {} try: @@ -3350,16 +3457,16 @@ class Add_File(Cmdlet): pass # Keep the add-file 1-row summary overlay (when last stage), then emit the - # canonical search-store payload/table for piping/selection consistency. - if auto_search_store and resolved_hash and resolved_hash != "unknown": - # Show the add-file summary (overlay only) but let search-store provide the downstream payload. + # canonical search-file payload/table for piping/selection consistency. + if auto_search_file and resolved_hash and resolved_hash != "unknown": + # Show the add-file summary (overlay only) but let search-file provide the downstream payload. Add_File._emit_storage_result( payload, overlay=not suppress_last_stage_overlay, emit=False ) - refreshed_items = Add_File._try_emit_search_store_by_hash( + refreshed_items = Add_File._try_emit_search_file_by_hash( store=backend_name, hash_value=resolved_hash, config=config, diff --git a/cmdlet/download_file.py b/cmdlet/download_file.py index 877be90..8aeaf9e 100644 --- a/cmdlet/download_file.py +++ b/cmdlet/download_file.py @@ -1079,13 +1079,13 @@ class Download_File(Cmdlet): f"[download-file] Not available on OpenLibrary; searching LibGen for: {title_text}", file=sys.stderr, ) - from cmdlet.search_provider import CMDLET as _SEARCH_PROVIDER_CMDLET + from cmdlet.search_file import CMDLET as _SEARCH_FILE_CMDLET fallback_query = title_text - exec_fn = getattr(_SEARCH_PROVIDER_CMDLET, "exec", None) + exec_fn = getattr(_SEARCH_FILE_CMDLET, "exec", None) if not callable(exec_fn): log( - "[download-file] search-provider cmdlet unavailable; cannot run LibGen fallback search", + "[download-file] search-file cmdlet unavailable; cannot run LibGen fallback search", file=sys.stderr, ) continue @@ -1099,7 +1099,7 @@ class Download_File(Cmdlet): config, ) - # Promote the search-provider table to a display overlay so it renders. + # Promote the search-file table to a display overlay so it renders. try: table_obj = pipeline_context.get_last_result_table() items_obj = pipeline_context.get_last_result_items() diff --git a/cmdlet/download_media.py b/cmdlet/download_media.py index 59a0c0f..1acce20 100644 --- a/cmdlet/download_media.py +++ b/cmdlet/download_media.py @@ -1469,6 +1469,17 @@ class Download_Media(Cmdlet): clip_values: List[str] = [] item_values: List[str] = [] + def _uniq(values: Sequence[str]) -> List[str]: + seen: set[str] = set() + out: List[str] = [] + for v in values: + key = str(v) + if key in seen: + continue + seen.add(key) + out.append(v) + return out + if clip_spec: # Support keyed clip syntax: # -query "clip:3m4s-3m14s,1h22m-1h33m,item:2-3" @@ -1482,6 +1493,10 @@ class Download_Media(Cmdlet): clip_values.extend(query_keyed.get("clip", []) or []) item_values.extend(query_keyed.get("item", []) or []) + # QueryArg also hydrates clip via -query, so combine and deduplicate here + clip_values = _uniq(clip_values) + item_values = _uniq(item_values) + if item_values and not parsed.get("item"): parsed["item"] = ",".join([v for v in item_values if v]) diff --git a/cmdlet/get_url.py b/cmdlet/get_url.py index 9e1adc1..383d038 100644 --- a/cmdlet/get_url.py +++ b/cmdlet/get_url.py @@ -27,6 +27,8 @@ class UrlItem: hash: str store: str title: str = "" + size: int | None = None + ext: str = "" class Get_Url(Cmdlet): @@ -183,6 +185,58 @@ class Get_Url(Cmdlet): return "" + @staticmethod + def _resolve_size_ext_for_hash(backend: Any, file_hash: str, hit: Any = None) -> tuple[int | None, str]: + """Best-effort (size, ext) resolution for a found hash.""" + # First: see if the hit already includes these fields. + try: + size_val = get_field(hit, "size") + if size_val is None: + size_val = get_field(hit, "file_size") + if size_val is None: + size_val = get_field(hit, "filesize") + if size_val is None: + size_val = get_field(hit, "size_bytes") + size_int = int(size_val) if isinstance(size_val, (int, float)) else None + except Exception: + size_int = None + + try: + ext_val = get_field(hit, "ext") + if ext_val is None: + ext_val = get_field(hit, "extension") + ext = str(ext_val).strip().lstrip(".") if isinstance(ext_val, str) else "" + except Exception: + ext = "" + + if size_int is not None or ext: + return size_int, ext + + # Next: backend.get_metadata(hash) when available. + try: + if hasattr(backend, "get_metadata"): + meta = backend.get_metadata(file_hash) + if isinstance(meta, dict): + size_val2 = meta.get("size") + if size_val2 is None: + size_val2 = meta.get("file_size") + if size_val2 is None: + size_val2 = meta.get("filesize") + if size_val2 is None: + size_val2 = meta.get("size_bytes") + if isinstance(size_val2, (int, float)): + size_int = int(size_val2) + + ext_val2 = meta.get("ext") + if ext_val2 is None: + ext_val2 = meta.get("extension") + if isinstance(ext_val2, str) and ext_val2.strip(): + ext = ext_val2.strip().lstrip(".") + except Exception: + pass + + return size_int, ext + def _search_urls_across_stores(self, pattern: str, config: Dict[str, @@ -210,6 +264,7 @@ class Get_Url(Cmdlet): backend = storage[store_name] title_cache: Dict[str, str] = {} + meta_cache: Dict[str, tuple[int | None, str]] = {} # Search only URL-bearing records using the backend's URL search capability. # This avoids the expensive/incorrect "search('*')" scan. @@ -250,6 +305,11 @@ class Get_Url(Cmdlet): title = self._resolve_title_for_hash(backend, file_hash, hit) title_cache[file_hash] = title + size, ext = meta_cache.get(file_hash, (None, "")) + if size is None and not ext: + size, ext = self._resolve_size_ext_for_hash(backend, file_hash, hit) + meta_cache[file_hash] = (size, ext) + try: urls = backend.get_url(file_hash) except Exception: @@ -264,6 +324,8 @@ class Get_Url(Cmdlet): hash=str(file_hash), store=str(store_name), title=str(title or ""), + size=size, + ext=str(ext or ""), ) ) found_stores.add(str(store_name)) @@ -308,22 +370,44 @@ class Get_Url(Cmdlet): log(f"No urls matching pattern: {search_pattern}", file=sys.stderr) return 1 + # NOTE: The CLI can auto-render tables from emitted items. When emitting + # dataclass objects, the generic-object renderer will include `hash` as a + # visible column. To keep HASH available for chaining but hidden from the + # table, emit dicts (dict rendering hides `hash`) and provide an explicit + # `columns` list to force display order and size formatting. + display_items: List[Dict[str, Any]] = [] + table = ( ResultTable( - "URL Search Results", - max_columns=3 - ).set_preserve_order(True).set_table("urls").set_value_case("preserve") + "url", + max_columns=5 + ).set_preserve_order(True).set_table("url").set_value_case("preserve") ) table.set_source_command("get-url", ["-url", search_pattern]) for item in items: - row = table.add_row() - row.add_column("Title", item.title) - row.add_column("Url", item.url) - row.add_column("Store", item.store) - ctx.emit(item) + payload: Dict[str, Any] = { + # Keep fields for downstream cmdlets. + "hash": item.hash, + "store": item.store, + "url": item.url, + "title": item.title, + "size": item.size, + "ext": item.ext, + # Force the visible table columns + ordering. + "columns": [ + ("Title", item.title), + ("Url", item.url), + ("Size", item.size), + ("Ext", item.ext), + ("Store", item.store), + ], + } + display_items.append(payload) + table.add_result(payload) + ctx.emit(payload) - ctx.set_last_result_table(table if items else None, items, subject=result) + ctx.set_last_result_table(table if display_items else None, display_items, subject=result) log( f"Found {len(items)} matching url(s) in {len(stores_searched)} store(s)" ) diff --git a/cmdlet/search_store.py b/cmdlet/search_file.py similarity index 66% rename from cmdlet/search_store.py rename to cmdlet/search_file.py index 9a66f99..0483639 100644 --- a/cmdlet/search_store.py +++ b/cmdlet/search_file.py @@ -1,15 +1,18 @@ -"""Search-store cmdlet: Search for files in storage backends (Folder, Hydrus).""" +"""search-file cmdlet: Search for files in storage backends (Folder, Hydrus).""" from __future__ import annotations from typing import Any, Dict, Sequence, List, Optional +import importlib +import uuid from pathlib import Path -from collections import OrderedDict import re import json import sys from SYS.logger import log, debug +from ProviderCore.registry import get_search_provider, list_search_providers +from SYS.config import get_local_storage_path from . import _shared as sh @@ -39,14 +42,14 @@ STORAGE_ORIGINS = {"local", "folder"} -class Search_Store(Cmdlet): - """Class-based search-store cmdlet for searching storage backends.""" +class search_file(Cmdlet): + """Class-based search-file cmdlet for searching storage backends.""" def __init__(self) -> None: super().__init__( - name="search-store", - summary="Search storage backends (Folder, Hydrus) for files.", - usage="search-store [-query ] [-store BACKEND] [-limit N]", + name="search-file", + summary="Search storage backends (Folder, Hydrus) or external providers (via -provider).", + usage="search-file [-query ] [-store BACKEND] [-limit N] [-provider NAME]", arg=[ CmdletArg( "limit", @@ -55,6 +58,17 @@ class Search_Store(Cmdlet): ), SharedArgs.STORE, SharedArgs.QUERY, + CmdletArg( + "provider", + type="string", + description= + "External provider name: bandcamp, libgen, soulseek, youtube, alldebrid, loc, internetarchive", + ), + CmdletArg( + "open", + type="integer", + description="(alldebrid) Open folder/magnet by ID and list its files", + ), ], detail=[ "Search across storage backends: Folder stores and Hydrus instances", @@ -64,14 +78,19 @@ class Search_Store(Cmdlet): "Hydrus-style extension: system:filetype = png", "Results include hash for downstream commands (get-file, add-tag, etc.)", "Examples:", - "search-store -query foo # Search all storage backends", - "search-store -store home -query '*' # Search 'home' Hydrus instance", - "search-store -store test -query 'video' # Search 'test' folder store", - "search-store -query 'hash:deadbeef...' # Search by SHA256 hash", - "search-store -query 'url:*' # Files that have any URL", - "search-store -query 'url:youtube.com' # Files whose URL contains substring", - "search-store -query 'ext:png' # Files whose metadata ext is png", - "search-store -query 'system:filetype = png' # Hydrus: native; Folder: maps to metadata.ext", + "search-file -query foo # Search all storage backends", + "search-file -store home -query '*' # Search 'home' Hydrus instance", + "search-file -store test -query 'video' # Search 'test' folder store", + "search-file -query 'hash:deadbeef...' # Search by SHA256 hash", + "search-file -query 'url:*' # Files that have any URL", + "search-file -query 'url:youtube.com' # Files whose URL contains substring", + "search-file -query 'ext:png' # Files whose metadata ext is png", + "search-file -query 'system:filetype = png' # Hydrus: native; Folder: maps to metadata.ext", + "", + "Provider search (-provider):", + "search-file -provider youtube 'tutorial' # Search YouTube provider", + "search-file -provider alldebrid '*' # List AllDebrid magnets", + "search-file -provider alldebrid -open 123 '*' # Show files for a magnet", ], exec=self.run, ) @@ -120,6 +139,172 @@ class Search_Store(Cmdlet): # This allows the table to respect max_columns and apply consistent formatting return payload + def _run_provider_search( + self, + *, + provider_name: str, + query: str, + limit: int, + limit_set: bool, + open_id: Optional[int], + args_list: List[str], + refresh_mode: bool, + config: Dict[str, Any], + ) -> int: + """Execute external provider search.""" + + if not provider_name or not query: + log("Error: search-file -provider requires both provider and query", file=sys.stderr) + log(f"Usage: {self.usage}", file=sys.stderr) + log("Available providers:", file=sys.stderr) + providers = list_search_providers(config) + for name, available in sorted(providers.items()): + status = "\u2713" if available else "\u2717" + log(f" {status} {name}", file=sys.stderr) + return 1 + + # Align with provider default when user did not set -limit. + if not limit_set: + limit = 50 + + debug(f"[search-file] provider={provider_name}, query={query}, limit={limit}, open_id={open_id}") + + provider = get_search_provider(provider_name, config) + if not provider: + log(f"Error: Provider '{provider_name}' is not available", file=sys.stderr) + log("Available providers:", file=sys.stderr) + providers = list_search_providers(config) + for name, available in sorted(providers.items()): + if available: + log(f" - {name}", file=sys.stderr) + return 1 + + worker_id = str(uuid.uuid4()) + library_root = get_local_storage_path(config or {}) if get_local_storage_path else None + + db = None + if library_root: + try: + from API.folder import API_folder_store + + db = API_folder_store(library_root) + db.__enter__() + db.insert_worker( + worker_id, + "search-file", + title=f"Search: {query}", + description=f"Provider: {provider_name}, Query: {query}", + pipe=ctx.get_current_command_text(), + ) + except Exception: + db = None + + try: + results_list: List[Dict[str, Any]] = [] + + from SYS import result_table + + importlib.reload(result_table) + from SYS.result_table import ResultTable + + provider_text = str(provider_name or "").strip() + provider_lower = provider_text.lower() + if provider_lower == "youtube": + provider_label = "Youtube" + elif provider_lower == "openlibrary": + provider_label = "OpenLibrary" + elif provider_lower == "loc": + provider_label = "LoC" + else: + provider_label = provider_text[:1].upper() + provider_text[1:] if provider_text else "Provider" + + if provider_lower == "alldebrid" and open_id is not None: + table_title = f"{provider_label} Files: {open_id}".strip().rstrip(":") + else: + table_title = f"{provider_label}: {query}".strip().rstrip(":") + + preserve_order = provider_lower in {"youtube", "openlibrary", "loc"} + table = ResultTable(table_title).set_preserve_order(preserve_order) + table.set_table(provider_name) + table.set_source_command("search-file", list(args_list)) + + debug(f"[search-file] Calling {provider_name}.search()") + if provider_lower == "alldebrid": + if open_id is not None: + results = provider.search(query, limit=limit, filters={"view": "files", "magnet_id": open_id}) + else: + results = provider.search(query, limit=limit, filters={"view": "folders"}) + else: + results = provider.search(query, limit=limit) + debug(f"[search-file] {provider_name} -> {len(results or [])} result(s)") + + if not results: + log(f"No results found for query: {query}", file=sys.stderr) + if db is not None: + db.append_worker_stdout(worker_id, json.dumps([], indent=2)) + db.update_worker_status(worker_id, "completed") + return 0 + + for search_result in results: + item_dict = ( + search_result.to_dict() + if hasattr(search_result, "to_dict") + else dict(search_result) + if isinstance(search_result, dict) + else {"title": str(search_result)} + ) + + if "table" not in item_dict: + item_dict["table"] = provider_name + + row_index = len(table.rows) + table.add_result(search_result) + + try: + if provider_lower == "alldebrid" and getattr(search_result, "media_kind", "") == "folder": + magnet_id = None + meta = getattr(search_result, "full_metadata", None) + if isinstance(meta, dict): + magnet_id = meta.get("magnet_id") + if magnet_id is not None: + table.set_row_selection_args(row_index, ["-open", str(magnet_id), "-query", "*"]) + except Exception: + pass + + results_list.append(item_dict) + ctx.emit(item_dict) + + if refresh_mode: + ctx.set_last_result_table_preserve_history(table, results_list) + else: + ctx.set_last_result_table(table, results_list) + + ctx.set_current_stage_table(table) + + if db is not None: + db.append_worker_stdout(worker_id, json.dumps(results_list, indent=2)) + db.update_worker_status(worker_id, "completed") + + return 0 + + except Exception as exc: + log(f"Error searching provider '{provider_name}': {exc}", file=sys.stderr) + import traceback + + debug(traceback.format_exc()) + if db is not None: + try: + db.update_worker_status(worker_id, "error") + except Exception: + pass + return 1 + finally: + if db is not None: + try: + db.__exit__(None, None, None) + except Exception: + pass + # --- Execution ------------------------------------------------------ def run(self, result: Any, args: Sequence[str], config: Dict[str, Any]) -> int: """Search storage backends for files.""" @@ -164,7 +349,7 @@ class Search_Store(Cmdlet): raw_title = None command_title = (str(raw_title).strip() if raw_title else - "") or _format_command_title("search-store", + "") or _format_command_title("search-file", list(args_list)) # Build dynamic flag variants from cmdlet arg definitions. @@ -182,11 +367,22 @@ class Search_Store(Cmdlet): f.lower() for f in (flag_registry.get("limit") or {"-limit", "--limit"}) } + provider_flags = { + f.lower() + for f in (flag_registry.get("provider") or {"-provider", "--provider"}) + } + open_flags = { + f.lower() + for f in (flag_registry.get("open") or {"-open", "--open"}) + } # Parse arguments query = "" storage_backend: Optional[str] = None + provider_name: Optional[str] = None + open_id: Optional[int] = None limit = 100 + limit_set = False searched_backends: List[str] = [] i = 0 @@ -198,10 +394,26 @@ class Search_Store(Cmdlet): query = f"{query} {chunk}".strip() if query else chunk i += 2 continue + if low in provider_flags and i + 1 < len(args_list): + provider_name = args_list[i + 1] + i += 2 + continue + if low in open_flags and i + 1 < len(args_list): + try: + open_id = int(args_list[i + 1]) + except ValueError: + log( + f"Warning: Invalid open value '{args_list[i + 1]}', ignoring", + file=sys.stderr, + ) + open_id = None + i += 2 + continue if low in store_flags and i + 1 < len(args_list): storage_backend = args_list[i + 1] i += 2 elif low in limit_flags and i + 1 < len(args_list): + limit_set = True try: limit = int(args_list[i + 1]) except ValueError: @@ -213,6 +425,20 @@ class Search_Store(Cmdlet): else: i += 1 + query = query.strip() + + if provider_name: + return self._run_provider_search( + provider_name=provider_name, + query=query, + limit=limit, + limit_set=limit_set, + open_id=open_id, + args_list=args_list, + refresh_mode=refresh_mode, + config=config, + ) + store_filter: Optional[str] = None if query: match = re.search(r"\bstore:([^\s,]+)", query, flags=re.IGNORECASE) @@ -232,8 +458,6 @@ class Search_Store(Cmdlet): return 1 from API.folder import API_folder_store - from SYS.config import get_local_storage_path - import uuid worker_id = str(uuid.uuid4()) library_root = get_local_storage_path(config or {}) @@ -246,7 +470,7 @@ class Search_Store(Cmdlet): try: db.insert_worker( worker_id, - "search-store", + "search-file", title=f"Search: {query}", description=f"Query: {query}", pipe=ctx.get_current_command_text(), @@ -261,7 +485,7 @@ class Search_Store(Cmdlet): table = ResultTable(command_title) try: - table.set_source_command("search-store", list(args_list)) + table.set_source_command("search-file", list(args_list)) except Exception: pass if hash_query: @@ -441,10 +665,10 @@ class Search_Store(Cmdlet): ) db.update_worker_status(worker_id, "error") return 1 - debug(f"[search-store] Searching '{backend_to_search}'") + debug(f"[search-file] Searching '{backend_to_search}'") results = target_backend.search(query, limit=limit) debug( - f"[search-store] '{backend_to_search}' -> {len(results or [])} result(s)" + f"[search-file] '{backend_to_search}' -> {len(results or [])} result(s)" ) else: all_results = [] @@ -453,13 +677,13 @@ class Search_Store(Cmdlet): backend = storage[backend_name] searched_backends.append(backend_name) - debug(f"[search-store] Searching '{backend_name}'") + debug(f"[search-file] Searching '{backend_name}'") backend_results = backend.search( query, limit=limit - len(all_results) ) debug( - f"[search-store] '{backend_name}' -> {len(backend_results or [])} result(s)" + f"[search-file] '{backend_name}' -> {len(backend_results or [])} result(s)" ) if backend_results: all_results.extend(backend_results) @@ -542,4 +766,4 @@ class Search_Store(Cmdlet): return 1 -CMDLET = Search_Store() +CMDLET = search_file() diff --git a/cmdlet/search_provider.py b/cmdlet/search_provider.py deleted file mode 100644 index f800f85..0000000 --- a/cmdlet/search_provider.py +++ /dev/null @@ -1,357 +0,0 @@ -"""search-provider cmdlet: Search external providers (bandcamp, libgen, soulseek, youtube, alldebrid).""" - -from __future__ import annotations - -from typing import Any, Dict, List, Sequence, Optional -import sys -import json -import uuid -import importlib - -from SYS.logger import log, debug -from ProviderCore.registry import get_search_provider, list_search_providers - -from . import _shared as sh - -Cmdlet, CmdletArg, should_show_help = ( - sh.Cmdlet, - sh.CmdletArg, - sh.should_show_help, -) -from SYS import pipeline as ctx - -# Optional dependencies -try: - from SYS.config import get_local_storage_path -except Exception: # pragma: no cover - get_local_storage_path = None # type: ignore - - -class Search_Provider(Cmdlet): - """Search external content providers.""" - - def __init__(self): - super().__init__( - name="search-provider", - summary= - "Search external providers (bandcamp, libgen, soulseek, youtube, alldebrid, loc, internetarchive)", - usage="search-provider -provider [-limit N] [-open ID]", - arg=[ - CmdletArg( - "provider", - type="string", - required=True, - description= - "Provider name: bandcamp, libgen, soulseek, youtube, alldebrid, loc, internetarchive", - ), - CmdletArg( - "query", - type="string", - required=True, - description="Search query (supports provider-specific syntax)", - ), - CmdletArg( - "limit", - type="int", - description="Maximum results to return (default: 50)" - ), - CmdletArg( - "open", - type="int", - description= - "(alldebrid) Open folder/magnet by ID and list its files", - ), - ], - detail=[ - "Search external content providers:", - "- alldebrid: List your AllDebrid account folders (magnets). Select @N to view files.", - ' Example: search-provider -provider alldebrid "*"', - ' Example: search-provider -provider alldebrid -open 123 "*"', - "- bandcamp: Search for music albums/tracks", - ' Example: search-provider -provider bandcamp "artist:altrusian grace"', - "- libgen: Search Library Genesis for books", - ' Example: search-provider -provider libgen "python programming"', - "- loc: Search Library of Congress (Chronicling America)", - ' Example: search-provider -provider loc "lincoln"', - "- soulseek: Search P2P network for music", - ' Example: search-provider -provider soulseek "pink floyd"', - "- youtube: Search YouTube for videos", - ' Example: search-provider -provider youtube "tutorial"', - "- internetarchive: Search archive.org items (advancedsearch syntax)", - ' Example: search-provider -provider internetarchive "title:(lincoln) AND mediatype:texts"', - "", - "Query syntax:", - "- bandcamp: Use 'artist:Name' to search by artist", - "- libgen: Supports isbn:, author:, title: prefixes", - "- soulseek: Plain text search", - "- youtube: Plain text search", - "- internetarchive: Archive.org advancedsearch query syntax", - "", - "Results can be piped to other cmdlet:", - ' search-provider -provider bandcamp "artist:grace" | @1 | download-file', - ], - exec=self.run, - ) - self.register() - - def run(self, result: Any, args: Sequence[str], config: Dict[str, Any]) -> int: - """Execute search-provider cmdlet.""" - if should_show_help(args): - ctx.emit(self.__dict__) - return 0 - - args_list = [str(a) for a in (args or [])] - - # Dynamic flag variants from cmdlet arg definitions. - flag_registry = self.build_flag_registry() - provider_flags = { - f.lower() - for f in (flag_registry.get("provider") or {"-provider", "--provider"}) - } - query_flags = { - f.lower() - for f in (flag_registry.get("query") or {"-query", "--query"}) - } - limit_flags = { - f.lower() - for f in (flag_registry.get("limit") or {"-limit", "--limit"}) - } - open_flags = { - f.lower() - for f in (flag_registry.get("open") or {"-open", "--open"}) - } - - provider_name: Optional[str] = None - query: Optional[str] = None - limit = 50 - open_id: Optional[int] = None - positionals: List[str] = [] - - i = 0 - while i < len(args_list): - token = args_list[i] - low = token.lower() - if low in provider_flags and i + 1 < len(args_list): - provider_name = args_list[i + 1] - i += 2 - elif low in query_flags and i + 1 < len(args_list): - query = args_list[i + 1] - i += 2 - elif low in limit_flags and i + 1 < len(args_list): - try: - limit = int(args_list[i + 1]) - except ValueError: - log( - f"Warning: Invalid limit value '{args_list[i + 1]}', using default 50", - file=sys.stderr, - ) - limit = 50 - i += 2 - elif low in open_flags and i + 1 < len(args_list): - try: - open_id = int(args_list[i + 1]) - except ValueError: - log( - f"Warning: Invalid open value '{args_list[i + 1]}', ignoring", - file=sys.stderr, - ) - open_id = None - i += 2 - elif not token.startswith("-"): - positionals.append(token) - i += 1 - else: - i += 1 - - # Backwards-compatible positional form: search-provider - if provider_name is None and positionals: - provider_name = positionals[0] - positionals = positionals[1:] - - if query is None and positionals: - query = " ".join(positionals).strip() or None - - if not provider_name or not query: - log("Error: search-provider requires a provider and query", file=sys.stderr) - log(f"Usage: {self.usage}", file=sys.stderr) - log("Available providers:", file=sys.stderr) - providers = list_search_providers(config) - for name, available in sorted(providers.items()): - status = "✓" if available else "✗" - log(f" {status} {name}", file=sys.stderr) - return 1 - - debug( - f"[search-provider] provider={provider_name}, query={query}, limit={limit}" - ) - - # Get provider - provider = get_search_provider(provider_name, config) - if not provider: - log(f"Error: Provider '{provider_name}' is not available", file=sys.stderr) - log("Available providers:", file=sys.stderr) - providers = list_search_providers(config) - for name, available in sorted(providers.items()): - if available: - log(f" - {name}", file=sys.stderr) - return 1 - - worker_id = str(uuid.uuid4()) - library_root = get_local_storage_path( - config or {} - ) if get_local_storage_path else None - - db = None - if library_root: - try: - from API.folder import API_folder_store - - db = API_folder_store(library_root) - except Exception: - db = None - - try: - # Use the worker DB if available; otherwise, run as a stateless one-off. - if db is not None: - db.__enter__() - db.insert_worker( - worker_id, - "search-provider", - title=f"Search: {query}", - description=f"Provider: {provider_name}, Query: {query}", - pipe=ctx.get_current_command_text(), - ) - - results_list = [] - from SYS import result_table - - importlib.reload(result_table) - from SYS.result_table import ResultTable - - provider_text = str(provider_name or "").strip() - provider_lower = provider_text.lower() - if provider_lower == "youtube": - provider_label = "Youtube" - elif provider_lower == "openlibrary": - provider_label = "OpenLibrary" - elif provider_lower == "loc": - provider_label = "LoC" - else: - provider_label = ( - provider_text[:1].upper() + - provider_text[1:] if provider_text else "Provider" - ) - - if provider_lower == "alldebrid" and open_id is not None: - table_title = f"{provider_label} Files: {open_id}".strip().rstrip(":") - else: - table_title = f"{provider_label}: {query}".strip().rstrip(":") - preserve_order = provider_name.lower() in ("youtube", "openlibrary", "loc") - table = ResultTable(table_title).set_preserve_order(preserve_order) - table.set_table(provider_name) - table.set_source_command("search-provider", list(args)) - - debug(f"[search-provider] Calling {provider_name}.search()") - if provider_lower == "alldebrid": - if open_id is not None: - # Second-stage: show files for selected folder/magnet. - results = provider.search( - query, - limit=limit, - filters={ - "view": "files", - "magnet_id": open_id - } - ) - else: - # Default: show folders (magnets) so user can select @N. - results = provider.search( - query, - limit=limit, - filters={ - "view": "folders" - } - ) - else: - results = provider.search(query, limit=limit) - debug(f"[search-provider] Got {len(results)} results") - - if not results: - log(f"No results found for query: {query}", file=sys.stderr) - if db is not None: - db.append_worker_stdout(worker_id, json.dumps([], indent=2)) - db.update_worker_status(worker_id, "completed") - return 0 - - # Emit results for pipeline - for search_result in results: - item_dict = ( - search_result.to_dict() - if hasattr(search_result, - "to_dict") else dict(search_result) - ) - - # Ensure table field is set (should be by provider, but just in case) - if "table" not in item_dict: - item_dict["table"] = provider_name - - row_index = len(table.rows) - table.add_result( - search_result - ) # ResultTable handles SearchResult objects - - # For AllDebrid folder rows, allow @N to open and show files. - try: - if (provider_lower == "alldebrid" and getattr(search_result, - "media_kind", - "") == "folder"): - magnet_id = None - meta = getattr(search_result, "full_metadata", None) - if isinstance(meta, dict): - magnet_id = meta.get("magnet_id") - if magnet_id is not None: - table.set_row_selection_args( - row_index, - ["-open", - str(magnet_id), - "-query", - "*"] - ) - except Exception: - pass - results_list.append(item_dict) - ctx.emit(item_dict) - - ctx.set_last_result_table(table, results_list) - # Ensure @N selection expands against this newly displayed table. - ctx.set_current_stage_table(table) - if db is not None: - db.append_worker_stdout(worker_id, json.dumps(results_list, indent=2)) - db.update_worker_status(worker_id, "completed") - - return 0 - - except Exception as e: - log(f"Error searching {provider_name}: {e}", file=sys.stderr) - import traceback - - debug(traceback.format_exc()) - if db is not None: - try: - db.update_worker_status(worker_id, "error") - except Exception: - pass - return 1 - finally: - if db is not None: - try: - db.__exit__(None, None, None) - except Exception: - pass - - -# Register cmdlet instance (catalog + REPL autocomplete expects module-level CMDLET) -CMDLET = Search_Provider() - -# Backwards-compatible alias -Search_Provider_Instance = CMDLET diff --git a/cmdnat/out_table.py b/cmdnat/out_table.py index 029e87e..dd79973 100644 --- a/cmdnat/out_table.py +++ b/cmdnat/out_table.py @@ -26,8 +26,8 @@ CMDLET = Cmdlet( "Exports the most recent table (overlay/stage/last) as an SVG using Rich.", "Default filename is derived from the table title (sanitized).", "Examples:", - 'search-store "ext:mp3" | .out-table -path "C:\\Users\\Admin\\Desktop"', - 'search-store "ext:mp3" | .out-table -path "C:\\Users\\Admin\\Desktop\\my-table.svg"', + 'search-file "ext:mp3" | .out-table -path "C:\\Users\\Admin\\Desktop"', + 'search-file "ext:mp3" | .out-table -path "C:\\Users\\Admin\\Desktop\\my-table.svg"', ], ) diff --git a/docs/img/search-store.svg b/docs/img/search-store.svg index 3c3199d..dfa9b9e 100644 --- a/docs/img/search-store.svg +++ b/docs/img/search-store.svg @@ -70,7 +70,7 @@ - ╭─────────────────────────────────────────────────────────────────────────────────────────── search-store "ubuntu" ────────────────────────────────────────────────────────────────────────────────────────────╮ + ╭─────────────────────────────────────────────────────────────────────────────────────────── search-file "ubuntu" ────────────────────────────────────────────────────────────────────────────────────────────╮          #TITLE                                                                                                    STORE                               SIZE                 EXT               ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────  diff --git a/docs/tutorial.md b/docs/tutorial.md index 58dc210..2f45d62 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -60,11 +60,11 @@ we added freeform tags, freeform tags are tags that dont have colons in them, th
<🜂🜄|🜁🜃>
-
search-store "ubuntu"
+  
search-file "ubuntu"
 
-![search-store]() +![search-file]() to access your file and view it, you can run either @@ -79,11 +79,11 @@ or if you have mpv installed (the preferred way for video files) # Bandcamp downloading (provider method)
<🜂🜄|🜁🜃>
-
search-provider -provider bandcamp -query "artist:altrusian grace media"
+  
search-file -provider bandcamp -query "artist:altrusian grace media"
 
-![search-provider -provider bandcamp](img/bandcamp-artist.svg) +![search-file -provider bandcamp](img/bandcamp-artist.svg) this brings up special scraper for bandcamp on the artist page, the query syntax of "artist:" is how we parse args for some cmdlets. next run the following
@@ -164,7 +164,7 @@ restart the cli and check the startup table, if soulseek says ENABLED then you a
<🜂🜄|🜁🜃>
-
search-provider -provider soulseek "erika herms niel"
+  
search-file -provider soulseek "erika herms niel"
 
@@ -202,7 +202,7 @@ openlibrary allows us to borrow books, merge them into a permement pdf, then ret
-we could have use the search-provider -provider openlibrary, but to show the versatile of the app, we able to use download-file and medios will be able to intelligently direct it the correct provider (with exception of download-media, download-media is just the frontend for yt-dlp). +we could have use the search-file -provider openlibrary, but to show the versatile of the app, we able to use download-file and medios will be able to intelligently direct it the correct provider (with exception of download-media, download-media is just the frontend for yt-dlp). # Libgen libgen is self-explanatory, diff --git a/readme.md b/readme.md index 8bddbb9..cefa7f3 100644 --- a/readme.md +++ b/readme.md @@ -70,7 +70,7 @@ download-file "https://openlibrary.org/books/OLxxxxxM/Book_Title" | add-file -st Search your library: ```bash -search-store "ext:mp3" +search-file "ext:mp3" ``` ## Providers & stores diff --git a/requirements.txt b/requirements.txt index a36547d..3b0bfbd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,12 @@ Pillow>=10.0.0 python-bidi>=0.4.2 ffmpeg-python>=0.2.0 +# AI tagging (FlorenceVision tool) +transformers>=4.45.0 +torch>=2.4.0 +einops>=0.8.0 +timm>=1.0.0 + # Metadata extraction and processing musicbrainzngs>=0.7.0 lxml>=4.9.0 diff --git a/tool/__init__.py b/tool/__init__.py index cd9f29d..ddbda1e 100644 --- a/tool/__init__.py +++ b/tool/__init__.py @@ -7,5 +7,13 @@ common defaults (cookies, timeouts, format selectors) and users can override the from .ytdlp import YtDlpTool, YtDlpDefaults from .playwright import PlaywrightTool, PlaywrightDefaults +from .florencevision import FlorenceVisionTool, FlorenceVisionDefaults -__all__ = ["YtDlpTool", "YtDlpDefaults", "PlaywrightTool", "PlaywrightDefaults"] +__all__ = [ + "YtDlpTool", + "YtDlpDefaults", + "PlaywrightTool", + "PlaywrightDefaults", + "FlorenceVisionTool", + "FlorenceVisionDefaults", +] diff --git a/tool/florencevision.py b/tool/florencevision.py new file mode 100644 index 0000000..ea2aa68 --- /dev/null +++ b/tool/florencevision.py @@ -0,0 +1,1000 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from SYS.logger import debug + + +def _truncate_debug_text(text: str, max_chars: int = 12000) -> str: + s = str(text or "") + if len(s) <= max_chars: + return s + return s[:max_chars] + f"\n... (truncated; {len(s)} chars total)" + + +def _debug_repr(value: Any, max_chars: int = 12000) -> str: + """Pretty-ish repr for debug without risking huge output.""" + try: + import pprint + + s = pprint.pformat(value, width=120, compact=True) + except Exception: + try: + s = repr(value) + except Exception: + s = f"<{type(value).__name__}>" + return _truncate_debug_text(s, max_chars=max_chars) + + +def _get_nested(config: Dict[str, Any], *path: str) -> Any: + cur: Any = config + for key in path: + if not isinstance(cur, dict): + return None + cur = cur.get(key) + return cur + + +def _as_bool(value: Any, default: bool = False) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + s = str(value).strip().lower() + if s in {"1", "true", "yes", "on"}: + return True + if s in {"0", "false", "no", "off"}: + return False + return default + + +def _as_int(value: Any, default: int) -> int: + try: + return int(value) + except Exception: + return default + + +def _clean_tag_value(text: str) -> str: + # Keep tags conservative: strip confidence suffixes, lowercase, trim, + # replace whitespace runs with underscore. + s0 = str(text or "").strip() + if not s0: + return "" + + # Strip common confidence suffix patterns, e.g.: + # "car (0.98)", "car:0.98", "car 0.98", "car (98%)" + # Keep this conservative to avoid mangling labels like "iphone 14". + import re + + s0 = re.sub(r"\s*[\(\[]\s*\d+(?:\.\d+)?\s*%?\s*[\)\]]\s*$", "", s0) + s0 = re.sub(r"\s*[:=]\s*\d+(?:\.\d+)?\s*%?\s*$", "", s0) + s0 = re.sub(r"\s+\d+\.\d+\s*%?\s*$", "", s0) + + s = s0.strip().lower() + if not s: + return "" + # remove leading/trailing punctuation + s = s.strip("\"'`.,;:!?()[]{}<>|\\/") + # Common list markers / bullets that show up in OCR/tag outputs. + s = s.strip("-–—•·") + s = "_".join([p for p in s.replace("\t", " ").replace("\n", " ").split() if p]) + # avoid empty or purely underscores + s = s.strip("_") + if not s: + return "" + # Drop values that have no alphanumerics (e.g. "-" / "___"). + if not any(ch.isalnum() for ch in s): + return "" + return s + + +def _normalize_task_prompt(task: str) -> str: + """Normalize human-friendly task names to Florence prompt tokens. + + Accepts either Florence tokens (e.g. "" / "" or "<|...|>") + or friendly aliases. Default and most aliases map to Florence's supported + detailed-caption + grounding combo to yield both labels and a caption: + - "tag" / "tags" -> "<|detailed_caption|><|grounding|>" + - "detection" / "detect" / "od" / "grounding" -> "<|detailed_caption|><|grounding|>" + - "caption" -> "<|detailed_caption|>" + - "ocr" -> "<|ocr|>" + """ + raw = str(task or "").strip() + if not raw: + return "<|detailed_caption|><|grounding|>" + + # If user already provided a Florence token, keep it unless it's a legacy OD token + # (then expand to the detailed_caption+grounding combo). + if raw.startswith("<") and raw.endswith(">"): + od_like = raw.strip().lower() + if od_like in {"", "<|od|>", "<|object_detection|>", "<|object-detection|>"}: + return "<|detailed_caption|><|grounding|>" + return raw if raw.startswith("<|") else raw.upper() + + key = raw.strip().lower().replace("_", "-") + key = " ".join(key.split()) + + if key in {"tag", "tags"}: + return "<|detailed_caption|><|grounding|>" + if key in { + "detection", + "detect", + "object-detection", + "object detection", + "object_detection", + "od", + }: + return "<|detailed_caption|><|grounding|>" + if key in {"grounding", "bbox", "boxes", "box"}: + return "<|detailed_caption|><|grounding|>" + if key in {"caption", "cap", "describe", "description"}: + return "<|detailed_caption|>" + if key in {"more-detailed-caption", "detailed-caption", "more detailed caption", "detailed_caption"}: + return "<|detailed_caption|>" + if key in {"ocr", "text", "read", "extract-text", "extract text"}: + return "<|ocr|>" + + # Unknown strings: pass through (remote-code models sometimes accept custom prompts). + return raw + + +def _is_caption_task(prompt: str) -> bool: + p = str(prompt or "").upper() + return "CAPTION" in p + + +def _is_tag_task(prompt: str) -> bool: + p = str(prompt or "").strip().lower() + return "tag" in p or "<|tag|>" in p + + +def _is_ocr_task(prompt: str) -> bool: + p = str(prompt or "").strip().lower() + return "ocr" in p or "<|ocr|>" in p + + +def _strip_florence_tokens(text: str) -> str: + """Remove Florence prompt/special tokens from generated text.""" + import re + + s = str(text or "") + if not s: + return "" + + # Remove <|...|> style tokens and legacy // style tokens. + s = re.sub(r"<\|[^>]+?\|>", " ", s) + s = re.sub(r"<[^>]+?>", " ", s) + + # Remove common leftover special tokens. + s = s.replace("", " ").replace("", " ") + return " ".join(s.split()) + + +def _split_text_to_labels(text: str) -> List[str]: + """Split a generated text blob into candidate labels.""" + raw = _strip_florence_tokens(text) + if not raw: + return [] + + out: List[str] = [] + for line in raw.replace("\r\n", "\n").replace("\r", "\n").split("\n"): + line = line.strip() + if not line: + continue + for part in line.replace(";", ",").split(","): + part = part.strip() + if part: + out.append(part) + return out + + +def _clean_caption_text(text: str) -> str: + """Strip Florence tokens and collapse whitespace for caption text.""" + cleaned = _strip_florence_tokens(text) + return " ".join(str(cleaned or "").split()) + + +def _collect_candidate_strings(value: Any) -> List[str]: + """Best-effort extraction of tag-like strings from nested Florence outputs.""" + out: List[str] = [] + + if value is None: + return out + if isinstance(value, str): + s = value.strip() + if s: + out.append(s) + return out + + if isinstance(value, dict): + # Prefer common semantic keys first. + for key in ( + "labels", + "label", + "text", + "texts", + "words", + "word", + "caption", + "captions", + "phrase", + "phrases", + "name", + "names", + ): + if key in value: + out.extend(_collect_candidate_strings(value.get(key))) + # Then fall back to scanning other values (numbers/bboxes are ignored). + for v in value.values(): + out.extend(_collect_candidate_strings(v)) + return out + + if isinstance(value, (list, tuple)): + for item in value: + out.extend(_collect_candidate_strings(item)) + return out + + return out + + +def _collect_captions(value: Any, key_hint: str = "") -> List[str]: + """Extract caption-like strings from nested structures by key name.""" + out: List[str] = [] + + def _norm(val: Any) -> Optional[str]: + if val is None: + return None + try: + s = str(val).strip() + except Exception: + return None + return s if s else None + + try: + hint_has_caption = "caption" in str(key_hint or "").lower() + except Exception: + hint_has_caption = False + + if isinstance(value, str): + if hint_has_caption: + s = _norm(value) + if s: + cleaned = _clean_caption_text(s) + if cleaned: + out.append(cleaned) + else: + out.append(s) + return out + + if isinstance(value, dict): + for k, v in value.items(): + out.extend(_collect_captions(v, key_hint=str(k))) + return out + + if isinstance(value, (list, tuple)): + for item in value: + out.extend(_collect_captions(item, key_hint=key_hint)) + return out + + return out + + +@dataclass(slots=True) +class FlorenceVisionDefaults: + enabled: bool = False + strict: bool = False + model: str = "microsoft/Florence-2-large" + device: str = "cpu" # "cpu" | "cuda" | "mps" + dtype: Optional[str] = None # e.g. "float16" | "bfloat16" | None + max_tags: int = 12 + namespace: str = "florence" + task: str = "tag" # Friendly aliases: tag/detection/caption/ocr (or raw Florence prompt tokens) + + +class FlorenceVisionTool: + """Microsoft Florence vision model wrapper. + + Designed to be dependency-light at import time; heavy deps are imported lazily. + + Config: + [tool=florencevision] + enabled=true + strict=false + model="microsoft/Florence-2-large" + device="cpu" + dtype="float16" # optional + max_tags=12 + task="<|tag|>" # or <|od|>, <|caption|>, <|ocr|> + + Notes: + Florence-2 typically requires `trust_remote_code=True` when loading via Transformers. + """ + + IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"} + + def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: + self._config: Dict[str, Any] = dict(config or {}) + self.defaults = self._load_defaults() + self._model = None + self._processor = None + self._last_caption: Optional[str] = None + + def _load_defaults(self) -> FlorenceVisionDefaults: + cfg = self._config + tool_block = _get_nested(cfg, "tool", "florencevision") + if not isinstance(tool_block, dict): + tool_block = {} + + base = FlorenceVisionDefaults() + + defaults = FlorenceVisionDefaults( + enabled=_as_bool(tool_block.get("enabled"), False), + strict=_as_bool(tool_block.get("strict"), False), + model=str(tool_block.get("model") or base.model), + device=str(tool_block.get("device") or base.device), + dtype=(str(tool_block.get("dtype")).strip() if tool_block.get("dtype") else None), + max_tags=_as_int(tool_block.get("max_tags"), base.max_tags), + namespace=str(tool_block.get("namespace") or base.namespace), + task=str(tool_block.get("task") or base.task), + ) + return defaults + + def enabled(self) -> bool: + return bool(self.defaults.enabled) + + def applicable_path(self, media_path: Path) -> bool: + try: + return media_path.suffix.lower() in self.IMAGE_EXTS + except Exception: + return False + + def _ensure_loaded(self) -> None: + if self._model is not None and self._processor is not None: + return + + try: + import torch + from transformers import AutoModelForCausalLM, AutoProcessor + except Exception as exc: + raise RuntimeError( + "FlorenceVision requires optional dependencies. Install at least: torch, transformers, pillow. " + "(Florence-2 typically also needs trust_remote_code=True)." + ) from exc + + model_id = self.defaults.model + device = self.defaults.device + + debug(f"[florencevision] Loading processor/model: model={model_id} device={device} dtype={self.defaults.dtype}") + + dtype = None + if self.defaults.dtype: + dt = self.defaults.dtype.strip().lower() + dtype = { + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float32": torch.float32, + "fp32": torch.float32, + }.get(dt) + + # FlorenceVision often runs on CPU for many users; float16/bfloat16 on CPU + # is fragile (and can produce dtype mismatches like float vs Half bias). + # If the configured device is CPU, force float32 unless explicitly set to float32. + if str(device).strip().lower() == "cpu" and dtype in {torch.float16, torch.bfloat16}: + debug( + f"[florencevision] Overriding dtype to float32 on CPU (was {self.defaults.dtype})" + ) + dtype = torch.float32 + + # Florence-2 usually requires trust_remote_code. + self._processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + base_kwargs: Dict[str, Any] = { + "trust_remote_code": True, + "torch_dtype": dtype, + } + + # Transformers attention backends have been a moving target. Some Florence-2 + # remote-code builds trigger AttributeError on SDPA capability checks. + # Prefer eager attention when supported, otherwise fall back. + try: + self._model = AutoModelForCausalLM.from_pretrained( + model_id, + attn_implementation="eager", + **base_kwargs, + ) + except TypeError: + # Older Transformers: no attn_implementation kwarg. + self._model = AutoModelForCausalLM.from_pretrained(model_id, **base_kwargs) + except AttributeError as exc: + if "_supports_sdpa" in str(exc): + try: + self._model = AutoModelForCausalLM.from_pretrained( + model_id, + attn_implementation="eager", + **base_kwargs, + ) + except TypeError: + self._model = AutoModelForCausalLM.from_pretrained(model_id, **base_kwargs) + else: + raise + + # Defensive compatibility patch: some Florence-2 implementations do not + # declare SDPA support flags but newer Transformers paths may probe them. + try: + if self._model is not None and not hasattr(self._model, "_supports_sdpa"): + setattr(self._model, "_supports_sdpa", False) + except Exception: + pass + + try: + self._model.to(device) # type: ignore[union-attr] + except Exception: + # Fallback to cpu + self._model.to("cpu") # type: ignore[union-attr] + + try: + self._model.eval() # type: ignore[union-attr] + except Exception: + pass + + try: + md = getattr(self._model, "device", None) + dt = None + try: + dt = next(self._model.parameters()).dtype # type: ignore[union-attr] + except Exception: + dt = None + debug(f"[florencevision] Model loaded: device={md} param_dtype={dt}") + except Exception: + pass + + def tags_for_image(self, media_path: Path) -> List[str]: + """Return Florence-derived tags for an image. + + Uses the configured Florence task (default: tag) and turns model output into tags. + """ + self._ensure_loaded() + self._last_caption = None + + try: + from PIL import Image + except Exception as exc: + raise RuntimeError("FlorenceVision requires pillow (PIL).") from exc + + if self._processor is None or self._model is None: + return [] + + prompt = _normalize_task_prompt(str(self.defaults.task or "tag")) + try: + debug(f"[florencevision] Task prompt: {prompt}") + except Exception: + pass + + max_tags = max(0, int(self.defaults.max_tags or 0)) + + try: + debug( + f"[florencevision] Opening image: path={media_path} exists={media_path.exists()} size_bytes={(media_path.stat().st_size if media_path.exists() else 'n/a')}" + ) + except Exception: + debug(f"[florencevision] Opening image: path={media_path}") + + image = Image.open(str(media_path)).convert("RGB") + try: + debug(f"[florencevision] Image loaded: mode={image.mode} size={image.width}x{image.height}") + except Exception: + pass + + processor = self._processor + model = self._model + + # Inspect forward signature once; reused across cascaded runs. + forward_params: set[str] = set() + try: + import inspect + + forward_params = set(inspect.signature(getattr(model, "forward")).parameters.keys()) + except Exception: + forward_params = set() + + def _run_prompt(task_prompt: str) -> Tuple[str, Any, Any]: + """Run a single Florence prompt and return (generated_text, parsed, seq).""" + try: + debug(f"[florencevision] running prompt: {task_prompt}") + except Exception: + pass + + try: + # Florence expects special task tokens like <|tag|>, <|od|>, <|caption|>, <|ocr|>. + inputs = processor(text=task_prompt, images=image, return_tensors="pt") + try: + import torch + + keys = [] + try: + keys = list(dict(inputs).keys()) + except Exception: + try: + keys = list(getattr(inputs, "keys")()) + except Exception: + keys = [] + + debug(f"[florencevision] Processor output keys: {keys}") + + for k in keys: + try: + v = dict(inputs).get(k) + except Exception: + try: + v = inputs.get(k) # type: ignore[union-attr] + except Exception: + v = None + + if v is None: + debug(f"[florencevision] {k}: None") + continue + if hasattr(v, "shape"): + try: + debug( + f"[florencevision] {k}: tensor shape={tuple(v.shape)} dtype={getattr(v, 'dtype', None)}" + ) + continue + except Exception: + pass + if isinstance(v, (list, tuple)): + has_none = any(x is None for x in v) + debug(f"[florencevision] {k}: {type(v).__name__} len={len(v)} has_none={has_none}") + continue + debug(f"[florencevision] {k}: type={type(v).__name__}") + except Exception: + pass + + try: + inputs = inputs.to(model.device) # type: ignore[attr-defined] + except Exception: + pass + + # Align floating-point input tensors with the model's parameter dtype. + try: + import torch + + try: + model_dtype = next(model.parameters()).dtype # type: ignore[union-attr] + except Exception: + model_dtype = None + + if model_dtype is not None: + for k, v in list(inputs.items()): + try: + if hasattr(v, "dtype") and torch.is_floating_point(v): + inputs[k] = v.to(dtype=model_dtype) + except Exception: + continue + except Exception: + pass + + try: + gen_inputs_all = {k: v for k, v in dict(inputs).items() if v is not None} + except Exception: + gen_inputs_all = inputs # type: ignore[assignment] + + gen_inputs: Dict[str, Any] = {} + if isinstance(gen_inputs_all, dict): + input_ids = gen_inputs_all.get("input_ids") + pixel_values = gen_inputs_all.get("pixel_values") + attention_mask = gen_inputs_all.get("attention_mask") + + if input_ids is not None: + gen_inputs["input_ids"] = input_ids + if pixel_values is not None: + gen_inputs["pixel_values"] = pixel_values + + try: + if ( + attention_mask is not None + and hasattr(attention_mask, "shape") + and hasattr(input_ids, "shape") + and tuple(attention_mask.shape) == tuple(input_ids.shape) + ): + gen_inputs["attention_mask"] = attention_mask + except Exception: + pass + + try: + debug( + "[florencevision] model forward supports: " + f"pixel_mask={'pixel_mask' in forward_params} " + f"image_attention_mask={'image_attention_mask' in forward_params} " + f"pixel_attention_mask={'pixel_attention_mask' in forward_params}" + ) + except Exception: + pass + + try: + gen_inputs.setdefault("use_cache", False) + gen_inputs.setdefault("num_beams", 1) + except Exception: + pass + + try: + debug(f"[florencevision] generate kwargs: {sorted(list(gen_inputs.keys()))}") + except Exception: + pass + + pv = gen_inputs.get("pixel_values") + if pv is None: + raise RuntimeError( + "FlorenceVision processor did not produce 'pixel_values'. " + "This usually indicates an image preprocessing issue." + ) + + try: + import torch + + cm = torch.inference_mode + except Exception: + cm = None + + def _do_generate(kwargs: Dict[str, Any]) -> Any: + if cm is not None: + with cm(): + return model.generate(**kwargs, max_new_tokens=1024) + return model.generate(**kwargs, max_new_tokens=1024) + + try: + generated_ids = _do_generate(gen_inputs) + except AttributeError as exc: + msg = str(exc) + if "_supports_sdpa" in msg: + try: + if not hasattr(model, "_supports_sdpa"): + setattr(model, "_supports_sdpa", False) + except Exception: + pass + generated_ids = _do_generate(gen_inputs) + elif "NoneType" in msg and "shape" in msg: + retry_inputs = dict(gen_inputs) + + try: + if ( + "attention_mask" not in retry_inputs + and isinstance(gen_inputs_all, dict) + and gen_inputs_all.get("attention_mask") is not None + ): + am = gen_inputs_all.get("attention_mask") + ii = retry_inputs.get("input_ids") + if ( + am is not None + and ii is not None + and hasattr(am, "shape") + and hasattr(ii, "shape") + and tuple(am.shape) == tuple(ii.shape) + ): + retry_inputs["attention_mask"] = am + except Exception: + pass + + try: + import torch + + pv2 = retry_inputs.get("pixel_values") + if pv2 is not None and hasattr(pv2, "shape") and len(pv2.shape) == 4: + b, _c, h, w = tuple(pv2.shape) + mask = torch.ones((b, h, w), dtype=torch.long, device=pv2.device) + if "pixel_mask" in forward_params and "pixel_mask" not in retry_inputs: + retry_inputs["pixel_mask"] = mask + elif "image_attention_mask" in forward_params and "image_attention_mask" not in retry_inputs: + retry_inputs["image_attention_mask"] = mask + elif "pixel_attention_mask" in forward_params and "pixel_attention_mask" not in retry_inputs: + retry_inputs["pixel_attention_mask"] = mask + except Exception: + pass + + try: + debug( + f"[florencevision] generate retry kwargs: {sorted(list(retry_inputs.keys()))}" + ) + except Exception: + pass + + generated_ids = _do_generate(retry_inputs) + else: + raise + + try: + debug(f"[florencevision] generated_ids type={type(generated_ids).__name__}") + except Exception: + pass + + seq = getattr(generated_ids, "sequences", generated_ids) + generated_text = processor.batch_decode(seq, skip_special_tokens=False)[0] + except Exception as exc: + try: + import traceback + + debug(f"[florencevision] prompt run failed: {type(exc).__name__}: {exc}") + debug("[florencevision] traceback:\n" + traceback.format_exc()) + except Exception: + pass + raise + + parsed = None + try: + parsed = processor.post_process_generation( + generated_text, + task=task_prompt, + image_size=(image.width, image.height), + ) + except Exception: + parsed = None + + try: + generated_text_no_special = None + try: + generated_text_no_special = processor.batch_decode(seq, skip_special_tokens=True)[0] + except Exception: + generated_text_no_special = None + + debug("[florencevision] ===== RAW GENERATED (skip_special_tokens=False) =====") + debug(_truncate_debug_text(str(generated_text or ""))) + if generated_text_no_special is not None: + debug("[florencevision] ===== RAW GENERATED (skip_special_tokens=True) =====") + debug(_truncate_debug_text(str(generated_text_no_special or ""))) + + if parsed is None: + debug("[florencevision] post_process_generation: None") + elif isinstance(parsed, dict): + try: + keys = list(parsed.keys()) + except Exception: + keys = [] + debug(f"[florencevision] post_process_generation: dict keys={keys}") + try: + if task_prompt in parsed: + debug(f"[florencevision] post_process[{task_prompt!r}] type={type(parsed.get(task_prompt)).__name__}") + debug("[florencevision] post_process[prompt] repr:\n" + _debug_repr(parsed.get(task_prompt))) + elif len(parsed) == 1: + only_key = next(iter(parsed.keys())) + debug(f"[florencevision] post_process single key {only_key!r} type={type(parsed.get(only_key)).__name__}") + debug("[florencevision] post_process[single] repr:\n" + _debug_repr(parsed.get(only_key))) + else: + for k in list(parsed.keys())[:5]: + debug(f"[florencevision] post_process[{k!r}] type={type(parsed.get(k)).__name__}") + debug("[florencevision] post_process[key] repr:\n" + _debug_repr(parsed.get(k))) + except Exception: + pass + else: + debug(f"[florencevision] post_process_generation: type={type(parsed).__name__}") + debug("[florencevision] post_process repr:\n" + _debug_repr(parsed)) + except Exception: + pass + + return generated_text, parsed, seq + + def _extract_labels_and_captions(task_prompt: str, generated_text: str, parsed: Any) -> Tuple[List[str], List[str], List[str], List[Tuple[str, str, str]]]: + labels: List[str] = [] + caption_candidates: List[str] = [] + + if isinstance(parsed, dict): + for k, v in parsed.items(): + key_lower = str(k).lower() + if "caption" in key_lower: + caption_candidates.extend(_collect_captions(v, key_hint=str(k))) + continue + labels.extend(_collect_candidate_strings(v)) + elif parsed is not None: + if isinstance(parsed, str) and parsed.strip() and _is_caption_task(task_prompt): + caption_candidates.append(parsed.strip()) + else: + labels.extend(_collect_candidate_strings(parsed)) + + if not labels: + raw = str(generated_text or "").strip() + if raw: + labels.extend(_split_text_to_labels(raw)) + + try: + debug(f"[florencevision] candidate label strings ({len(labels)}): {labels!r}") + except Exception: + pass + + out: List[str] = [] + seen: set[str] = set() + dropped: List[Tuple[str, str, str]] = [] + for lab in labels: + v = _clean_tag_value(lab) + if not v: + dropped.append((str(lab), "", "cleaned_empty")) + continue + + if v in { + "od", + "caption", + "more_detailed_caption", + "more-detailed-caption", + "ocr", + "tag", + "grounding", + "object_detection", + "detailed_caption", + "caption_to_phrase_grounding", + }: + dropped.append((str(lab), v, "filtered_task_token")) + continue + + if v.startswith("florence:"): + v = v.split(":", 1)[1].strip("_") + if not v: + dropped.append((str(lab), "", "stripped_namespace_empty")) + continue + + key = v.lower() + if key in seen: + dropped.append((str(lab), v, "duplicate")) + continue + seen.add(key) + out.append(v) + if max_tags and len(out) >= max_tags: + break + + try: + debug(f"[florencevision] cleaned tags ({len(out)}): {out!r}") + if dropped: + debug(f"[florencevision] dropped ({len(dropped)}):") + for raw_lab, cleaned, reason in dropped: + debug(f"[florencevision] drop reason={reason} raw={raw_lab!r} cleaned={cleaned!r}") + except Exception: + pass + + return labels, caption_candidates, out, dropped + + def _best_caption(candidates: Sequence[str]) -> Optional[str]: + cleaned: List[str] = [] + raw: List[str] = [] + for c in candidates: + try: + s = str(c).strip() + except Exception: + continue + if not s: + continue + raw.append(s) + cc = _clean_caption_text(s) + if cc: + cleaned.append(cc) + + if cleaned: + try: + return max(cleaned, key=lambda s: len(str(s)), default=None) + except Exception: + pass + try: + return max(raw, key=lambda s: len(str(s)), default=None) + except Exception: + return None + + def _grounding_candidates_from_caption(caption_text: Optional[str], fallback_tags: Sequence[str]) -> List[str]: + import re + + words: List[str] = [] + if caption_text: + cap_clean = _clean_caption_text(caption_text) + if cap_clean: + words.extend(re.split(r"[^A-Za-z0-9_\-]+", cap_clean)) + + # Add any fallback tags (e.g., cleaned caption labels) to seed grounding. + for tag in fallback_tags or []: + cleaned_tag = _clean_tag_value(tag) + if cleaned_tag: + words.append(cleaned_tag) + + seen: set[str] = set() + out: List[str] = [] + for w in words: + if not w: + continue + w_clean = re.sub(r"[^A-Za-z0-9_\-]+", "", w).strip("._-") + if len(w_clean) < 3: + continue + if not any(ch.isalpha() for ch in w_clean): + continue + if re.match(r"loc[_-]?\d", w_clean, re.IGNORECASE): + continue + if w_clean.lower() in {"detailed", "caption", "grounding", "poly", "task"}: + continue + key = w_clean.lower() + if key in seen: + continue + seen.add(key) + out.append(w_clean) + if len(out) >= max(max_tags * 2, 10): + break + return out + + is_combo_prompt = "<|detailed_caption|>" in prompt and "<|grounding|>" in prompt + + final_tags: List[str] = [] + caption_text: Optional[str] = None + + if is_combo_prompt: + # Cascaded flow: caption first, then grounding seeded by caption terms. + cap_text, cap_parsed, _cap_seq = _run_prompt("<|detailed_caption|>") + cap_labels, cap_captions, cap_cleaned, _cap_dropped = _extract_labels_and_captions("<|detailed_caption|>", cap_text, cap_parsed) + + best_cap = _best_caption(cap_captions) or _best_caption([_strip_florence_tokens(cap_text)]) + if best_cap: + cap_cleaned_text = _clean_caption_text(best_cap) + if cap_cleaned_text: + caption_text = cap_cleaned_text + + candidates = _grounding_candidates_from_caption(caption_text, cap_cleaned or cap_labels) + grounding_prompt = "<|grounding|>" if not candidates else "<|grounding|> Find and label: " + ", ".join(candidates) + try: + debug(f"[florencevision] grounding prompt: {grounding_prompt}") + except Exception: + pass + + grd_text, grd_parsed, _grd_seq = _run_prompt(grounding_prompt) + _grd_labels, grd_captions, grd_cleaned, _grd_dropped = _extract_labels_and_captions(grounding_prompt, grd_text, grd_parsed) + + final_tags = grd_cleaned or cap_cleaned + if not caption_text: + caption_text = _best_caption(grd_captions) + + # If grounding still produced nothing useful, fall back to raw split of grounding text. + if not final_tags: + fallback_labels = _split_text_to_labels(grd_text) + final_tags = [_clean_tag_value(v) for v in fallback_labels if _clean_tag_value(v)] + if max_tags: + final_tags = final_tags[:max_tags] + else: + gen_text, parsed, _seq = _run_prompt(prompt) + _labels, captions, cleaned_tags, _dropped = _extract_labels_and_captions(prompt, gen_text, parsed) + final_tags = cleaned_tags + caption_text = _best_caption(captions) + + # Fallback: if combo-like prompt yields only task tokens, retry with caption-only once. + try: + is_combo = "<|detailed_caption|>" in prompt and "<|grounding|>" in prompt + only_task_tokens = not final_tags or all(t in {"object_detection", "grounding", "tag"} for t in final_tags) + except Exception: + is_combo = False + only_task_tokens = False + + if is_combo and only_task_tokens and not getattr(self, "_od_tag_retrying", False): + try: + self._od_tag_retrying = True + debug("[florencevision] caption+grounding produced no labels; retrying with <|detailed_caption|> only") + original_task = self.defaults.task + try: + self.defaults.task = "<|detailed_caption|>" + except Exception: + pass + final_tags = self.tags_for_image(media_path) + finally: + try: + self.defaults.task = original_task + except Exception: + pass + self._od_tag_retrying = False + + self._last_caption = caption_text if caption_text else None + return final_tags + + @property + def last_caption(self) -> Optional[str]: + return self._last_caption + + def tags_for_file(self, media_path: Path) -> List[str]: + if not self.enabled(): + return [] + if not self.applicable_path(media_path): + return [] + return self.tags_for_image(media_path) + + +__all__ = ["FlorenceVisionTool", "FlorenceVisionDefaults"]