From 219ede2d5d624b1df446a24d539e7b161799f04e Mon Sep 17 00:00:00 2001 From: Timmy Welch Date: Fri, 19 Jul 2024 15:52:30 -0700 Subject: [PATCH] Improve StrEnum Return the actual string for __str__ Allow case insensitive conversion --- comicapi/merge.py | 43 ++----------------------------------------- comicapi/utils.py | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 42 deletions(-) diff --git a/comicapi/merge.py b/comicapi/merge.py index 2c6c4f6..8e3a724 100644 --- a/comicapi/merge.py +++ b/comicapi/merge.py @@ -1,51 +1,12 @@ from __future__ import annotations import dataclasses -import sys from collections import defaultdict from collections.abc import Collection -from enum import Enum, auto +from enum import auto from typing import Any -from comicapi.utils import norm_fold - -if sys.version_info < (3, 11): - - class StrEnum(str, Enum): - """ - Enum where members are also (and must be) strings - """ - - def __new__(cls, *values: Any) -> Any: - "values must already be of type `str`" - if len(values) > 3: - raise TypeError(f"too many arguments for str(): {values!r}") - if len(values) == 1: - # it must be a string - if not isinstance(values[0], str): - raise TypeError(f"{values[0]!r} is not a string") - if len(values) >= 2: - # check that encoding argument is a string - if not isinstance(values[1], str): - raise TypeError(f"encoding must be a string, not {values[1]!r}") - if len(values) == 3: - # check that errors argument is a string - if not isinstance(values[2], str): - raise TypeError("errors must be a string, not %r" % (values[2])) - value = str(*values) - member = str.__new__(cls, value) - member._value_ = value - return member - - @staticmethod - def _generate_next_value_(name: str, start: int, count: int, last_values: Any) -> str: - """ - Return the lower-cased version of the member name. - """ - return name.lower() - -else: - from enum import StrEnum +from comicapi.utils import StrEnum, norm_fold @dataclasses.dataclass diff --git a/comicapi/utils.py b/comicapi/utils.py index dac74bc..a7165b8 100644 --- a/comicapi/utils.py +++ b/comicapi/utils.py @@ -82,8 +82,28 @@ if sys.version_info < (3, 11): """ return name.lower() + @classmethod + def _missing_(cls, value: Any) -> str | None: + if not isinstance(value, str): + return None + if not hasattr(cls, "_lower_members"): + cls._lower_members = {x.casefold(): x for x in cls} # type: ignore[attr-defined] + return cls._lower_members.get(value.casefold(), None) # type: ignore[attr-defined] + + def __str__(self): + return self.value + else: - from enum import StrEnum + from enum import StrEnum as s + + class StrEnum(s): + @classmethod + def _missing_(cls, value: Any) -> str | None: + if not isinstance(value, str): + return None + if not hasattr(cls, "_lower_members"): + cls._lower_members = {x.casefold(): x for x in cls} # type: ignore[attr-defined] + return cls._lower_members.get(value.casefold(), None) # type: ignore[attr-defined] logger = logging.getLogger(__name__)