From 4c9096a11b49b28b2703abb50699e00c23973cb1 Mon Sep 17 00:00:00 2001 From: Timmy Welch Date: Sun, 15 Sep 2024 17:09:33 -0700 Subject: [PATCH] Implement the most basic local plugin isolation possible Remove modules belonging to local plugins after loading Remove sys.path entry after loading This means that multiple local plugins can be installed with the same import path and should work correctly This does not allow loading a local plugin that has the same import path as an installed plugin --- comicapi/comicarchive.py | 16 +-- comictaggerlib/ctsettings/plugin_finder.py | 152 +++++++++++++++++---- comictaggerlib/main.py | 13 +- comictalker/__init__.py | 30 +++- setup.cfg | 1 + 5 files changed, 165 insertions(+), 47 deletions(-) diff --git a/comicapi/comicarchive.py b/comicapi/comicarchive.py index 269134a..8348a65 100644 --- a/comicapi/comicarchive.py +++ b/comicapi/comicarchive.py @@ -24,7 +24,6 @@ import pathlib import shutil import sys from collections.abc import Iterable -from typing import TYPE_CHECKING from comicapi import utils from comicapi.archivers import Archiver, UnknownArchiver, ZipArchiver @@ -32,16 +31,13 @@ from comicapi.genericmetadata import GenericMetadata from comicapi.tags import Tag from comictaggerlib.ctversion import version -if TYPE_CHECKING: - from importlib.metadata import EntryPoint - logger = logging.getLogger(__name__) archivers: list[type[Archiver]] = [] tags: dict[str, Tag] = {} -def load_archive_plugins(local_plugins: Iterable[EntryPoint] = tuple()) -> None: +def load_archive_plugins(local_plugins: Iterable[type[Archiver]] = tuple()) -> None: if archivers: return if sys.version_info < (3, 10): @@ -52,7 +48,7 @@ def load_archive_plugins(local_plugins: Iterable[EntryPoint] = tuple()) -> None: archive_plugins: list[type[Archiver]] = [] # A list is used first matching plugin wins - for ep in itertools.chain(local_plugins, entry_points(group="comicapi.archiver")): + for ep in itertools.chain(entry_points(group="comicapi.archiver")): try: spec = importlib.util.find_spec(ep.module) except ValueError: @@ -70,11 +66,12 @@ def load_archive_plugins(local_plugins: Iterable[EntryPoint] = tuple()) -> None: else: logger.exception("Failed to load archive plugin: %s", ep.name) archivers.clear() + archivers.extend(local_plugins) archivers.extend(archive_plugins) archivers.extend(builtin) -def load_tag_plugins(version: str = f"ComicAPI/{version}", local_plugins: Iterable[EntryPoint] = tuple()) -> None: +def load_tag_plugins(version: str = f"ComicAPI/{version}", local_plugins: Iterable[type[Tag]] = tuple()) -> None: if tags: return if sys.version_info < (3, 10): @@ -84,7 +81,7 @@ def load_tag_plugins(version: str = f"ComicAPI/{version}", local_plugins: Iterab builtin: dict[str, Tag] = {} tag_plugins: dict[str, tuple[Tag, str]] = {} # A dict is used, last plugin wins - for ep in itertools.chain(entry_points(group="comicapi.tags"), local_plugins): + for ep in entry_points(group="comicapi.tags"): location = "Unknown" try: _spec = importlib.util.find_spec(ep.module) @@ -109,6 +106,9 @@ def load_tag_plugins(version: str = f"ComicAPI/{version}", local_plugins: Iterab tag_plugins[tag.id] = (tag(version), location) except Exception: logger.exception("Failed to load tag plugin: %s from %s", ep.name, location) + # A dict is used, last plugin wins + for tag in local_plugins: + tag_plugins[tag.id] = (tag(version), "Local") for tag_id in set(builtin.keys()).intersection(tag_plugins): location = tag_plugins[tag_id][1] diff --git a/comictaggerlib/ctsettings/plugin_finder.py b/comictaggerlib/ctsettings/plugin_finder.py index 75b8d15..73b9d93 100644 --- a/comictaggerlib/ctsettings/plugin_finder.py +++ b/comictaggerlib/ctsettings/plugin_finder.py @@ -5,17 +5,52 @@ from __future__ import annotations import configparser -import importlib.metadata +import importlib.util +import itertools import logging import pathlib +import platform import re -from collections.abc import Generator -from typing import Any, NamedTuple +import sys +from collections.abc import Generator, Iterable +from typing import Any, NamedTuple, TypeVar + +if sys.version_info < (3, 10): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata +import tomli logger = logging.getLogger(__name__) NORMALIZE_PACKAGE_NAME_RE = re.compile(r"[-_.]+") PLUGIN_GROUPS = frozenset(("comictagger.talker", "comicapi.archiver", "comicapi.tags")) +icu_available = importlib.util.find_spec("icu") is not None + + +def _custom_key(tup: Any) -> Any: + import natsort + + lst = [] + for x in natsort.os_sort_keygen()(tup): + ret = x + if len(x) > 1 and isinstance(x[1], int) and isinstance(x[0], str) and x[0] == "": + ret = ("a", *x[1:]) + + lst.append(ret) + return tuple(lst) + + +T = TypeVar("T") + + +def os_sorted(lst: Iterable[T]) -> Iterable[T]: + import natsort + + key = _custom_key + if icu_available or platform.system() == "Windows": + key = natsort.os_sort_keygen() + return sorted(lst, key=key) class FailedToLoadPlugin(Exception): @@ -47,9 +82,12 @@ class Plugin(NamedTuple): package: str version: str - entry_point: importlib.metadata.EntryPoint + entry_point: importlib_metadata.EntryPoint path: pathlib.Path + def load(self) -> LoadedPlugin: + return LoadedPlugin(self, self.entry_point.load()) + class LoadedPlugin(NamedTuple): """Represents a plugin after being imported.""" @@ -71,11 +109,11 @@ class LoadedPlugin(NamedTuple): class Plugins(NamedTuple): """Classified plugins.""" - archivers: list[Plugin] - tags: list[Plugin] - talkers: list[Plugin] + archivers: list[LoadedPlugin] + tags: list[LoadedPlugin] + talkers: list[LoadedPlugin] - def all_plugins(self) -> Generator[Plugin]: + def all_plugins(self) -> Generator[LoadedPlugin]: """Return an iterator over all :class:`LoadedPlugin`s.""" yield from self.archivers yield from self.tags @@ -83,13 +121,24 @@ class Plugins(NamedTuple): def versions_str(self) -> str: """Return a user-displayed list of plugin versions.""" - return ", ".join(sorted({f"{plugin.package}: {plugin.version}" for plugin in self.all_plugins()})) + return ", ".join(sorted({f"{plugin.plugin.package}: {plugin.plugin.version}" for plugin in self.all_plugins()})) -def _find_local_plugins(plugin_path: pathlib.Path) -> Generator[Plugin]: +def _find_ep_plugin(plugin_path: pathlib.Path) -> None | Generator[Plugin]: + logger.debug("Checking for distributions in %s", plugin_path) + for dist in importlib_metadata.distributions(path=[str(plugin_path)]): + logger.debug("found distribution %s", dist.name) + eps = dist.entry_points + for group in PLUGIN_GROUPS: + for ep in eps.select(group=group): + logger.debug("found EntryPoint group %s %s=%s", group, ep.name, ep.value) + yield Plugin(plugin_path.name, dist.version, ep, plugin_path) + return None + +def _find_cfg_plugin(setup_cfg_path: pathlib.Path) -> Generator[Plugin]: cfg = configparser.ConfigParser(interpolation=None) - cfg.read(plugin_path / "setup.cfg") + cfg.read(setup_cfg_path) for group in PLUGIN_GROUPS: for plugin_s in cfg.get("options.entry_points", group, fallback="").splitlines(): @@ -98,8 +147,43 @@ def _find_local_plugins(plugin_path: pathlib.Path) -> Generator[Plugin]: name, _, entry_str = plugin_s.partition("=") name, entry_str = name.strip(), entry_str.strip() - ep = importlib.metadata.EntryPoint(name, entry_str, group) - yield Plugin(plugin_path.name, cfg.get("metadata", "version", fallback="0.0.1"), ep, plugin_path) + ep = importlib_metadata.EntryPoint(name, entry_str, group) + yield Plugin( + setup_cfg_path.parent.name, cfg.get("metadata", "version", fallback="0.0.1"), ep, setup_cfg_path.parent + ) + + +def _find_pyproject_plugin(pyproject_path: pathlib.Path) -> Generator[Plugin]: + cfg = tomli.loads(pyproject_path.read_text()) + + for group in PLUGIN_GROUPS: + cfg["project"]["entry-points"] + for plugins in cfg.get("project", {}).get("entry-points", {}).get(group, {}): + if not plugins: + continue + for name, entry_str in plugins.items(): + ep = importlib_metadata.EntryPoint(name, entry_str, group) + yield Plugin( + pyproject_path.parent.name, + cfg.get("project", {}).get("version", "0.0.1"), + ep, + pyproject_path.parent, + ) + + +def _find_local_plugins(plugin_path: pathlib.Path) -> Generator[Plugin]: + gen = _find_ep_plugin(plugin_path) + if gen is not None: + yield from gen + return + + if (plugin_path / "setup.cfg").is_file(): + yield from _find_cfg_plugin(plugin_path / "setup.cfg") + return + + if (plugin_path / "pyproject.cfg").is_file(): + yield from _find_pyproject_plugin(plugin_path / "setup.cfg") + return def _check_required_plugins(plugins: list[Plugin], expected: frozenset[str]) -> None: @@ -118,30 +202,48 @@ def _check_required_plugins(plugins: list[Plugin], expected: frozenset[str]) -> def find_plugins(plugin_folder: pathlib.Path) -> Plugins: """Discovers all plugins (but does not load them).""" - ret: list[Plugin] = [] - for plugin_path in plugin_folder.glob("*/setup.cfg"): - try: - ret.extend(_find_local_plugins(plugin_path.parent)) - except Exception as err: - FailedToLoadPlugin(plugin_path.parent.name, err) + ret: list[LoadedPlugin] = [] - # for determinism, sort the list - ret.sort() + dirs = {x.parent for x in plugin_folder.glob("*/setup.cfg")} + dirs.update({x.parent for x in plugin_folder.glob("*/pyproject.toml")}) + + zips = [x for x in plugin_folder.glob("*.zip") if x.is_file()] + + for plugin_path in itertools.chain(os_sorted(zips), os_sorted(dirs)): + logger.debug("looking for plugins in %s", plugin_path) + try: + sys.path.append(str(plugin_path)) + for plugin in _find_local_plugins(plugin_path): + logger.debug("Attempting to load %s from %s", plugin.entry_point.name, plugin.path) + ret.append(plugin.load()) + # sys.path.remove(str(plugin_path)) + except Exception as err: + FailedToLoadPlugin(plugin_path.name, err) + finally: + sys.path.remove(str(plugin_path)) + for mod in list(sys.modules.values()): + if ( + mod is not None + and hasattr(mod, "__spec__") + and mod.__spec__ + and str(plugin_path) in (mod.__spec__.origin or "") + ): + sys.modules.pop(mod.__name__) return _classify_plugins(ret) -def _classify_plugins(plugins: list[Plugin]) -> Plugins: +def _classify_plugins(plugins: list[LoadedPlugin]) -> Plugins: archivers = [] tags = [] talkers = [] for p in plugins: - if p.entry_point.group == "comictagger.talker": + if p.plugin.entry_point.group == "comictagger.talker": talkers.append(p) - elif p.entry_point.group == "comicapi.tags": + elif p.plugin.entry_point.group == "comicapi.tags": tags.append(p) - elif p.entry_point.group == "comicapi.archiver": + elif p.plugin.entry_point.group == "comicapi.archiver": archivers.append(p) else: logger.warning(NotImplementedError(f"what plugin type? {p}")) diff --git a/comictaggerlib/main.py b/comictaggerlib/main.py index e0dec98..7c598c7 100644 --- a/comictaggerlib/main.py +++ b/comictaggerlib/main.py @@ -44,7 +44,6 @@ if sys.version_info < (3, 10): import importlib_metadata else: import importlib.metadata as importlib_metadata - logger = logging.getLogger("comictagger") @@ -124,19 +123,13 @@ class App: def load_plugins(self, opts: argparse.Namespace) -> None: local_plugins = plugin_finder.find_plugins(opts.config.user_plugin_dir) - self._extend_plugin_paths(local_plugins) - comicapi.comicarchive.load_archive_plugins(local_plugins=[p.entry_point for p in local_plugins.archivers]) - comicapi.comicarchive.load_tag_plugins( - version=version, local_plugins=[p.entry_point for p in local_plugins.tags] - ) + comicapi.comicarchive.load_archive_plugins(local_plugins=[p.obj for p in local_plugins.archivers]) + comicapi.comicarchive.load_tag_plugins(version=version, local_plugins=[p.obj for p in local_plugins.tags]) self.talkers = comictalker.get_talkers( - version, opts.config.user_cache_dir, local_plugins=[p.entry_point for p in local_plugins.talkers] + version, opts.config.user_cache_dir, local_plugins=[p.obj for p in local_plugins.talkers] ) - def _extend_plugin_paths(self, plugins: plugin_finder.Plugins) -> None: - sys.path.extend(str(p.path.absolute()) for p in plugins.all_plugins()) - def list_plugins( self, talkers: Collection[comictalker.ComicTalker], diff --git a/comictalker/__init__.py b/comictalker/__init__.py index 070e472..a80903b 100644 --- a/comictalker/__init__.py +++ b/comictalker/__init__.py @@ -9,9 +9,9 @@ from collections.abc import Sequence from packaging.version import InvalidVersion, parse if sys.version_info < (3, 10): - from importlib_metadata import EntryPoint, entry_points + from importlib_metadata import entry_points else: - from importlib.metadata import entry_points, EntryPoint + from importlib.metadata import entry_points from comictalker.comictalker import ComicTalker, TalkerError @@ -24,14 +24,14 @@ __all__ = [ def get_talkers( - version: str, cache: pathlib.Path, local_plugins: Sequence[EntryPoint] = tuple() + version: str, cache: pathlib.Path, local_plugins: Sequence[type[ComicTalker]] = tuple() ) -> dict[str, ComicTalker]: """Returns all comic talker instances""" talkers: dict[str, ComicTalker] = {} ct_version = parse(version) # A dict is used, last plugin wins - for talker in itertools.chain(entry_points(group="comictagger.talker"), local_plugins): + for talker in itertools.chain(entry_points(group="comictagger.talker")): try: talker_cls = talker.load() obj = talker_cls(version, cache) @@ -56,4 +56,26 @@ def get_talkers( except Exception: logger.exception("Failed to load talker: %s", talker.name) + # A dict is used, last plugin wins + for talker_cls in local_plugins: + try: + obj = talker_cls(version, cache) + try: + if ct_version >= parse(talker_cls.comictagger_min_ver): + talkers[talker_cls.id] = obj + else: + logger.error( + f"Minimum ComicTagger version required of {talker_cls.comictagger_min_ver} for talker {talker_cls.id} is not met, will NOT load talker" + ) + except InvalidVersion: + logger.warning( + f"Invalid minimum required ComicTagger version number for talker: {talker_cls.id} - version: {talker_cls.comictagger_min_ver}, will load talker anyway" + ) + # Attempt to use the talker anyway + # TODO flag this problem for later display to the user + talkers[talker_cls.id] = obj + + except Exception: + logger.exception("Failed to load talker: %s", talker_cls.id) + return talkers diff --git a/setup.cfg b/setup.cfg index 761ef77..92d2ea3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,6 +50,7 @@ install_requires = requests==2.* settngs==0.10.4 text2digits + tomli typing-extensions>=4.3.0 wordninja python_requires = >=3.9