6 Commits

Author SHA1 Message Date
c588fc891e Support type generation for dicts an addition to namespaces 2024-02-22 19:15:47 -08:00
1ce6079285 Fix pre-commit 2024-02-22 19:14:45 -08:00
7c748f6815 Merge branch 'pre-commit-ci-update-config' 2024-02-22 14:44:51 -08:00
8d5b30546e Improve type guessing for generic Sequence types 2024-02-22 14:42:07 -08:00
cebca481fc [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/asottile/pyupgrade: v3.15.0 → v3.15.1](https://github.com/asottile/pyupgrade/compare/v3.15.0...v3.15.1)
2024-02-19 17:21:12 +00:00
dd8cd1188e [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/PyCQA/flake8: 6.1.0 → 7.0.0](https://github.com/PyCQA/flake8/compare/6.1.0...7.0.0)
- [github.com/pre-commit/mirrors-mypy: v1.7.0 → v1.8.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.7.0...v1.8.0)
2024-01-08 17:17:39 +00:00
3 changed files with 214 additions and 68 deletions

View File

@ -28,7 +28,7 @@ repos:
hooks: hooks:
- id: dead - id: dead
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.15.0 rev: v3.15.1
hooks: hooks:
- id: pyupgrade - id: pyupgrade
args: [--py38-plus] args: [--py38-plus]
@ -37,11 +37,11 @@ repos:
hooks: hooks:
- id: autopep8 - id: autopep8
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 6.1.0 rev: 7.0.0
hooks: hooks:
- id: flake8 - id: flake8
additional_dependencies: [flake8-encodings, flake8-warnings, flake8-builtins, flake8-length, flake8-print] additional_dependencies: [flake8-encodings, flake8-warnings, flake8-builtins, flake8-print]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.0 rev: v1.8.0
hooks: hooks:
- id: mypy - id: mypy

View File

@ -13,6 +13,7 @@ from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast
from typing import Dict from typing import Dict
from typing import Generic from typing import Generic
from typing import NoReturn from typing import NoReturn
@ -84,6 +85,20 @@ else: # pragma: no cover
removeprefix = str.removeprefix removeprefix = str.removeprefix
def _isnamedtupleinstance(x: Any) -> bool:
t = type(x)
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, '_fields', None)
if not isinstance(f, tuple):
return False
return all(isinstance(n, str) for n in f)
class Setting: class Setting:
def __init__( def __init__(
self, self,
@ -199,6 +214,11 @@ class Setting:
return str return str
else: else:
if not self.cmdline and self.default is not None: if not self.cmdline and self.default is not None:
if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]:
try:
return cast(type, type(self.default)[type(self.default[0])])
except Exception:
...
return type(self.default) return type(self.default)
return 'Any' return 'Any'
@ -211,6 +231,11 @@ class Setting:
t: type | str = type_hints['return'] t: type | str = type_hints['return']
return t return t
if self.default is not None: if self.default is not None:
if not isinstance(self.default, str) and not _isnamedtupleinstance(self.default) and isinstance(self.default, Sequence) and self.default and self.default[0]:
try:
return cast(type, type(self.default)[type(self.default[0])])
except Exception:
...
return type(self.default) return type(self.default)
return 'Any' return 'Any'
@ -291,8 +316,8 @@ if TYPE_CHECKING:
ns = Namespace | TypedNS | Config[T] | None ns = Namespace | TypedNS | Config[T] | None
def generate_ns(definitions: Definitions) -> str: def generate_ns(definitions: Definitions) -> tuple[str, str]:
initial_imports = ['from __future__ import annotations', '', 'import settngs', ''] initial_imports = ['from __future__ import annotations', '', 'import settngs']
imports: Sequence[str] | set[str] imports: Sequence[str] | set[str]
imports = set() imports = set()
@ -331,7 +356,7 @@ def generate_ns(definitions: Definitions) -> str:
if attributes and attributes[-1] != '': if attributes and attributes[-1] != '':
attributes.append('') attributes.append('')
ns = 'class settngs_namespace(settngs.TypedNS):\n' ns = 'class SettngsNS(settngs.TypedNS):\n'
# Add a '...' expression if there are no attributes # Add a '...' expression if there are no attributes
if not attributes or all(x == '' for x in attributes): if not attributes or all(x == '' for x in attributes):
ns += ' ...\n' ns += ' ...\n'
@ -345,7 +370,69 @@ def generate_ns(definitions: Definitions) -> str:
imports = sorted(list(imports - {'import typing'})) imports = sorted(list(imports - {'import typing'}))
# Merge the imports the ns class definition and the attributes # Merge the imports the ns class definition and the attributes
return '\n'.join(initial_imports + imports) + '\n\n\n' + ns + '\n'.join(attributes) return '\n'.join(initial_imports + imports), ns + '\n'.join(attributes)
def generate_dict(definitions: Definitions) -> tuple[str, str]:
initial_imports = ['from __future__ import annotations', '', 'import typing']
imports: Sequence[str] | set[str]
imports = set()
groups_are_identifiers = all(n.isidentifier() for n in definitions.keys())
classes = []
for group_name, group in definitions.items():
attributes = []
for setting in group.v.values():
t = setting._guess_type()
if t is None:
continue
# Default to any
type_name = 'Any'
# Take a string as is
if isinstance(t, str):
type_name = t
# Handle generic aliases eg dict[str, str] instead of dict
elif isinstance(t, types_GenericAlias):
type_name = str(t)
# Handle standard type objects
elif isinstance(t, type):
type_name = t.__name__
# Builtin types don't need an import
if t.__module__ != 'builtins':
imports.add(f'import {t.__module__}')
# Use the full imported name
type_name = t.__module__ + '.' + type_name
# Expand Any to typing.Any
if type_name == 'Any':
type_name = 'typing.Any'
attribute = f' {setting.dest}: {type_name}'
if attribute not in attributes:
attributes.append(attribute)
if not attributes or all(x == '' for x in attributes):
attributes = [' ...']
classes.append(
f'class {sanitize_name(group_name)}(typing.TypedDict):\n'
+ '\n'.join(attributes) + '\n\n',
)
# Remove the possible duplicate typing import
imports = sorted(list(imports - {'import typing'}))
if groups_are_identifiers:
ns = '\nclass SettngsDict(typing.TypedDict):\n'
ns += '\n'.join(f' {n}: {sanitize_name(n)}' for n in definitions.keys())
else:
ns = '\nSettngsDict = typing.TypedDict(\n'
ns += " 'SettngsDict', {\n"
for n in definitions.keys():
ns += f' {n!r}: {sanitize_name(n)},\n'
ns += ' },\n'
ns += ')\n'
# Merge the imports the ns class definition and the attributes
return '\n'.join(initial_imports + imports), '\n'.join(classes) + ns + '\n'
def sanitize_name(name: str) -> str: def sanitize_name(name: str) -> str:
@ -707,9 +794,12 @@ class Manager:
return Config(c, self.definitions) return Config(c, self.definitions)
return c return c
def generate_ns(self) -> str: def generate_ns(self) -> tuple[str, str]:
return generate_ns(self.definitions) return generate_ns(self.definitions)
def generate_dict(self) -> tuple[str, str]:
return generate_dict(self.definitions)
def create_argparser(self) -> None: def create_argparser(self) -> None:
self.argparser = create_argparser(self.definitions, self.description, self.epilog) self.argparser = create_argparser(self.definitions, self.description, self.epilog)

View File

@ -6,7 +6,6 @@ import json
import pathlib import pathlib
import sys import sys
from collections import defaultdict from collections import defaultdict
from textwrap import dedent
from typing import Generator from typing import Generator
import pytest import pytest
@ -650,83 +649,140 @@ class _customAction(argparse.Action): # pragma: no cover
types = ( types = (
(settngs.Setting('-t', '--test'), str), (0, settngs.Setting('-t', '--test'), str),
(settngs.Setting('-t', '--test', cmdline=False), 'Any'), (1, settngs.Setting('-t', '--test', cmdline=False), 'Any'),
(settngs.Setting('-t', '--test', default=1, file=True, cmdline=False), int), (2, settngs.Setting('-t', '--test', default=1, file=True, cmdline=False), int),
(settngs.Setting('-t', '--test', action='count'), int), (3, settngs.Setting('-t', '--test', action='count'), int),
(settngs.Setting('-t', '--test', action='append'), List[str]), (4, settngs.Setting('-t', '--test', action='append'), List[str]),
(settngs.Setting('-t', '--test', action='extend'), List[str]), (5, settngs.Setting('-t', '--test', action='extend'), List[str]),
(settngs.Setting('-t', '--test', action='store_const', const=1), int), (6, settngs.Setting('-t', '--test', nargs='+'), List[str]),
(settngs.Setting('-t', '--test', action='append_const', const=1), list), (7, settngs.Setting('-t', '--test', action='store_const', const=1), int),
(settngs.Setting('-t', '--test', action='store_true'), bool), (8, settngs.Setting('-t', '--test', action='append_const', const=1), list),
(settngs.Setting('-t', '--test', action='store_false'), bool), (9, settngs.Setting('-t', '--test', action='store_true'), bool),
(settngs.Setting('-t', '--test', action=settngs.BooleanOptionalAction), bool), (10, settngs.Setting('-t', '--test', action='store_false'), bool),
(settngs.Setting('-t', '--test', action=_customAction), 'Any'), (11, settngs.Setting('-t', '--test', action=settngs.BooleanOptionalAction), bool),
(settngs.Setting('-t', '--test', action='help'), None), (12, settngs.Setting('-t', '--test', action=_customAction), 'Any'),
(settngs.Setting('-t', '--test', action='version'), None), (13, settngs.Setting('-t', '--test', action='help'), None),
(settngs.Setting('-t', '--test', type=int), int), (14, settngs.Setting('-t', '--test', action='version'), None),
(settngs.Setting('-t', '--test', type=_typed_function), test_type), (15, settngs.Setting('-t', '--test', type=int), int),
(settngs.Setting('-t', '--test', type=_untyped_function, default=1), int), (16, settngs.Setting('-t', '--test', type=_typed_function), test_type),
(settngs.Setting('-t', '--test', type=_untyped_function), 'Any'), (17, settngs.Setting('-t', '--test', type=_untyped_function, default=1), int),
(18, settngs.Setting('-t', '--test', type=_untyped_function), 'Any'),
) )
@pytest.mark.parametrize('setting,typ', types) @pytest.mark.parametrize('num,setting,typ', types)
def test_guess_type(setting, typ): def test_guess_type(num, setting, typ):
guessed_type = setting._guess_type() guessed_type = setting._guess_type()
assert guessed_type == typ assert guessed_type == typ
expected_src = '''from __future__ import annotations
import settngs
{extra_imports}
class SettngsNS(settngs.TypedNS):
test__test: {typ}
'''
no_type_expected_src = '''from __future__ import annotations
import settngs
class SettngsNS(settngs.TypedNS):
...
'''
settings = ( settings = (
(lambda parser: parser.add_setting('-t', '--test'), 'str'), (0, lambda parser: parser.add_setting('-t', '--test'), expected_src.format(extra_imports='', typ='str')),
(lambda parser: parser.add_setting('-t', '--test', cmdline=False), 'typing.Any'), (1, lambda parser: parser.add_setting('-t', '--test', cmdline=False), expected_src.format(extra_imports='import typing\n', typ='typing.Any')),
(lambda parser: parser.add_setting('-t', '--test', default=1, file=True, cmdline=False), 'int'), (2, lambda parser: parser.add_setting('-t', '--test', default=1, file=True, cmdline=False), expected_src.format(extra_imports='', typ='int')),
(lambda parser: parser.add_setting('-t', '--test', action='count'), 'int'), (3, lambda parser: parser.add_setting('-t', '--test', action='count'), expected_src.format(extra_imports='', typ='int')),
(lambda parser: parser.add_setting('-t', '--test', action='append'), List[str]), (4, lambda parser: parser.add_setting('-t', '--test', action='append'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str]' if sys.version_info < (3, 9) else 'list[str]')),
(lambda parser: parser.add_setting('-t', '--test', action='extend'), List[str]), (5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str]' if sys.version_info < (3, 9) else 'list[str]')),
(lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), 'int'), (6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src.format(extra_imports='import typing\n' if sys.version_info < (3, 9) else '', typ='typing.List[str]' if sys.version_info < (3, 9) else 'list[str]')),
(lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), 'list'), (7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src.format(extra_imports='', typ='int')),
(lambda parser: parser.add_setting('-t', '--test', action='store_true'), 'bool'), (8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src.format(extra_imports='', typ='list')),
(lambda parser: parser.add_setting('-t', '--test', action='store_false'), 'bool'), (9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src.format(extra_imports='', typ='bool')),
(lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), 'bool'), (10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src.format(extra_imports='', typ='bool')),
(lambda parser: parser.add_setting('-t', '--test', action=_customAction), 'typing.Any'), (11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src.format(extra_imports='', typ='bool')),
(lambda parser: parser.add_setting('-t', '--test', action='help'), None), (12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src.format(extra_imports='import typing\n', typ='typing.Any')),
(lambda parser: parser.add_setting('-t', '--test', action='version'), None), (13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src),
(lambda parser: parser.add_setting('-t', '--test', type=int), 'int'), (14, lambda parser: parser.add_setting('-t', '--test', action='version'), no_type_expected_src),
(lambda parser: parser.add_setting('-t', '--test', nargs='+'), List[str]), (15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src.format(extra_imports='', typ='int')),
(lambda parser: parser.add_setting('-t', '--test', type=_typed_function), 'tests.settngs_test.test_type'), (16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type')),
(lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), 'int'), (17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src.format(extra_imports='', typ='int')),
(lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), 'typing.Any'), (18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src.format(extra_imports='import typing\n', typ='typing.Any')),
) )
@pytest.mark.parametrize('set_options,typ', settings) @pytest.mark.parametrize('num,set_options,expected', settings)
def test_generate_ns(settngs_manager, set_options, typ): def test_generate_ns(settngs_manager, num, set_options, expected):
settngs_manager.add_group('test', set_options) settngs_manager.add_group('test', set_options)
src = dedent('''\ imports, types = settngs_manager.generate_ns()
from __future__ import annotations generated_src = '\n\n\n'.join((imports, types))
import settngs assert generated_src == expected
''')
if 'typing.' in str(typ): ast.parse(generated_src)
src += '\nimport typing'
if typ == 'tests.settngs_test.test_type':
src += '\nimport tests.settngs_test'
src += dedent('''
class settngs_namespace(settngs.TypedNS): expected_src_dict = '''from __future__ import annotations
''')
if typ is None:
src += ' ...\n'
else:
src += f' {settngs_manager.definitions["test"].v["test"].internal_name}: {typ}\n'
generated_src = settngs_manager.generate_ns() import typing
{extra_imports}
assert generated_src == src class test(typing.TypedDict):
test: {typ}
class SettngsDict(typing.TypedDict):
test: test
'''
no_type_expected_src_dict = '''from __future__ import annotations
import typing
class test(typing.TypedDict):
...
class SettngsDict(typing.TypedDict):
test: test
'''
settings_dict = (
(0, lambda parser: parser.add_setting('-t', '--test'), expected_src_dict.format(extra_imports='', typ='str')),
(1, lambda parser: parser.add_setting('-t', '--test', cmdline=False), expected_src_dict.format(extra_imports='', typ='typing.Any')),
(2, lambda parser: parser.add_setting('-t', '--test', default=1, file=True, cmdline=False), expected_src_dict.format(extra_imports='', typ='int')),
(3, lambda parser: parser.add_setting('-t', '--test', action='count'), expected_src_dict.format(extra_imports='', typ='int')),
(4, lambda parser: parser.add_setting('-t', '--test', action='append'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str]' if sys.version_info < (3, 9) else 'list[str]')),
(5, lambda parser: parser.add_setting('-t', '--test', action='extend'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str]' if sys.version_info < (3, 9) else 'list[str]')),
(6, lambda parser: parser.add_setting('-t', '--test', nargs='+'), expected_src_dict.format(extra_imports='' if sys.version_info < (3, 9) else '', typ='typing.List[str]' if sys.version_info < (3, 9) else 'list[str]')),
(7, lambda parser: parser.add_setting('-t', '--test', action='store_const', const=1), expected_src_dict.format(extra_imports='', typ='int')),
(8, lambda parser: parser.add_setting('-t', '--test', action='append_const', const=1), expected_src_dict.format(extra_imports='', typ='list')),
(9, lambda parser: parser.add_setting('-t', '--test', action='store_true'), expected_src_dict.format(extra_imports='', typ='bool')),
(10, lambda parser: parser.add_setting('-t', '--test', action='store_false'), expected_src_dict.format(extra_imports='', typ='bool')),
(11, lambda parser: parser.add_setting('-t', '--test', action=settngs.BooleanOptionalAction), expected_src_dict.format(extra_imports='', typ='bool')),
(12, lambda parser: parser.add_setting('-t', '--test', action=_customAction), expected_src_dict.format(extra_imports='', typ='typing.Any')),
(13, lambda parser: parser.add_setting('-t', '--test', action='help'), no_type_expected_src_dict),
(14, lambda parser: parser.add_setting('-t', '--test', action='version'), no_type_expected_src_dict),
(15, lambda parser: parser.add_setting('-t', '--test', type=int), expected_src_dict.format(extra_imports='', typ='int')),
(16, lambda parser: parser.add_setting('-t', '--test', type=_typed_function), expected_src_dict.format(extra_imports='import tests.settngs_test\n', typ='tests.settngs_test.test_type')),
(17, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), expected_src_dict.format(extra_imports='', typ='int')),
(18, lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), expected_src_dict.format(extra_imports='', typ='typing.Any')),
)
@pytest.mark.parametrize('num,set_options,expected', settings_dict)
def test_generate_dict(settngs_manager, num, set_options, expected):
settngs_manager.add_group('test', set_options)
imports, types = settngs_manager.generate_dict()
generated_src = '\n\n\n'.join((imports, types))
assert generated_src == expected
ast.parse(generated_src) ast.parse(generated_src)