Improve type detection

This commit is contained in:
Timmy Welch 2024-05-18 15:49:32 -07:00
parent eca7be0c51
commit 43f6bf1eac
2 changed files with 64 additions and 66 deletions

View File

@ -41,6 +41,11 @@ if sys.version_info < (3, 9): # pragma: no cover
else: else:
return self[:] return self[:]
def get_typing_type(t: type) -> type:
if t.__module__ == 'builtins':
return getattr(typing, t.__name__.title(), t)
return t
class BooleanOptionalAction(argparse.Action): class BooleanOptionalAction(argparse.Action):
def __init__( def __init__(
self, self,
@ -77,7 +82,7 @@ if sys.version_info < (3, 9): # pragma: no cover
metavar=metavar, metavar=metavar,
) )
def __call__(self, parser, namespace, values, option_string=None): # dead: disable def __call__(self, parser, namespace, values, option_string=None): # pragma: no cover dead: disable
if option_string in self.option_strings: if option_string in self.option_strings:
setattr(namespace, self.dest, not option_string.startswith('--no-')) setattr(namespace, self.dest, not option_string.startswith('--no-'))
else: # pragma: no cover else: # pragma: no cover
@ -86,8 +91,11 @@ else: # pragma: no cover
from argparse import BooleanOptionalAction from argparse import BooleanOptionalAction
removeprefix = str.removeprefix removeprefix = str.removeprefix
def get_typing_type(t: type) -> type:
return t
def _isnamedtupleinstance(x: Any) -> bool:
def _isnamedtupleinstance(x: Any) -> bool: # pragma: no cover
t = type(x) t = type(x)
b = t.__bases__ b = t.__bases__
@ -209,58 +217,45 @@ class Setting:
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
def _guess_type(self) -> tuple[type | str | None, bool]: def _guess_type(self) -> tuple[type | str | None, bool]:
if self.type is None and self.action is None:
if self.cmdline:
if self.nargs in ('+', '*') or isinstance(self.nargs, int) and self.nargs > 1:
return List[str], self.default is None
return str, self.default is None
else:
if not self.cmdline and self.default is not None:
if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]:
try:
return cast(type, type(self.default)[type(self.default[0])]), self.default is None
except Exception:
...
return type(self.default), self.default is None
return 'Any', self.default is None
if isinstance(self.type, type): if isinstance(self.type, type):
return self.type, self.default is None return self.type, self.default is None
__action_to_type = {
'store_true': (bool, False),
'store_false': (bool, False),
BooleanOptionalAction: (bool, self.default is None),
'store_const': (type(self.const), self.default is None),
'count': (int, self.default is None),
'append': (List[str], self.default is None),
'extend': (List[str], self.default is None),
'append_const': (List[type(self.const)], self.default is None), # type: ignore[misc]
'help': (None, self.default is None),
'version': (None, self.default is None),
}
if self.action in __action_to_type:
return __action_to_type[self.action]
if self.type is not None: if self.type is not None:
type_hints = typing.get_type_hints(self.type) type_hints = typing.get_type_hints(self.type)
if 'return' in type_hints: if 'return' in type_hints:
t: type | str = type_hints['return'] t: type | str = type_hints['return']
return t, self.default is None return t, self.default is None
if self.default is not None: if self.default is not None:
if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]: if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]:
try: try:
return cast(type, type(self.default)[type(self.default[0])]), self.default is None t = get_typing_type(type(self.default))
ret = cast(type, t[type(self.default[0])]), self.default is None # type: ignore[index]
return ret
except Exception: except Exception:
... ...
return type(self.default), self.default is None return type(self.default), self.default is None
return 'Any', self.default is None
if self.action in ('store_true', 'store_false'): if self.cmdline and self.action is None and self.type is None:
return bool, False if self.nargs in ('+', '*') or isinstance(self.nargs, int) and self.nargs > 1:
if self.action == BooleanOptionalAction:
return bool, self.default is None
if self.action in ('store_const',):
return type(self.const), self.default is None
if self.action in ('count',):
return int, self.default is None
if self.action in ('append', 'extend'):
return List[str], self.default is None return List[str], self.default is None
return str, self.default is None
if self.action in ('append_const',):
return list, self.default is None # list[type(self.const)]
if self.action in ('help', 'version'):
return None, self.default is None
return 'Any', self.default is None return 'Any', self.default is None
def get_dest(self, prefix: str, names: Sequence[str], dest: str | None) -> tuple[str, str, str, bool]: def get_dest(self, prefix: str, names: Sequence[str], dest: str | None) -> tuple[str, str, str, bool]:
@ -329,7 +324,7 @@ def generate_ns(definitions: Definitions) -> tuple[str, str]:
attributes = [] attributes = []
for group in definitions.values(): for group in definitions.values():
for setting in group.v.values(): for setting in group.v.values():
t, no_default = setting._guess_type() t, noneable = setting._guess_type()
if t is None: if t is None:
continue continue
# Default to any # Default to any
@ -354,7 +349,7 @@ def generate_ns(definitions: Definitions) -> tuple[str, str]:
if type_name == 'Any': if type_name == 'Any':
type_name = 'typing.Any' type_name = 'typing.Any'
if no_default and type_name not in ('typing.Any', 'None'): if noneable and type_name not in ('typing.Any', 'None'):
attribute = f' {setting.internal_name}: {type_name} | None' attribute = f' {setting.internal_name}: {type_name} | None'
else: else:
attribute = f' {setting.internal_name}: {type_name}' attribute = f' {setting.internal_name}: {type_name}'

