Improve type detection

This commit is contained in:
Timmy Welch 2024-05-18 15:49:32 -07:00
parent eca7be0c51
commit 43f6bf1eac
2 changed files with 64 additions and 66 deletions

View File

@ -41,6 +41,11 @@ if sys.version_info < (3, 9): # pragma: no cover
else:
return self[:]
def get_typing_type(t: type) -> type:
if t.__module__ == 'builtins':
return getattr(typing, t.__name__.title(), t)
return t
class BooleanOptionalAction(argparse.Action):
def __init__(
self,
@ -77,7 +82,7 @@ if sys.version_info < (3, 9): # pragma: no cover
metavar=metavar,
)
def __call__(self, parser, namespace, values, option_string=None): # dead: disable
def __call__(self, parser, namespace, values, option_string=None): # pragma: no cover dead: disable
if option_string in self.option_strings:
setattr(namespace, self.dest, not option_string.startswith('--no-'))
else: # pragma: no cover
@ -86,8 +91,11 @@ else: # pragma: no cover
from argparse import BooleanOptionalAction
removeprefix = str.removeprefix
def get_typing_type(t: type) -> type:
return t
def _isnamedtupleinstance(x: Any) -> bool:
def _isnamedtupleinstance(x: Any) -> bool: # pragma: no cover
t = type(x)
b = t.__bases__
@ -209,58 +217,45 @@ class Setting:
return self.__dict__ == other.__dict__
def _guess_type(self) -> tuple[type | str | None, bool]:
if self.type is None and self.action is None:
if self.cmdline:
if self.nargs in ('+', '*') or isinstance(self.nargs, int) and self.nargs > 1:
return List[str], self.default is None
return str, self.default is None
else:
if not self.cmdline and self.default is not None:
if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]:
try:
return cast(type, type(self.default)[type(self.default[0])]), self.default is None
except Exception:
...
return type(self.default), self.default is None
return 'Any', self.default is None
if isinstance(self.type, type):
return self.type, self.default is None
__action_to_type = {
'store_true': (bool, False),
'store_false': (bool, False),
BooleanOptionalAction: (bool, self.default is None),
'store_const': (type(self.const), self.default is None),
'count': (int, self.default is None),
'append': (List[str], self.default is None),
'extend': (List[str], self.default is None),
'append_const': (List[type(self.const)], self.default is None), # type: ignore[misc]
'help': (None, self.default is None),
'version': (None, self.default is None),
}
if self.action in __action_to_type:
return __action_to_type[self.action]
if self.type is not None:
type_hints = typing.get_type_hints(self.type)
if 'return' in type_hints:
t: type | str = type_hints['return']
return t, self.default is None
if self.default is not None:
if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]:
try:
return cast(type, type(self.default)[type(self.default[0])]), self.default is None
except Exception:
...
return type(self.default), self.default is None
return 'Any', self.default is None
if self.action in ('store_true', 'store_false'):
return bool, False
if self.default is not None:
if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]:
try:
t = get_typing_type(type(self.default))
ret = cast(type, t[type(self.default[0])]), self.default is None # type: ignore[index]
return ret
except Exception:
...
return type(self.default), self.default is None
if self.action == BooleanOptionalAction:
return bool, self.default is None
if self.action in ('store_const',):
return type(self.const), self.default is None
if self.action in ('count',):
return int, self.default is None
if self.action in ('append', 'extend'):
return List[str], self.default is None
if self.action in ('append_const',):
return list, self.default is None # list[type(self.const)]
if self.action in ('help', 'version'):
return None, self.default is None
if self.cmdline and self.action is None and self.type is None:
if self.nargs in ('+', '*') or isinstance(self.nargs, int) and self.nargs > 1:
return List[str], self.default is None
return str, self.default is None
return 'Any', self.default is None
def get_dest(self, prefix: str, names: Sequence[str], dest: str | None) -> tuple[str, str, str, bool]:
@ -329,7 +324,7 @@ def generate_ns(definitions: Definitions) -> tuple[str, str]:
attributes = []
for group in definitions.values():
for setting in group.v.values():
t, no_default = setting._guess_type()
t, noneable = setting._guess_type()
if t is None:
continue
# Default to any
@ -354,7 +349,7 @@ def generate_ns(definitions: Definitions) -> tuple[str, str]:
if type_name == 'Any':
type_name = 'typing.Any'
if no_default and type_name not in ('typing.Any', 'None'):
if noneable and type_name not in ('typing.Any', 'None'):
attribute = f' {setting.internal_name}: {type_name} | None'
else:
attribute = f' {setting.internal_name}: {type_name}'

View File

