Improve generated namespace

Improve formatting of namespace
Allow _guess_type to return any string
Add tests
This commit is contained in:
Timmy Welch 2023-11-18 23:31:24 -08:00
parent 2c79e62765
commit ccacca1b32
3 changed files with 151 additions and 45 deletions

View File

@ -14,7 +14,6 @@ from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import Generic from typing import Generic
from typing import Literal
from typing import NoReturn from typing import NoReturn
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
@ -133,7 +132,7 @@ class Setting:
raise ValueError('names must be specified') raise ValueError('names must be specified')
# We prefix the destination name used by argparse so that there are no conflicts # We prefix the destination name used by argparse so that there are no conflicts
# Argument names will still cause an exception if there is a conflict e.g. if '-f' is defined twice # Argument names will still cause an exception if there is a conflict e.g. if '-f' is defined twice
self.internal_name, dest, flag = self.get_dest(group, names, dest) self.internal_name, dest, self.flag = self.get_dest(group, names, dest)
args: Sequence[str] = names args: Sequence[str] = names
# We then also set the metavar so that '--config' in the group runtime shows as 'CONFIG' instead of 'RUNTIME_CONFIG' # We then also set the metavar so that '--config' in the group runtime shows as 'CONFIG' instead of 'RUNTIME_CONFIG'
@ -142,7 +141,7 @@ class Setting:
# If we are not a flag, no '--' or '-' in front # If we are not a flag, no '--' or '-' in front
# we use internal_name as argparse sets dest to args[0] # we use internal_name as argparse sets dest to args[0]
if not flag: if not self.flag:
args = tuple((self.internal_name, *names[1:])) args = tuple((self.internal_name, *names[1:]))
self.action = action self.action = action
@ -172,7 +171,7 @@ class Setting:
'required': required, 'required': required,
'help': help, 'help': help,
'metavar': metavar, 'metavar': metavar,
'dest': self.internal_name if flag else None, 'dest': self.internal_name if self.flag else None,
} }
def __str__(self) -> str: # pragma: no cover def __str__(self) -> str: # pragma: no cover
@ -186,7 +185,7 @@ class Setting:
return NotImplemented return NotImplemented
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
def _guess_type(self) -> type | Literal['Any'] | None: def _guess_type(self) -> type | str | None:
if self.type is None and self.action is None: if self.type is None and self.action is None:
if self.cmdline: if self.cmdline:
if self.nargs in ('+', '*') or isinstance(self.nargs, int) and self.nargs > 1: if self.nargs in ('+', '*') or isinstance(self.nargs, int) and self.nargs > 1:
@ -202,8 +201,9 @@ class Setting:
if self.type is not None: if self.type is not None:
type_hints = typing.get_type_hints(self.type) type_hints = typing.get_type_hints(self.type)
if 'return' in type_hints and isinstance(type_hints['return'], type): if 'return' in type_hints:
return type_hints['return'] t: type | str = type_hints['return']
return t
if self.default is not None: if self.default is not None:
return type(self.default) return type(self.default)
return 'Any' return 'Any'
@ -284,36 +284,58 @@ if TYPE_CHECKING:
def generate_ns(definitions: Definitions) -> str: def generate_ns(definitions: Definitions) -> str:
imports = ['from __future__ import annotations', 'import typing', 'import settngs'] initial_imports = ['from __future__ import annotations', '', 'import settngs', '']
ns = 'class settngs_namespace(settngs.TypedNS):\n' imports: Sequence[str] | set[str]
types = [] imports = set()
for group_name, group in definitions.items():
for setting_name, setting in group.v.items(): attributes = []
for group in definitions.values():
for setting in group.v.values():
t = setting._guess_type() t = setting._guess_type()
if t is None: if t is None:
continue continue
# Default to any
type_name = 'Any' type_name = 'Any'
# Take a string as is
if isinstance(t, str): if isinstance(t, str):
type_name = t type_name = t
# Handle generic aliases eg dict[str, str] instead of dict
elif isinstance(t, types_GenericAlias): elif isinstance(t, types_GenericAlias):
type_name = str(t) type_name = str(t)
# Handle standard type objects
elif isinstance(t, type): elif isinstance(t, type):
type_name = t.__name__ type_name = t.__name__
# Builtin types don't need an import
if t.__module__ != 'builtins': if t.__module__ != 'builtins':
imports.append(f'import {t.__module__}') imports.add(f'import {t.__module__}')
# Use the full imported name
type_name = t.__module__ + '.' + type_name type_name = t.__module__ + '.' + type_name
# Expand Any to typing.Any
if type_name == 'Any': if type_name == 'Any':
type_name = 'typing.Any' type_name = 'typing.Any'
types.append(f' {setting.internal_name}: {type_name}') attributes.append(f' {setting.internal_name}: {type_name}')
if types and types[-1] != '': # Add a blank line between groups
types.append('') if attributes and attributes[-1] != '':
attributes.append('')
if not types or all(x == '' for x in types): ns = 'class settngs_namespace(settngs.TypedNS):\n'
# Add a '...' expression if there are no attributes
if not attributes or all(x == '' for x in attributes):
ns += ' ...\n' ns += ' ...\n'
types = [''] attributes = ['']
return '\n'.join(imports) + '\n\n' + ns + '\n'.join(types) # Add the tying import before extra imports
if 'typing.' in '\n'.join(attributes):
initial_imports.append('import typing')
# Remove the possible duplicate typing import
imports = sorted(list(imports - {'import typing'}))
# Merge the imports the ns class definition and the attributes
return '\n'.join(initial_imports + imports) + '\n\n\n' + ns + '\n'.join(attributes)
def sanitize_name(name: str) -> str: def sanitize_name(name: str) -> str:
@ -558,8 +580,11 @@ def create_argparser(definitions: Definitions, description: str, epilog: str) ->
else: else:
groups[setting.group] = argparser.add_argument_group(setting.group) groups[setting.group] = argparser.add_argument_group(setting.group)
# hard coded exception for files # Hard coded exception for positional arguments
if not (setting.group == 'runtime' and setting.nargs == '*'): # Ensures that the option shows at the top of the help output
if 'runtime' in setting.group.casefold() and setting.nargs == '*' and not setting.flag:
current_group = argparser
else:
current_group = groups[setting.group] current_group = groups[setting.group]
current_group.add_argument(*argparse_args, **argparse_kwargs) current_group.add_argument(*argparse_args, **argparse_kwargs)
return argparser return argparser
@ -643,6 +668,11 @@ class Manager:
self.exclusive_group = False self.exclusive_group = False
self.current_group_name = '' self.current_group_name = ''
def _get_config(self, c: T | Config[T]) -> Config[T]:
if not isinstance(c, Config):
return Config(c, self.definitions)
return c
def generate_ns(self) -> str: def generate_ns(self) -> str:
return generate_ns(self.definitions) return generate_ns(self.definitions)
@ -651,6 +681,7 @@ class Manager:
def add_setting(self, *args: Any, **kwargs: Any) -> None: def add_setting(self, *args: Any, **kwargs: Any) -> None:
"""Passes all arguments through to `Setting`, `group` and `exclusive` are already set""" """Passes all arguments through to `Setting`, `group` and `exclusive` are already set"""
setting = Setting(*args, **kwargs, group=self.current_group_name, exclusive=self.exclusive_group) setting = Setting(*args, **kwargs, group=self.current_group_name, exclusive=self.exclusive_group)
self.definitions[self.current_group_name].v[setting.dest] = setting self.definitions[self.current_group_name].v[setting.dest] = setting
@ -663,6 +694,7 @@ class Manager:
group: A function that registers individual options using :meth:`add_setting` group: A function that registers individual options using :meth:`add_setting`
exclusive_group: If this group is an argparse exclusive group exclusive_group: If this group is an argparse exclusive group
""" """
if self.current_group_name != '': if self.current_group_name != '':
raise ValueError('Sub groups are not allowed') raise ValueError('Sub groups are not allowed')
self.current_group_name = name self.current_group_name = name
@ -681,6 +713,7 @@ class Manager:
group: A function that registers individual options using :meth:`add_setting` group: A function that registers individual options using :meth:`add_setting`
exclusive_group: If this group is an argparse exclusive group exclusive_group: If this group is an argparse exclusive group
""" """
if self.current_group_name != '': if self.current_group_name != '':
raise ValueError('Sub groups are not allowed') raise ValueError('Sub groups are not allowed')
self.current_group_name = name self.current_group_name = name
@ -713,9 +746,7 @@ class Manager:
cmdline: Include cmdline options cmdline: Include cmdline options
""" """
if not isinstance(config, Config): return clean_config(self._get_config(config), file=file, cmdline=cmdline)
config = Config(config, self.definitions)
return clean_config(config, file=file, cmdline=cmdline)
def normalize_config( def normalize_config(
self, self,
@ -738,10 +769,8 @@ class Manager:
persistent: Include unknown keys in persistent groups persistent: Include unknown keys in persistent groups
""" """
if not isinstance(config, Config):
config = Config(config, self.definitions)
return normalize_config( return normalize_config(
config=config, config=self._get_config(config),
file=file, file=file,
cmdline=cmdline, cmdline=cmdline,
default=default, default=default,
@ -769,11 +798,9 @@ class Manager:
persistent: Include unknown keys in persistent groups persistent: Include unknown keys in persistent groups
""" """
if isinstance(config, Config): return get_namespace(
self.definitions = config[1] self._get_config(config), file=file, cmdline=cmdline, default=default, persistent=persistent,
else: )
config = Config(config, self.definitions)
return get_namespace(config, file=file, cmdline=cmdline, default=default, persistent=persistent)
def parse_file(self, filename: pathlib.Path) -> tuple[Config[Values], bool]: def parse_file(self, filename: pathlib.Path) -> tuple[Config[Values], bool]:
""" """
@ -784,6 +811,7 @@ class Manager:
Args: Args:
filename: A pathlib.Path object to read a JSON dictionary from filename: A pathlib.Path object to read a JSON dictionary from
""" """
return parse_file(filename=filename, definitions=self.definitions) return parse_file(filename=filename, definitions=self.definitions)
def save_file(self, config: T | Config[T], filename: pathlib.Path) -> bool: def save_file(self, config: T | Config[T], filename: pathlib.Path) -> bool:
@ -796,9 +824,8 @@ class Manager:
config: The options to save to a json dictionary config: The options to save to a json dictionary
filename: A pathlib.Path object to save the json dictionary to filename: A pathlib.Path object to save the json dictionary to
""" """
if not isinstance(config, Config):
config = Config(config, self.definitions) return save_file(self._get_config(config), filename=filename)
return save_file(config, filename=filename)
def parse_cmdline(self, args: list[str] | None = None, config: ns[T] = None) -> Config[Values]: def parse_cmdline(self, args: list[str] | None = None, config: ns[T] = None) -> Config[Values]:
""" """
@ -891,7 +918,7 @@ def _main(args: list[str] | None = None) -> None:
if merged_namespace.values.Example_Group_save: if merged_namespace.values.Example_Group_save:
if manager.save_file(merged_config, settings_path): if manager.save_file(merged_config, settings_path):
print(f'Successfully saved settings to {settings_path}') # noqa: T201 print(f'Successfully saved settings to {settings_path}') # noqa: T201
else: else: # pragma: no cover
print(f'Failed saving settings to a {settings_path}') # noqa: T201 print(f'Failed saving settings to a {settings_path}') # noqa: T201
if merged_namespace.values.Example_Group_verbose: if merged_namespace.values.Example_Group_verbose:
print(f'{merged_namespace.values.Example_Group_verbose=}') # noqa: T201 print(f'{merged_namespace.values.Example_Group_verbose=}') # noqa: T201

View File

@ -81,6 +81,7 @@ success = [
'display_name': 'test_setting', # defaults to dest 'display_name': 'test_setting', # defaults to dest
'exclusive': False, 'exclusive': False,
'file': True, 'file': True,
'flag': True,
'group': 'tst', 'group': 'tst',
'help': None, 'help': None,
'internal_name': 'tst_test_setting', # Should almost always be "{group}_{dest}" 'internal_name': 'tst_test_setting', # Should almost always be "{group}_{dest}"
@ -121,6 +122,7 @@ success = [
'display_name': 'testing', # defaults to dest 'display_name': 'testing', # defaults to dest
'exclusive': False, 'exclusive': False,
'file': True, 'file': True,
'flag': True,
'group': 'tst', 'group': 'tst',
'help': None, 'help': None,
'internal_name': 'tst_testing', # Should almost always be "{group}_{dest}" 'internal_name': 'tst_testing', # Should almost always be "{group}_{dest}"
@ -160,6 +162,7 @@ success = [
'display_name': 'test', # defaults to dest 'display_name': 'test', # defaults to dest
'exclusive': False, 'exclusive': False,
'file': True, 'file': True,
'flag': True,
'group': 'tst', 'group': 'tst',
'help': None, 'help': None,
'internal_name': 'tst_test', # Should almost always be "{group}_{dest}" 'internal_name': 'tst_test', # Should almost always be "{group}_{dest}"
@ -200,6 +203,7 @@ success = [
'display_name': 'test', # defaults to dest 'display_name': 'test', # defaults to dest
'exclusive': False, 'exclusive': False,
'file': True, 'file': True,
'flag': True,
'group': 'tst', 'group': 'tst',
'help': None, 'help': None,
'internal_name': 'tst_test', # Should almost always be "{group}_{dest}" 'internal_name': 'tst_test', # Should almost always be "{group}_{dest}"
@ -239,6 +243,7 @@ success = [
'display_name': 'test', # defaults to dest 'display_name': 'test', # defaults to dest
'exclusive': False, 'exclusive': False,
'file': True, 'file': True,
'flag': True,
'group': 'tst', 'group': 'tst',
'help': None, 'help': None,
'internal_name': 'tst_test', 'internal_name': 'tst_test',
@ -278,6 +283,7 @@ success = [
'display_name': 'test', # defaults to dest 'display_name': 'test', # defaults to dest
'exclusive': False, 'exclusive': False,
'file': True, 'file': True,
'flag': False,
'group': 'tst', 'group': 'tst',
'help': None, 'help': None,
'internal_name': 'tst_test', 'internal_name': 'tst_test',
@ -315,6 +321,7 @@ success = [
'display_name': 'test', # defaults to dest 'display_name': 'test', # defaults to dest
'exclusive': False, 'exclusive': False,
'file': True, 'file': True,
'flag': True,
'group': '', 'group': '',
'help': None, 'help': None,
'internal_name': 'test', # No group, leading _ is stripped 'internal_name': 'test', # No group, leading _ is stripped

View File

@ -6,6 +6,7 @@ import json
import pathlib import pathlib
import sys import sys
from collections import defaultdict from collections import defaultdict
from textwrap import dedent
from typing import Generator from typing import Generator
import pytest import pytest
@ -63,6 +64,71 @@ def test_add_setting(settngs_manager):
assert settngs_manager.add_setting('--test') is None assert settngs_manager.add_setting('--test') is None
def test_add_setting_invalid_name(settngs_manager):
with pytest.raises(Exception, match='Cannot use test¥ in a namespace'):
assert settngs_manager.add_setting('--test¥') is None
def test_sub_group(settngs_manager):
with pytest.raises(Exception, match='Sub groups are not allowed'):
settngs_manager.add_group('tst', lambda parser: parser.add_group('tst', lambda parser: parser.add_setting('--test2', default='hello')))
def test_sub_persistent_group(settngs_manager):
with pytest.raises(Exception, match='Sub groups are not allowed'):
settngs_manager.add_persistent_group('tst', lambda parser: parser.add_persistent_group('tst', lambda parser: parser.add_setting('--test2', default='hello')))
def test_redefine_persistent_group(settngs_manager):
settngs_manager.add_group('tst', lambda parser: parser.add_setting('--test2', default='hello'))
with pytest.raises(Exception, match='Group already exists and is not persistent'):
settngs_manager.add_persistent_group('tst', None)
def test_exclusive_group(settngs_manager):
settngs_manager.add_group('tst', lambda parser: parser.add_setting('--test', default='hello'), exclusive_group=True)
settngs_manager.create_argparser()
args = settngs_manager.argparser.parse_args(['--test', 'never'])
assert args.tst_test == 'never'
with pytest.raises(SystemExit):
settngs_manager.add_group('tst', lambda parser: parser.add_setting('--test2', default='hello'), exclusive_group=True)
settngs_manager.create_argparser()
args = settngs_manager.argparser.parse_args(['--test', 'never', '--test2', 'never'])
def test_files_group(capsys, settngs_manager):
settngs_manager.add_group('runtime', lambda parser: parser.add_setting('test', default='hello', nargs='*'))
settngs_manager.create_argparser()
settngs_manager.argparser.print_help()
captured = capsys.readouterr()
assert captured.out == dedent('''\
usage: __main__.py [-h] [TEST [TEST ...]]
positional arguments:
TEST
optional arguments:
-h, --help show this help message and exit
''')
def test_setting_without_group(capsys, settngs_manager):
settngs_manager.add_setting('test', default='hello', nargs='*')
settngs_manager.create_argparser()
settngs_manager.argparser.print_help()
captured = capsys.readouterr()
assert captured.out == dedent('''\
usage: __main__.py [-h] [TEST [TEST ...]]
positional arguments:
TEST
optional arguments:
-h, --help show this help message and exit
''')
class TestValues: class TestValues:
def test_invalid_normalize(self, settngs_manager): def test_invalid_normalize(self, settngs_manager):
@ -533,6 +599,7 @@ settings = (
(lambda parser: parser.add_setting('-t', '--test', action='help'), None), (lambda parser: parser.add_setting('-t', '--test', action='help'), None),
(lambda parser: parser.add_setting('-t', '--test', action='version'), None), (lambda parser: parser.add_setting('-t', '--test', action='version'), None),
(lambda parser: parser.add_setting('-t', '--test', type=int), 'int'), (lambda parser: parser.add_setting('-t', '--test', type=int), 'int'),
(lambda parser: parser.add_setting('-t', '--test', nargs='+'), List[str]),
(lambda parser: parser.add_setting('-t', '--test', type=_typed_function), 'tests.settngs_test.test_type'), (lambda parser: parser.add_setting('-t', '--test', type=_typed_function), 'tests.settngs_test.test_type'),
(lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), 'int'), (lambda parser: parser.add_setting('-t', '--test', type=_untyped_function, default=1), 'int'),
(lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), 'typing.Any'), (lambda parser: parser.add_setting('-t', '--test', type=_untyped_function), 'typing.Any'),
@ -543,16 +610,21 @@ settings = (
def test_generate_ns(settngs_manager, set_options, typ): def test_generate_ns(settngs_manager, set_options, typ):
settngs_manager.add_group('test', set_options) settngs_manager.add_group('test', set_options)
src = '''\ src = dedent('''\
from __future__ import annotations from __future__ import annotations
import typing
import settngs import settngs
''' ''')
if 'typing.' in str(typ):
src += '\nimport typing'
if typ == 'tests.settngs_test.test_type': if typ == 'tests.settngs_test.test_type':
src += 'import tests.settngs_test\n' src += '\nimport tests.settngs_test'
src += ''' src += dedent('''
class settngs_namespace(settngs.TypedNS):
'''
class settngs_namespace(settngs.TypedNS):
''')
if typ is None: if typ is None:
src += ' ...\n' src += ' ...\n'
else: else: