From 43f6bf1eacb55a70066d15f8e63fa55d5357d825 Mon Sep 17 00:00:00 2001 From: Timmy Welch Date: Sat, 18 May 2024 15:49:32 -0700 Subject: [PATCH] Improve type detection --- settngs/__init__.py | 87 ++++++++++++++++++++----------------------- tests/settngs_test.py | 43 +++++++++++---------- 2 files changed, 64 insertions(+), 66 deletions(-) diff --git a/settngs/__init__.py b/settngs/__init__.py index 03485b0..ffa5ce4 100644 --- a/settngs/__init__.py +++ b/settngs/__init__.py @@ -41,6 +41,11 @@ if sys.version_info < (3, 9): # pragma: no cover else: return self[:] + def get_typing_type(t: type) -> type: + if t.__module__ == 'builtins': + return getattr(typing, t.__name__.title(), t) + return t + class BooleanOptionalAction(argparse.Action): def __init__( self, @@ -77,7 +82,7 @@ if sys.version_info < (3, 9): # pragma: no cover metavar=metavar, ) - def __call__(self, parser, namespace, values, option_string=None): # dead: disable + def __call__(self, parser, namespace, values, option_string=None): # pragma: no cover dead: disable if option_string in self.option_strings: setattr(namespace, self.dest, not option_string.startswith('--no-')) else: # pragma: no cover @@ -86,8 +91,11 @@ else: # pragma: no cover from argparse import BooleanOptionalAction removeprefix = str.removeprefix + def get_typing_type(t: type) -> type: + return t -def _isnamedtupleinstance(x: Any) -> bool: + +def _isnamedtupleinstance(x: Any) -> bool: # pragma: no cover t = type(x) b = t.__bases__ @@ -209,58 +217,45 @@ class Setting: return self.__dict__ == other.__dict__ def _guess_type(self) -> tuple[type | str | None, bool]: - if self.type is None and self.action is None: - if self.cmdline: - if self.nargs in ('+', '*') or isinstance(self.nargs, int) and self.nargs > 1: - return List[str], self.default is None - return str, self.default is None - else: - if not self.cmdline and self.default is not None: - if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]: - try: - return cast(type, type(self.default)[type(self.default[0])]), self.default is None - except Exception: - ... - return type(self.default), self.default is None - return 'Any', self.default is None - if isinstance(self.type, type): return self.type, self.default is None + __action_to_type = { + 'store_true': (bool, False), + 'store_false': (bool, False), + BooleanOptionalAction: (bool, self.default is None), + 'store_const': (type(self.const), self.default is None), + 'count': (int, self.default is None), + 'append': (List[str], self.default is None), + 'extend': (List[str], self.default is None), + 'append_const': (List[type(self.const)], self.default is None), # type: ignore[misc] + 'help': (None, self.default is None), + 'version': (None, self.default is None), + } + + if self.action in __action_to_type: + return __action_to_type[self.action] + if self.type is not None: type_hints = typing.get_type_hints(self.type) if 'return' in type_hints: t: type | str = type_hints['return'] return t, self.default is None - if self.default is not None: - if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]: - try: - return cast(type, type(self.default)[type(self.default[0])]), self.default is None - except Exception: - ... - return type(self.default), self.default is None - return 'Any', self.default is None - if self.action in ('store_true', 'store_false'): - return bool, False + if self.default is not None: + if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]: + try: + t = get_typing_type(type(self.default)) + ret = cast(type, t[type(self.default[0])]), self.default is None # type: ignore[index] + return ret + except Exception: + ... + return type(self.default), self.default is None - if self.action == BooleanOptionalAction: - return bool, self.default is None - - if self.action in ('store_const',): - return type(self.const), self.default is None - - if self.action in ('count',): - return int, self.default is None - - if self.action in ('append', 'extend'): - return List[str], self.default is None - - if self.action in ('append_const',): - return list, self.default is None # list[type(self.const)] - - if self.action in ('help', 'version'): - return None, self.default is None + if self.cmdline and self.action is None and self.type is None: + if self.nargs in ('+', '*') or isinstance(self.nargs, int) and self.nargs > 1: + return List[str], self.default is None + return str, self.default is None return 'Any', self.default is None def get_dest(self, prefix: str, names: Sequence[str], dest: str | None) -> tuple[str, str, str, bool]: @@ -329,7 +324,7 @@ def generate_ns(definitions: Definitions) -> tuple[str, str]: attributes = [] for group in definitions.values(): for setting in group.v.values(): - t, no_default = setting._guess_type() + t, noneable = setting._guess_type() if t is None: continue # Default to any @@ -354,7 +349,7 @@ def generate_ns(definitions: Definitions) -> tuple[str, str]: if type_name == 'Any': type_name = 'typing.Any' - if no_default and type_name not in ('typing.Any', 'None'): + if noneable and type_name not in ('typing.Any', 'None'): attribute = f' {setting.internal_name}: {type_name} | None' else: attribute = f' {setting.internal_name}: {type_name}' diff --git a/tests/settngs_test.py b/tests/settngs_test.py index 831e70f..db5fd38 100644 --- a/tests/settngs_test.py +++ b/tests/settngs_test.py @@ -644,7 +644,7 @@ class _customAction(argparse.Action): # pragma: no cover help=help, ) - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, parser, namespace, values, option_string=None): # pragma: no cover setattr(namespace, self.dest, 'Something') @@ -657,9 +657,9 @@ types = ( (5, settngs.Setting('-t', '--test', action='extend'), List[str], True), (6, settngs.Setting('-t', '--test', nargs='+'), List[str], True), (7, settngs.Setting('-t', '--test', action='store_const', const=1), int, True), - (8, settngs.Setting('-t', '--test', action='append_const', const=1), list, True), - (9, settngs.Setting('-t', '--test', action='store_true'), bool, True), - (10, settngs.Setting('-t', '--test', action='store_false'), bool, True), + (8, settngs.Setting('-t', '--test', action='append_const', const=1), List[int], True), + (9, settngs.Setting('-t', '--test', action='store_true'), bool, False), + (10, settngs.Setting('-t', '--test', action='store_false'), bool, False), (11, settngs.Setting('-t', '--test', action=settngs.BooleanOptionalAction), bool, True), (12, settngs.Setting('-t', '--test', action=_customAction), 'Any', True), (13, settngs.Setting('-t', '--test', action='help'), None, True), @@ -667,16 +667,17 @@ types = ( (15, settngs.Setting('-t', '--test', type=int), int, True), (16, settngs.Setting('-t', '--test', type=_typed_function), test_type, True), (17, settngs.Setting('-t', '--test', type=_untyped_function, default=1), int, False), - (18, settngs.Setting('-t', '--test', type=_untyped_function), 'Any', True), + (18, settngs.Setting('-t', '--test', type=_untyped_function, default=[1]), List[int], False), + (19, settngs.Setting('-t', '--test', type=_untyped_function), 'Any', True), ) -@pytest.mark.parametrize('num,setting,typ,no_default_expected', types) -def test_guess_type(num, setting, typ, no_default_expected): +@pytest.mark.parametrize('num,setting,typ,noneable_expected', types) +def test_guess_type(num, setting, typ, noneable_expected): x = setting._guess_type() - guessed_type, no_default = x + guessed_type, noneable = x assert guessed_type == typ - assert no_default == no_default_expected + assert noneable == noneable_expected expected_src = '''from __future__ import annotations @@ -704,9 +705,9 @@ settings = ( (5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), (6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), (7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src.format(extra_imports='', typ='int | None')), - (8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src.format(extra_imports='', typ='list | None')), - (9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src.format(extra_imports='', typ='bool | None')), - (10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src.format(extra_imports='', typ='bool | None')), + (8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[int] | None' if sys.version_info < (3, 9) else 'list[int] | None')), + (9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src.format(extra_imports='', typ='bool')), + (10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src.format(extra_imports='', typ='bool')), (11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src.format(extra_imports='', typ='bool | None')), (12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src.format(extra_imports='import typing\n', typ='typing.Any')), (13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src), @@ -714,7 +715,8 @@ settings = ( (15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src.format(extra_imports='', typ='int | None')), (16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type | None')), (17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src.format(extra_imports='', typ='int')), - (18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src.format(extra_imports='import typing\n', typ='typing.Any')), + (18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=[1]), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[int]' if sys.version_info < (3, 9) else 'list[int]')), + (19, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src.format(extra_imports='import typing\n', typ='typing.Any')), ) @@ -759,13 +761,13 @@ settings_dict = ( (1, lambda parser: parser.add_setting('-t', '--test', cmdline=False), expected_src_dict.format(extra_imports='', typ='typing.Any')), (2, lambda parser: parser.add_setting('-t', '--test', default=1, file=True, cmdline=False), expected_src_dict.format(extra_imports='', typ='int')), (3, lambda parser: parser.add_setting('-t', '--test', action='count'), expected_src_dict.format(extra_imports='', typ='int | None')), - (4, lambda parser: parser.add_setting('-t', '--test', action='append'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), - (5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), - (6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), + (4, lambda parser: parser.add_setting('-t', '--test', action='append'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), + (5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), + (6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), (7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src_dict.format(extra_imports='', typ='int | None')), - (8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src_dict.format(extra_imports='', typ='list | None')), - (9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src_dict.format(extra_imports='', typ='bool | None')), - (10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src_dict.format(extra_imports='', typ='bool | None')), + (8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src_dict.format(extra_imports='', typ='typing.List[int] | None' if sys.version_info < (3, 9) else 'list[int] | None')), + (9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src_dict.format(extra_imports='', typ='bool')), + (10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src_dict.format(extra_imports='', typ='bool')), (11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src_dict.format(extra_imports='', typ='bool | None')), (12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src_dict.format(extra_imports='', typ='typing.Any')), (13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src_dict), @@ -773,7 +775,8 @@ settings_dict = ( (15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src_dict.format(extra_imports='', typ='int | None')), (16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src_dict.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type | None')), (17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src_dict.format(extra_imports='', typ='int')), - (18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src_dict.format(extra_imports='', typ='typing.Any')), + (18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=[1]), expected_src_dict.format(extra_imports='', typ='typing.List[int]' if sys.version_info < (3, 9) else 'list[int]')), + (19, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src_dict.format(extra_imports='', typ='typing.Any')), )