Improve StrEnum

Return the actual string for __str__
Allow case insensitive conversion
This commit is contained in:
Timmy Welch 2024-07-19 15:52:30 -07:00
parent e96cb8ad15
commit 219ede2d5d
2 changed files with 23 additions and 42 deletions

View File

@ -1,51 +1,12 @@
from __future__ import annotations
import dataclasses
import sys
from collections import defaultdict
from collections.abc import Collection
from enum import Enum, auto
from enum import auto
from typing import Any
from comicapi.utils import norm_fold
if sys.version_info < (3, 11):
class StrEnum(str, Enum):
"""
Enum where members are also (and must be) strings
"""
def __new__(cls, *values: Any) -> Any:
"values must already be of type `str`"
if len(values) > 3:
raise TypeError(f"too many arguments for str(): {values!r}")
if len(values) == 1:
# it must be a string
if not isinstance(values[0], str):
raise TypeError(f"{values[0]!r} is not a string")
if len(values) >= 2:
# check that encoding argument is a string
if not isinstance(values[1], str):
raise TypeError(f"encoding must be a string, not {values[1]!r}")
if len(values) == 3:
# check that errors argument is a string
if not isinstance(values[2], str):
raise TypeError("errors must be a string, not %r" % (values[2]))
value = str(*values)
member = str.__new__(cls, value)
member._value_ = value
return member
@staticmethod
def _generate_next_value_(name: str, start: int, count: int, last_values: Any) -> str:
"""
Return the lower-cased version of the member name.
"""
return name.lower()
else:
from enum import StrEnum
from comicapi.utils import StrEnum, norm_fold
@dataclasses.dataclass

View File

@ -82,8 +82,28 @@ if sys.version_info < (3, 11):
"""
return name.lower()
@classmethod
def _missing_(cls, value: Any) -> str | None:
if not isinstance(value, str):
return None
if not hasattr(cls, "_lower_members"):
cls._lower_members = {x.casefold(): x for x in cls} # type: ignore[attr-defined]
return cls._lower_members.get(value.casefold(), None) # type: ignore[attr-defined]
def __str__(self):
return self.value
else:
from enum import StrEnum
from enum import StrEnum as s
class StrEnum(s):
@classmethod
def _missing_(cls, value: Any) -> str | None:
if not isinstance(value, str):
return None
if not hasattr(cls, "_lower_members"):
cls._lower_members = {x.casefold(): x for x in cls} # type: ignore[attr-defined]
return cls._lower_members.get(value.casefold(), None) # type: ignore[attr-defined]
logger = logging.getLogger(__name__)