Improve StrEnum
Return the actual string for __str__ Allow case insensitive conversion
This commit is contained in:
parent
e96cb8ad15
commit
219ede2d5d
@ -1,51 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Collection
|
||||
from enum import Enum, auto
|
||||
from enum import auto
|
||||
from typing import Any
|
||||
|
||||
from comicapi.utils import norm_fold
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""
|
||||
Enum where members are also (and must be) strings
|
||||
"""
|
||||
|
||||
def __new__(cls, *values: Any) -> Any:
|
||||
"values must already be of type `str`"
|
||||
if len(values) > 3:
|
||||
raise TypeError(f"too many arguments for str(): {values!r}")
|
||||
if len(values) == 1:
|
||||
# it must be a string
|
||||
if not isinstance(values[0], str):
|
||||
raise TypeError(f"{values[0]!r} is not a string")
|
||||
if len(values) >= 2:
|
||||
# check that encoding argument is a string
|
||||
if not isinstance(values[1], str):
|
||||
raise TypeError(f"encoding must be a string, not {values[1]!r}")
|
||||
if len(values) == 3:
|
||||
# check that errors argument is a string
|
||||
if not isinstance(values[2], str):
|
||||
raise TypeError("errors must be a string, not %r" % (values[2]))
|
||||
value = str(*values)
|
||||
member = str.__new__(cls, value)
|
||||
member._value_ = value
|
||||
return member
|
||||
|
||||
@staticmethod
|
||||
def _generate_next_value_(name: str, start: int, count: int, last_values: Any) -> str:
|
||||
"""
|
||||
Return the lower-cased version of the member name.
|
||||
"""
|
||||
return name.lower()
|
||||
|
||||
else:
|
||||
from enum import StrEnum
|
||||
from comicapi.utils import StrEnum, norm_fold
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -82,8 +82,28 @@ if sys.version_info < (3, 11):
|
||||
"""
|
||||
return name.lower()
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: Any) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
if not hasattr(cls, "_lower_members"):
|
||||
cls._lower_members = {x.casefold(): x for x in cls} # type: ignore[attr-defined]
|
||||
return cls._lower_members.get(value.casefold(), None) # type: ignore[attr-defined]
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
else:
|
||||
from enum import StrEnum
|
||||
from enum import StrEnum as s
|
||||
|
||||
class StrEnum(s):
|
||||
@classmethod
|
||||
def _missing_(cls, value: Any) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
if not hasattr(cls, "_lower_members"):
|
||||
cls._lower_members = {x.casefold(): x for x in cls} # type: ignore[attr-defined]
|
||||
return cls._lower_members.get(value.casefold(), None) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
Loading…
Reference in New Issue
Block a user