From 88bc34adec7e1803bea326b96ed6f860044080ad Mon Sep 17 00:00:00 2001 From: axolotle Date: Tue, 17 Jan 2023 18:50:13 +0100 Subject: [PATCH] add Field to typer Argument|Option factory --- src/interface/cli.py | 123 +++++++++++++++++++++++++++++-------------- 1 file changed, 83 insertions(+), 40 deletions(-) diff --git a/src/interface/cli.py b/src/interface/cli.py index 91477840a..5166b2296 100644 --- a/src/interface/cli.py +++ b/src/interface/cli.py @@ -4,10 +4,18 @@ from __future__ import ( import os import inspect -import typer import yaml -from typing import Any, Callable +import typer +import pydantic + +from typing import ( + Any, + Callable, + Optional, + Union, + get_type_hints, +) from rich import print as rprint from rich.syntax import Syntax @@ -77,49 +85,27 @@ class Interface(BaseInterface): def cli(self, command_def: str, **extra_data): def decorator(func: Callable): - signature = inspect.signature(func) - override_params = [] - params = self.filter_params(signature.parameters.values()) local_data = merge_dicts(self.local_data, extra_data) - params_doc = get_params_doc(func.__doc__) - command, args = parse_cli_command(command_def) - for param in params: - param_default = ( - param.default - if not param.default == param.empty - else ... # required - ) + signature = inspect.signature(func) + annotations = get_type_hints(func, include_extras=True) + params = signature.parameters.values() + doc = get_params_doc(func.__doc__) + command, positional_params = parse_cli_command(command_def) - param_kwargs = local_data.get(param.name, {}) - param_kwargs["help"] = params_doc.get(param.name, None) + forward_params = [] + override_params = [] - if param_kwargs.pop("deprecated", False): - param_kwargs["rich_help_panel"] = "Deprecated Options" + for param, field in self.build_fields( + Interface, params, annotations, doc, positional_params + ): + forward_params.append(param) - if param_kwargs.get("prompt", False): - if param.name == "password": - param_kwargs["confirmation_prompt"] = True - param_kwargs["hide_input"] = True - - pattern = param_kwargs.pop("pattern", None) - if pattern: - param_kwargs["callback"] = pattern_validator(pattern, param.name) - - # Populate default param value with typer.Argument|Option - if param_kwargs.pop("file", False): - new_param = param.replace( - annotation=typer.FileText, - default=param_default + if field: + override_param = param.replace( + default=field_to_typer_default(field) ) - elif param.kind == param.VAR_POSITIONAL: - new_param = param - elif param.name in args: - new_param = param.replace(default=typer.Argument(param_default, **param_kwargs)) - else: - new_param = param.replace(default=typer.Option(param_default, **param_kwargs)) - - override_params.append(new_param) + override_params.append(override_param) def hook_results(*args, **kwargs): try: @@ -143,7 +129,64 @@ class Interface(BaseInterface): )(command_func) self.clear_local_data() - + func.__signature__ = override_function(func, signature, forward_params) # type: ignore return func return decorator + + +def field_to_typer_default( + field: pydantic.fields.FieldInfo, +) -> Union[typer.models.ArgumentInfo, typer.models.OptionInfo]: + name = field.extra["name"] + positional = field.extra["positional"] + param_decls = field.extra["param_decls"] + panel = field.extra["panel"] + + generic = { + "callback": None, + # "metavar": None, + "show_default": field.default is not Ellipsis, + "help": field.description, + "hidden": field.extra["hidden"], + # "show_choices": True, + "rich_help_panel": panel, + } + + if field.extra["deprecated"] and not panel: + generic["rich_help_panel"] = "Deprecated Options" + + if field.regex: + generic["callback"] = pattern_validator( + field.regex, field.extra["pattern_name"] + ) + + if not positional: + specific: dict[str, Any] = { + "prompt": field.extra["ask"], # Union[bool, str] + "confirmation_prompt": field.extra["confirm"], # bool + # "prompt_required": True, # bool + "hide_input": field.extra["redac"], # bool + # "is_flag": None, # Optional[bool] + # "flag_value": None, # Optional[Any] + # "count": False, # bool + # "allow_from_autoenv": True, + } + + if positional: + return typer.Argument( + field.default, + **generic, + ) + + if param_decls: + param_decls.insert(0, "--" + name.replace("_", "-")) + else: + param_decls = ["--" + name.replace("_", "-"), "-" + name[0]] + + return typer.Option( + field.default, + *param_decls, + **generic, + **specific, + )