Implement the most basic local plugin isolation possible
Remove modules belonging to local plugins after loading Remove sys.path entry after loading This means that multiple local plugins can be installed with the same import path and should work correctly This does not allow loading a local plugin that has the same import path as an installed plugin
This commit is contained in:
parent
c9c0c99a2a
commit
4c9096a11b
@ -24,7 +24,6 @@ import pathlib
|
||||
import shutil
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from comicapi import utils
|
||||
from comicapi.archivers import Archiver, UnknownArchiver, ZipArchiver
|
||||
@ -32,16 +31,13 @@ from comicapi.genericmetadata import GenericMetadata
|
||||
from comicapi.tags import Tag
|
||||
from comictaggerlib.ctversion import version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from importlib.metadata import EntryPoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
archivers: list[type[Archiver]] = []
|
||||
tags: dict[str, Tag] = {}
|
||||
|
||||
|
||||
def load_archive_plugins(local_plugins: Iterable[EntryPoint] = tuple()) -> None:
|
||||
def load_archive_plugins(local_plugins: Iterable[type[Archiver]] = tuple()) -> None:
|
||||
if archivers:
|
||||
return
|
||||
if sys.version_info < (3, 10):
|
||||
@ -52,7 +48,7 @@ def load_archive_plugins(local_plugins: Iterable[EntryPoint] = tuple()) -> None:
|
||||
archive_plugins: list[type[Archiver]] = []
|
||||
# A list is used first matching plugin wins
|
||||
|
||||
for ep in itertools.chain(local_plugins, entry_points(group="comicapi.archiver")):
|
||||
for ep in itertools.chain(entry_points(group="comicapi.archiver")):
|
||||
try:
|
||||
spec = importlib.util.find_spec(ep.module)
|
||||
except ValueError:
|
||||
@ -70,11 +66,12 @@ def load_archive_plugins(local_plugins: Iterable[EntryPoint] = tuple()) -> None:
|
||||
else:
|
||||
logger.exception("Failed to load archive plugin: %s", ep.name)
|
||||
archivers.clear()
|
||||
archivers.extend(local_plugins)
|
||||
archivers.extend(archive_plugins)
|
||||
archivers.extend(builtin)
|
||||
|
||||
|
||||
def load_tag_plugins(version: str = f"ComicAPI/{version}", local_plugins: Iterable[EntryPoint] = tuple()) -> None:
|
||||
def load_tag_plugins(version: str = f"ComicAPI/{version}", local_plugins: Iterable[type[Tag]] = tuple()) -> None:
|
||||
if tags:
|
||||
return
|
||||
if sys.version_info < (3, 10):
|
||||
@ -84,7 +81,7 @@ def load_tag_plugins(version: str = f"ComicAPI/{version}", local_plugins: Iterab
|
||||
builtin: dict[str, Tag] = {}
|
||||
tag_plugins: dict[str, tuple[Tag, str]] = {}
|
||||
# A dict is used, last plugin wins
|
||||
for ep in itertools.chain(entry_points(group="comicapi.tags"), local_plugins):
|
||||
for ep in entry_points(group="comicapi.tags"):
|
||||
location = "Unknown"
|
||||
try:
|
||||
_spec = importlib.util.find_spec(ep.module)
|
||||
@ -109,6 +106,9 @@ def load_tag_plugins(version: str = f"ComicAPI/{version}", local_plugins: Iterab
|
||||
tag_plugins[tag.id] = (tag(version), location)
|
||||
except Exception:
|
||||
logger.exception("Failed to load tag plugin: %s from %s", ep.name, location)
|
||||
# A dict is used, last plugin wins
|
||||
for tag in local_plugins:
|
||||
tag_plugins[tag.id] = (tag(version), "Local")
|
||||
|
||||
for tag_id in set(builtin.keys()).intersection(tag_plugins):
|
||||
location = tag_plugins[tag_id][1]
|
||||
|
@ -5,17 +5,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import configparser
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import itertools
|
||||
import logging
|
||||
import pathlib
|
||||
import platform
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any, NamedTuple
|
||||
import sys
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Any, NamedTuple, TypeVar
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
import tomli
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NORMALIZE_PACKAGE_NAME_RE = re.compile(r"[-_.]+")
|
||||
PLUGIN_GROUPS = frozenset(("comictagger.talker", "comicapi.archiver", "comicapi.tags"))
|
||||
icu_available = importlib.util.find_spec("icu") is not None
|
||||
|
||||
|
||||
def _custom_key(tup: Any) -> Any:
|
||||
import natsort
|
||||
|
||||
lst = []
|
||||
for x in natsort.os_sort_keygen()(tup):
|
||||
ret = x
|
||||
if len(x) > 1 and isinstance(x[1], int) and isinstance(x[0], str) and x[0] == "":
|
||||
ret = ("a", *x[1:])
|
||||
|
||||
lst.append(ret)
|
||||
return tuple(lst)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def os_sorted(lst: Iterable[T]) -> Iterable[T]:
|
||||
import natsort
|
||||
|
||||
key = _custom_key
|
||||
if icu_available or platform.system() == "Windows":
|
||||
key = natsort.os_sort_keygen()
|
||||
return sorted(lst, key=key)
|
||||
|
||||
|
||||
class FailedToLoadPlugin(Exception):
|
||||
@ -47,9 +82,12 @@ class Plugin(NamedTuple):
|
||||
|
||||
package: str
|
||||
version: str
|
||||
entry_point: importlib.metadata.EntryPoint
|
||||
entry_point: importlib_metadata.EntryPoint
|
||||
path: pathlib.Path
|
||||
|
||||
def load(self) -> LoadedPlugin:
|
||||
return LoadedPlugin(self, self.entry_point.load())
|
||||
|
||||
|
||||
class LoadedPlugin(NamedTuple):
|
||||
"""Represents a plugin after being imported."""
|
||||
@ -71,11 +109,11 @@ class LoadedPlugin(NamedTuple):
|
||||
class Plugins(NamedTuple):
|
||||
"""Classified plugins."""
|
||||
|
||||
archivers: list[Plugin]
|
||||
tags: list[Plugin]
|
||||
talkers: list[Plugin]
|
||||
archivers: list[LoadedPlugin]
|
||||
tags: list[LoadedPlugin]
|
||||
talkers: list[LoadedPlugin]
|
||||
|
||||
def all_plugins(self) -> Generator[Plugin]:
|
||||
def all_plugins(self) -> Generator[LoadedPlugin]:
|
||||
"""Return an iterator over all :class:`LoadedPlugin`s."""
|
||||
yield from self.archivers
|
||||
yield from self.tags
|
||||
@ -83,13 +121,24 @@ class Plugins(NamedTuple):
|
||||
|
||||
def versions_str(self) -> str:
|
||||
"""Return a user-displayed list of plugin versions."""
|
||||
return ", ".join(sorted({f"{plugin.package}: {plugin.version}" for plugin in self.all_plugins()}))
|
||||
return ", ".join(sorted({f"{plugin.plugin.package}: {plugin.plugin.version}" for plugin in self.all_plugins()}))
|
||||
|
||||
|
||||
def _find_local_plugins(plugin_path: pathlib.Path) -> Generator[Plugin]:
|
||||
def _find_ep_plugin(plugin_path: pathlib.Path) -> None | Generator[Plugin]:
|
||||
logger.debug("Checking for distributions in %s", plugin_path)
|
||||
for dist in importlib_metadata.distributions(path=[str(plugin_path)]):
|
||||
logger.debug("found distribution %s", dist.name)
|
||||
eps = dist.entry_points
|
||||
for group in PLUGIN_GROUPS:
|
||||
for ep in eps.select(group=group):
|
||||
logger.debug("found EntryPoint group %s %s=%s", group, ep.name, ep.value)
|
||||
yield Plugin(plugin_path.name, dist.version, ep, plugin_path)
|
||||
return None
|
||||
|
||||
|
||||
def _find_cfg_plugin(setup_cfg_path: pathlib.Path) -> Generator[Plugin]:
|
||||
cfg = configparser.ConfigParser(interpolation=None)
|
||||
cfg.read(plugin_path / "setup.cfg")
|
||||
cfg.read(setup_cfg_path)
|
||||
|
||||
for group in PLUGIN_GROUPS:
|
||||
for plugin_s in cfg.get("options.entry_points", group, fallback="").splitlines():
|
||||
@ -98,8 +147,43 @@ def _find_local_plugins(plugin_path: pathlib.Path) -> Generator[Plugin]:
|
||||
|
||||
name, _, entry_str = plugin_s.partition("=")
|
||||
name, entry_str = name.strip(), entry_str.strip()
|
||||
ep = importlib.metadata.EntryPoint(name, entry_str, group)
|
||||
yield Plugin(plugin_path.name, cfg.get("metadata", "version", fallback="0.0.1"), ep, plugin_path)
|
||||
ep = importlib_metadata.EntryPoint(name, entry_str, group)
|
||||
yield Plugin(
|
||||
setup_cfg_path.parent.name, cfg.get("metadata", "version", fallback="0.0.1"), ep, setup_cfg_path.parent
|
||||
)
|
||||
|
||||
|
||||
def _find_pyproject_plugin(pyproject_path: pathlib.Path) -> Generator[Plugin]:
|
||||
cfg = tomli.loads(pyproject_path.read_text())
|
||||
|
||||
for group in PLUGIN_GROUPS:
|
||||
cfg["project"]["entry-points"]
|
||||
for plugins in cfg.get("project", {}).get("entry-points", {}).get(group, {}):
|
||||
if not plugins:
|
||||
continue
|
||||
for name, entry_str in plugins.items():
|
||||
ep = importlib_metadata.EntryPoint(name, entry_str, group)
|
||||
yield Plugin(
|
||||
pyproject_path.parent.name,
|
||||
cfg.get("project", {}).get("version", "0.0.1"),
|
||||
ep,
|
||||
pyproject_path.parent,
|
||||
)
|
||||
|
||||
|
||||
def _find_local_plugins(plugin_path: pathlib.Path) -> Generator[Plugin]:
|
||||
gen = _find_ep_plugin(plugin_path)
|
||||
if gen is not None:
|
||||
yield from gen
|
||||
return
|
||||
|
||||
if (plugin_path / "setup.cfg").is_file():
|
||||
yield from _find_cfg_plugin(plugin_path / "setup.cfg")
|
||||
return
|
||||
|
||||
if (plugin_path / "pyproject.cfg").is_file():
|
||||
yield from _find_pyproject_plugin(plugin_path / "setup.cfg")
|
||||
return
|
||||
|
||||
|
||||
def _check_required_plugins(plugins: list[Plugin], expected: frozenset[str]) -> None:
|
||||
@ -118,30 +202,48 @@ def _check_required_plugins(plugins: list[Plugin], expected: frozenset[str]) ->
|
||||
|
||||
def find_plugins(plugin_folder: pathlib.Path) -> Plugins:
|
||||
"""Discovers all plugins (but does not load them)."""
|
||||
ret: list[Plugin] = []
|
||||
for plugin_path in plugin_folder.glob("*/setup.cfg"):
|
||||
try:
|
||||
ret.extend(_find_local_plugins(plugin_path.parent))
|
||||
except Exception as err:
|
||||
FailedToLoadPlugin(plugin_path.parent.name, err)
|
||||
ret: list[LoadedPlugin] = []
|
||||
|
||||
# for determinism, sort the list
|
||||
ret.sort()
|
||||
dirs = {x.parent for x in plugin_folder.glob("*/setup.cfg")}
|
||||
dirs.update({x.parent for x in plugin_folder.glob("*/pyproject.toml")})
|
||||
|
||||
zips = [x for x in plugin_folder.glob("*.zip") if x.is_file()]
|
||||
|
||||
for plugin_path in itertools.chain(os_sorted(zips), os_sorted(dirs)):
|
||||
logger.debug("looking for plugins in %s", plugin_path)
|
||||
try:
|
||||
sys.path.append(str(plugin_path))
|
||||
for plugin in _find_local_plugins(plugin_path):
|
||||
logger.debug("Attempting to load %s from %s", plugin.entry_point.name, plugin.path)
|
||||
ret.append(plugin.load())
|
||||
# sys.path.remove(str(plugin_path))
|
||||
except Exception as err:
|
||||
FailedToLoadPlugin(plugin_path.name, err)
|
||||
finally:
|
||||
sys.path.remove(str(plugin_path))
|
||||
for mod in list(sys.modules.values()):
|
||||
if (
|
||||
mod is not None
|
||||
and hasattr(mod, "__spec__")
|
||||
and mod.__spec__
|
||||
and str(plugin_path) in (mod.__spec__.origin or "")
|
||||
):
|
||||
sys.modules.pop(mod.__name__)
|
||||
|
||||
return _classify_plugins(ret)
|
||||
|
||||
|
||||
def _classify_plugins(plugins: list[Plugin]) -> Plugins:
|
||||
def _classify_plugins(plugins: list[LoadedPlugin]) -> Plugins:
|
||||
archivers = []
|
||||
tags = []
|
||||
talkers = []
|
||||
|
||||
for p in plugins:
|
||||
if p.entry_point.group == "comictagger.talker":
|
||||
if p.plugin.entry_point.group == "comictagger.talker":
|
||||
talkers.append(p)
|
||||
elif p.entry_point.group == "comicapi.tags":
|
||||
elif p.plugin.entry_point.group == "comicapi.tags":
|
||||
tags.append(p)
|
||||
elif p.entry_point.group == "comicapi.archiver":
|
||||
elif p.plugin.entry_point.group == "comicapi.archiver":
|
||||
archivers.append(p)
|
||||
else:
|
||||
logger.warning(NotImplementedError(f"what plugin type? {p}"))
|
||||
|
@ -44,7 +44,6 @@ if sys.version_info < (3, 10):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
|
||||
logger = logging.getLogger("comictagger")
|
||||
|
||||
|
||||
@ -124,19 +123,13 @@ class App:
|
||||
|
||||
def load_plugins(self, opts: argparse.Namespace) -> None:
|
||||
local_plugins = plugin_finder.find_plugins(opts.config.user_plugin_dir)
|
||||
self._extend_plugin_paths(local_plugins)
|
||||
|
||||
comicapi.comicarchive.load_archive_plugins(local_plugins=[p.entry_point for p in local_plugins.archivers])
|
||||
comicapi.comicarchive.load_tag_plugins(
|
||||
version=version, local_plugins=[p.entry_point for p in local_plugins.tags]
|
||||
)
|
||||
comicapi.comicarchive.load_archive_plugins(local_plugins=[p.obj for p in local_plugins.archivers])
|
||||
comicapi.comicarchive.load_tag_plugins(version=version, local_plugins=[p.obj for p in local_plugins.tags])
|
||||
self.talkers = comictalker.get_talkers(
|
||||
version, opts.config.user_cache_dir, local_plugins=[p.entry_point for p in local_plugins.talkers]
|
||||
version, opts.config.user_cache_dir, local_plugins=[p.obj for p in local_plugins.talkers]
|
||||
)
|
||||
|
||||
def _extend_plugin_paths(self, plugins: plugin_finder.Plugins) -> None:
|
||||
sys.path.extend(str(p.path.absolute()) for p in plugins.all_plugins())
|
||||
|
||||
def list_plugins(
|
||||
self,
|
||||
talkers: Collection[comictalker.ComicTalker],
|
||||
|
@ -9,9 +9,9 @@ from collections.abc import Sequence
|
||||
from packaging.version import InvalidVersion, parse
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_metadata import EntryPoint, entry_points
|
||||
from importlib_metadata import entry_points
|
||||
else:
|
||||
from importlib.metadata import entry_points, EntryPoint
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
from comictalker.comictalker import ComicTalker, TalkerError
|
||||
|
||||
@ -24,14 +24,14 @@ __all__ = [
|
||||
|
||||
|
||||
def get_talkers(
|
||||
version: str, cache: pathlib.Path, local_plugins: Sequence[EntryPoint] = tuple()
|
||||
version: str, cache: pathlib.Path, local_plugins: Sequence[type[ComicTalker]] = tuple()
|
||||
) -> dict[str, ComicTalker]:
|
||||
"""Returns all comic talker instances"""
|
||||
talkers: dict[str, ComicTalker] = {}
|
||||
ct_version = parse(version)
|
||||
|
||||
# A dict is used, last plugin wins
|
||||
for talker in itertools.chain(entry_points(group="comictagger.talker"), local_plugins):
|
||||
for talker in itertools.chain(entry_points(group="comictagger.talker")):
|
||||
try:
|
||||
talker_cls = talker.load()
|
||||
obj = talker_cls(version, cache)
|
||||
@ -56,4 +56,26 @@ def get_talkers(
|
||||
except Exception:
|
||||
logger.exception("Failed to load talker: %s", talker.name)
|
||||
|
||||
# A dict is used, last plugin wins
|
||||
for talker_cls in local_plugins:
|
||||
try:
|
||||
obj = talker_cls(version, cache)
|
||||
try:
|
||||
if ct_version >= parse(talker_cls.comictagger_min_ver):
|
||||
talkers[talker_cls.id] = obj
|
||||
else:
|
||||
logger.error(
|
||||
f"Minimum ComicTagger version required of {talker_cls.comictagger_min_ver} for talker {talker_cls.id} is not met, will NOT load talker"
|
||||
)
|
||||
except InvalidVersion:
|
||||
logger.warning(
|
||||
f"Invalid minimum required ComicTagger version number for talker: {talker_cls.id} - version: {talker_cls.comictagger_min_ver}, will load talker anyway"
|
||||
)
|
||||
# Attempt to use the talker anyway
|
||||
# TODO flag this problem for later display to the user
|
||||
talkers[talker_cls.id] = obj
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to load talker: %s", talker_cls.id)
|
||||
|
||||
return talkers
|
||||
|
Loading…
Reference in New Issue
Block a user