@ -644,7 +644,7 @@ class _customAction(argparse.Action): # pragma: no cover
help=help,
)
def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser, namespace, values, option_string=None): # pragma: no cover
setattr(namespace, self.dest, 'Something')
@ -657,9 +657,9 @@ types = (
(5, settngs.Setting('-t', '--test', action='extend'), List[str], True),
(6, settngs.Setting('-t', '--test', nargs='+'), List[str], True),
(7, settngs.Setting('-t', '--test', action='store_const', const=1), int, True),
(8, settngs.Setting('-t', '--test', action='append_const', const=1), list, True),
(9, settngs.Setting('-t', '--test', action='store_true'), bool, True),
(10, settngs.Setting('-t', '--test', action='store_false'), bool, True),
(8, settngs.Setting('-t', '--test', action='append_const', const=1), List[int], True),
(9, settngs.Setting('-t', '--test', action='store_true'), bool, False),
(10, settngs.Setting('-t', '--test', action='store_false'), bool, False),
(11, settngs.Setting('-t', '--test', action=settngs.BooleanOptionalAction), bool, True),
(12, settngs.Setting('-t', '--test', action=_customAction), 'Any', True),
(13, settngs.Setting('-t', '--test', action='help'), None, True),
@ -667,16 +667,17 @@ types = (
(15, settngs.Setting('-t', '--test', type=int), int, True),
(16, settngs.Setting('-t', '--test', type=_typed_function), test_type, True),
(17, settngs.Setting('-t', '--test', type=_untyped_function, default=1), int, False),
(18, settngs.Setting('-t', '--test', type=_untyped_function), 'Any', True),
(18, settngs.Setting('-t', '--test', type=_untyped_function, default=[1]), List[int], False),
(19, settngs.Setting('-t', '--test', type=_untyped_function), 'Any', True),
)
@pytest.mark.parametrize('num,setting,typ,no_default_expected', types)
def test_guess_type(num, setting, typ, no_default_expected):
@pytest.mark.parametrize('num,setting,typ,noneable_expected', types)
def test_guess_type(num, setting, typ, noneable_expected):
x = setting._guess_type()
guessed_type, no_default = x
guessed_type, noneable = x
assert guessed_type == typ
assert no_default == no_default_expected
assert noneable == noneable_expected
expected_src = '''from __future__ import annotations
@ -704,9 +705,9 @@ settings = (
(5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src.format(extra_imports='', typ='int | None')),
(8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src.format(extra_imports='', typ='list | None')),
(9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src.format(extra_imports='', typ='bool | None')),
(10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src.format(extra_imports='', typ='bool | None')),
(8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[int] | None' if sys.version_info < (3, 9) else 'list[int] | None')),
(9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src.format(extra_imports='', typ='bool')),
(10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src.format(extra_imports='', typ='bool')),
(11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src.format(extra_imports='', typ='bool | None')),
(12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src.format(extra_imports='import typing\n', typ='typing.Any')),
(13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src),
@ -714,7 +715,8 @@ settings = (
(15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src.format(extra_imports='', typ='int | None')),
(16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type | None')),
(17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src.format(extra_imports='', typ='int')),
(18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src.format(extra_imports='import typing\n', typ='typing.Any')),
(18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=[1]), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[int]' if sys.version_info < (3, 9) else 'list[int]')),
(19, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src.format(extra_imports='import typing\n', typ='typing.Any')),
)
@ -759,13 +761,13 @@ settings_dict = (
(1, lambda parser: parser.add_setting('-t', '--test', cmdline=False), expected_src_dict.format(extra_imports='', typ='typing.Any')),
(2, lambda parser: parser.add_setting('-t', '--test', default=1, file=True, cmdline=False), expected_src_dict.format(extra_imports='', typ='int')),
(3, lambda parser: parser.add_setting('-t', '--test', action='count'), expected_src_dict.format(extra_imports='', typ='int | None')),
(4, lambda parser: parser.add_setting('-t', '--test', action='append'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(4, lambda parser: parser.add_setting('-t', '--test', action='append'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src_dict.format(extra_imports='', typ='typing.List[str] | None' if sys.version_info < (3, 9) else 'list[str] | None')),
(7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src_dict.format(extra_imports='', typ='int | None')),
(8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src_dict.format(extra_imports='', typ='list | None')),
(9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src_dict.format(extra_imports='', typ='bool | None')),
(10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src_dict.format(extra_imports='', typ='bool | None')),
(8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src_dict.format(extra_imports='', typ='typing.List[int] | None' if sys.version_info < (3, 9) else 'list[int] | None')),
(9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src_dict.format(extra_imports='', typ='bool')),
(10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src_dict.format(extra_imports='', typ='bool')),
(11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src_dict.format(extra_imports='', typ='bool | None')),
(12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src_dict.format(extra_imports='', typ='typing.Any')),
(13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src_dict),
@ -773,7 +775,8 @@ settings_dict = (
(15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src_dict.format(extra_imports='', typ='int | None')),
(16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src_dict.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type | None')),
(17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src_dict.format(extra_imports='', typ='int')),
(18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src_dict.format(extra_imports='', typ='typing.Any')),
(18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=[1]), expected_src_dict.format(extra_imports='', typ='typing.List[int]' if sys.version_info < (3, 9) else 'list[int]')),
(19, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src_dict.format(extra_imports='', typ='typing.Any')),
)