View File

@ -644,7 +644,7 @@ class _customAction(argparse.Action): # pragma: no cover
help=help, help=help,
) )
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None): # pragma: no cover
setattr(namespace, self.dest, 'Something') setattr(namespace, self.dest, 'Something')
@ -657,9 +657,9 @@ types = (
(5, settngs.Setting('-t', '--test', action='extend'), List[str], True), (5, settngs.Setting('-t', '--test', action='extend'), List[str], True),
(6, settngs.Setting('-t', '--test', nargs='+'), List[str], True), (6, settngs.Setting('-t', '--test', nargs='+'), List[str], True),
(7, settngs.Setting('-t', '--test', action='store_const', const=1), int, True), (7, settngs.Setting('-t', '--test', action='store_const', const=1), int, True),
(8, settngs.Setting('-t', '--test', action='append_const', const=1), list, True), (8, settngs.Setting('-t', '--test', action='append_const', const=1), List[int], True),
(9, settngs.Setting('-t', '--test', action='store_true'), bool, True), (9, settngs.Setting('-t', '--test', action='store_true'), bool, False),
(10, settngs.Setting('-t', '--test', action='store_false'), bool, True), (10, settngs.Setting('-t', '--test', action='store_false'), bool, False),
(11, settngs.Setting('-t', '--test', action=settngs.BooleanOptionalAction), bool, True), (11, settngs.Setting('-t', '--test', action=settngs.BooleanOptionalAction), bool, True),
(12, settngs.Setting('-t', '--test', action=_customAction), 'Any', True), (12, settngs.Setting('-t', '--test', action=_customAction), 'Any', True),
(13, settngs.Setting('-t', '--test', action='help'), None, True), (13, settngs.Setting('-t', '--test', action='help'), None, True),
@ -667,16 +667,17 @@ types = (
(15, settngs.Setting('-t', '--test', type=int), int, True), (15, settngs.Setting('-t', '--test', type=int), int, True),
(16, settngs.Setting('-t', '--test', type=_typed_function), test_type, True), (16, settngs.Setting('-t', '--test', type=_typed_function), test_type, True),
(17, settngs.Setting('-t', '--test', type=_untyped_function, default=1), int, False), (17, settngs.Setting('-t', '--test', type=_untyped_function, default=1), int, False),
(18, settngs.Setting('-t', '--test', type=_untyped_function), 'Any', True), (18, settngs.Setting('-t', '--test', type=_untyped_function, default=[1]), List[int], False),
(19, settngs.Setting('-t', '--test', type=_untyped_function), 'Any', True),
) )
@pytest.mark.parametrize('num,setting,typ,no_default_expected', types) @pytest.mark.parametrize('num,setting,typ,noneable_expected', types)
def test_guess_type(num, setting, typ, no_default_expected): def test_guess_type(num, setting, typ, noneable_expected):
x = setting._guess_type() x = setting._guess_type()
guessed_type, no_default = x guessed_type, noneable = x
assert guessed_type == typ assert guessed_type == typ
assert no_default == no_default_expected assert noneable == noneable_expected
expected_src = '''from __future__ import annotations expected_src = '''from __future__ import annotations
@ -704,9 +705,9 @@ settings = (
(5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), (5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), (6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src.format(extra_imports='', typ='int | None')), (7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src.format(extra_imports='', typ='int | None')),
(8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src.format(extra_imports='', typ='list | None')), (8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[int] | None' if sys.version_info < (3, 9) else 'list[int] | None')),
(9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src.format(extra_imports='', typ='bool | None')), (9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src.format(extra_imports='', typ='bool')),
(10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src.format(extra_imports='', typ='bool | None')), (10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src.format(extra_imports='', typ='bool')),
(11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src.format(extra_imports='', typ='bool | None')), (11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src.format(extra_imports='', typ='bool | None')),
(12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src.format(extra_imports='import typing\n', typ='typing.Any')), (12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src.format(extra_imports='import typing\n', typ='typing.Any')),
(13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src), (13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src),
@ -714,7 +715,8 @@ settings = (
(15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src.format(extra_imports='', typ='int | None')), (15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src.format(extra_imports='', typ='int | None')),
(16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type | None')), (16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type | None')),
(17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src.format(extra_imports='', typ='int')), (17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src.format(extra_imports='', typ='int')),
(18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src.format(extra_imports='import typing\n', typ='typing.Any')), (18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=[1]), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[int]' if sys.version_info < (3, 9) else 'list[int]')),
(19, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src.format(extra_imports='import typing\n', typ='typing.Any')),
) )
@ -759,13 +761,13 @@ settings_dict = (
(1, lambda parser: parser.add_setting('-t', '--test', cmdline=False), expected_src_dict.format(extra_imports='', typ='typing.Any')), (1, lambda parser: parser.add_setting('-t', '--test', cmdline=False), expected_src_dict.format(extra_imports='', typ='typing.Any')),
(2, lambda parser: parser.add_setting('-t', '--test', default=1, file=True, cmdline=False), expected_src_dict.format(extra_imports='', typ='int')), (2, lambda parser: parser.add_setting('-t', '--test', default=1, file=True, cmdline=False), expected_src_dict.format(extra_imports='', typ='int')),
(3, lambda parser: parser.add_setting('-t', '--test', action='count'), expected_src_dict.format(extra_imports='', typ='int | None')), (3, lambda parser: parser.add_setting('-t', '--test', action='count'), expected_src_dict.format(extra_imports='', typ='int | None')),
(4, lambda parser: parser.add_setting('-t', '--test', action='append'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), (4, lambda parser: parser.add_setting('-t', '--test', action='append'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), (5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')), (6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src_dict.format(extra_imports='', typ='int | None')), (7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src_dict.format(extra_imports='', typ='int | None')),
(8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src_dict.format(extra_imports='', typ='list | None')), (8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src_dict.format(extra_imports='', typ='typing.List[int] | None' if sys.version_info < (3, 9) else 'list[int] | None')),
(9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src_dict.format(extra_imports='', typ='bool | None')), (9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src_dict.format(extra_imports='', typ='bool')),
(10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src_dict.format(extra_imports='', typ='bool | None')), (10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src_dict.format(extra_imports='', typ='bool')),
(11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src_dict.format(extra_imports='', typ='bool | None')), (11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src_dict.format(extra_imports='', typ='bool | None')),
(12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src_dict.format(extra_imports='', typ='typing.Any')), (12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src_dict.format(extra_imports='', typ='typing.Any')),
(13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src_dict), (13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src_dict),
@ -773,7 +775,8 @@ settings_dict = (
(15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src_dict.format(extra_imports='', typ='int | None')), (15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src_dict.format(extra_imports='', typ='int | None')),
(16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src_dict.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type | None')), (16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src_dict.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type | None')),
(17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src_dict.format(extra_imports='', typ='int')), (17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src_dict.format(extra_imports='', typ='int')),
(18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src_dict.format(extra_imports='', typ='typing.Any')), (18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=[1]), expected_src_dict.format(extra_imports='', typ='typing.List[int]' if sys.version_info < (3, 9) else 'list[int]')),
(19, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src_dict.format(extra_imports='', typ='typing.Any')),
) )