diff --git a/src/interface/api.py b/src/interface/api.py index 0131f5c24..bf84cff79 100644 --- a/src/interface/api.py +++ b/src/interface/api.py @@ -1,52 +1,46 @@ -from __future__ import annotations +from __future__ import ( + annotations, +) # Enable self reference a class in its own method arguments import inspect -import types +import re import fastapi import pydantic -from typing import Optional, Union +from typing import Any, Optional, Union +from yunohost.interface.base import ( + BaseInterface, + InterfaceKind, + merge_dicts, + get_params_doc, + override_function, +) def snake_to_camel_case(snake: str) -> str: return "".join(word.title() for word in snake.split("_")) -def get_path_param(route: str) -> Optional[str]: - last_path = route.split("/")[-1] - if last_path and "{" in last_path: - return last_path.strip("{}") - return None - - -def alter_params_for_body( - parameters, as_body, as_body_except -) -> list[Union[list[inspect.Parameter], inspect.Parameter]]: - if as_body_except: - body_params = [] - rest = [] - for param in parameters: - if param.name not in as_body_except: - body_params.append(param) - else: - rest.append(param) - return [body_params, *rest] - - if as_body: - return [parameters] - - return parameters +def parse_api_route(route: str) -> list[str]: + return re.findall(r"{(\w+)}", route) def params_to_body( - params: list[inspect.Parameter], func_name: str + params: list[inspect.Parameter], + data: dict[str, Any], + doc: dict[str, Any], + func_name: str, ) -> inspect.Parameter: model = pydantic.create_model( snake_to_camel_case(func_name), **{ param.name: ( param.annotation, - param.default if param.default != param.empty else ..., + pydantic.Field( + param.default if param.default != param.empty else ..., + description=doc.get(param.name, None), + **data.get(param.name, {}), + ), ) for param in params }, @@ -60,15 +54,15 @@ def params_to_body( ) -class Interface: - type = "api" +class Interface(BaseInterface): + kind = InterfaceKind.API instance: Union[fastapi.FastAPI, fastapi.APIRouter] name: str - def __init__(self, root: bool = False, name: Optional[str] = None): + def __init__(self, root: bool = False, **kwargs): + super().__init__(root=root, **kwargs) self.instance = fastapi.FastAPI() if root else fastapi.APIRouter() - self.name = "root" if root else name or "" def add(self, interface: Interface): assert isinstance(interface.instance, fastapi.APIRouter) @@ -76,11 +70,28 @@ class Interface: interface.instance, prefix=f"/{interface.name}", tags=[interface.name] ) - def cli(self, *args, **kwargs): - def decorator(func): - return func + def prepare_params( + self, + params: list[inspect.Parameter], + as_body: bool, + as_body_except: Optional[list[str]] = None, + ) -> list[Union[list[inspect.Parameter], inspect.Parameter]]: + params = self.filter_params(params) - return decorator + if as_body_except: + body_params = [] + rest = [] + for param in params: + if param.name not in as_body_except: + body_params.append(param) + else: + rest.append(param) + return [body_params, *rest] + + if as_body: + return [params] + + return params def api( self, @@ -88,38 +99,40 @@ class Interface: method: str = "get", as_body: bool = False, as_body_except: Optional[list[str]] = None, - **kwargs, + **extra_data, ): as_body = as_body if not as_body_except else True def decorator(func): signature = inspect.signature(func) override_params = [] - params = alter_params_for_body( + params = self.prepare_params( signature.parameters.values(), as_body, as_body_except ) - path_param = get_path_param(route) + local_data = merge_dicts(self.local_data, extra_data) + params_doc = get_params_doc(func.__doc__) + paths = parse_api_route(route) for param in params: if isinstance(param, list): - override_params.append(params_to_body(param, func.__name__)) + override_params.append( + params_to_body(param, local_data, params_doc, func.__name__) + ) else: - default_kwargs = kwargs.get(param.name, {}) - default_cls = ( - fastapi.Path if param.name == path_param else fastapi.Query + param_kwargs = local_data.get(param.name, {}) + param_kwargs["description"] = params_doc.get(param.name, None) + param_default = ( + param.default + if not param.default == param.empty + else ... # required ) - if param.default is None: - default_value = default_cls(None, **default_kwargs) - elif param.default is param.empty: - default_value = default_cls(..., **default_kwargs) + if param.name in paths: + param_default = fastapi.Path(param_default, **param_kwargs) else: - default_value = default_cls(param.default, **default_kwargs) + param_default = fastapi.Query(param_default, **param_kwargs) - override_params.append(param.replace(default=default_value)) - - route_func = getattr(self.instance, method)(route) - override_signature = signature.replace(parameters=tuple(override_params)) + override_params.append(param.replace(default=param_default)) if as_body: @@ -132,19 +145,18 @@ class Interface: new_kwargs[kwarg] = value return func(*args, **new_kwargs) - body_to_args_back.__name__ = func.__name__ - body_to_args_back.__signature__ = override_signature - route_func(body_to_args_back) - else: - func_copy = types.FunctionType( - func.__code__, - func.__globals__, - func.__name__, - func.__defaults__, - func.__closure__, + route_func = override_function( + func, signature, override_params, decorator=body_to_args_back ) - func_copy.__signature__ = override_signature - route_func(func_copy) + else: + route_func = override_function(func, signature, override_params) + + summary = func.__doc__.split("\n\n")[0] if func.__doc__ else None + getattr(self.instance, method)( + route, summary=summary, deprecated=local_data.get("deprecated") + )(route_func) + + self.clear_local_data() return func diff --git a/src/interface/base.py b/src/interface/base.py new file mode 100644 index 000000000..e4a01321c --- /dev/null +++ b/src/interface/base.py @@ -0,0 +1,94 @@ +import inspect +import types +import re + +from enum import Enum +from typing import Any, Callable, Optional, TypeVar + +ViewFunc = TypeVar("ViewFunc") + + +class InterfaceKind(Enum): + API = "api" + CLI = "cli" + + +def pass_func(func: ViewFunc) -> ViewFunc: + return func + + +def merge_dicts(*dicts: dict[str, Any]) -> dict[str, Any]: + merged: dict[str, Any] = {} + + for dict_ in dicts: + for key, value in dict_.items(): + if key not in merged or isinstance(merged[key], (str, int, float, bool)): + merged[key] = value + else: + merged[key] |= value + + return merged + + +def get_params_doc(docstring: Optional[str]) -> dict[str, str]: + if not docstring: + return {} + + return { + param_name: param_desc + for param_name, param_desc in re.findall( + r"- \*\*(\w+)\*\*: (.*)", docstring, re.MULTILINE + ) + } + + +def override_function( + func: Callable, + func_signature: inspect.Signature, + new_params: list[inspect.Parameter], + decorator: Optional[Callable] = None, + name: Optional[str] = None, + doc: Optional[str] = None, +) -> Callable: + returned_func = decorator or types.FunctionType( + func.__code__, + func.__globals__, + func.__name__, + func.__defaults__, + func.__closure__, + ) + returned_func.__name__ = name or func.__name__ + returned_func.__doc__ = doc or func.__doc__ + returned_func.__signature__ = func_signature.replace(parameters=tuple(new_params)) + + return returned_func + + +class BaseInterface: + kind: InterfaceKind + local_data: dict[str, Any] = {} + + def __init__(self, root: bool = False, name: str = "", help: str = ""): + self.name = "root" if root else name or "" + self.help = help + + def __call__(self, *args, **kwargs): + self.local_data = kwargs + return pass_func + + def clear_local_data(self): + self.local_data = {} + + def filter_params(self, params: list[inspect.Parameter]) -> list[inspect.Parameter]: + private = self.local_data.get("private", []) + + if private: + return [param for param in params if param.name not in private] + + return params + + def api(self, *args, **kwargs): + return pass_func + + def cli(self, *args, **kwargs): + return pass_func diff --git a/src/interface/cli.py b/src/interface/cli.py index e9c9f4687..23473ab6c 100644 --- a/src/interface/cli.py +++ b/src/interface/cli.py @@ -1,13 +1,23 @@ -from __future__ import annotations +from __future__ import ( + annotations, +) # Enable self reference a class in its own method arguments import inspect import typer import yaml -from typing import Any, Optional +from typing import Any, Callable from rich import print as rprint from rich.syntax import Syntax +from yunohost.interface.base import ( + BaseInterface, + InterfaceKind, + merge_dicts, + get_params_doc, + override_function, +) + def parse_cli_command(command: str) -> tuple[str, list[str]]: command, *args = command.split(" ") @@ -19,63 +29,74 @@ def print_as_yaml(data: Any): rprint(Syntax(data, "yaml", background_color="default")) -class Interface: - type = "cli" - +class Interface(BaseInterface): + kind = InterfaceKind.CLI instance: typer.Typer name: str - def __init__(self, root: bool = False, name: Optional[str] = None): - self.instance = typer.Typer() - self.name = "root" if root else name or "" + def __init__(self, root: bool = False, **kwargs): + super().__init__(root=root, **kwargs) + self.instance = typer.Typer(rich_markup_mode="markdown") def add(self, interface: Interface): - self.instance.add_typer(interface.instance, name=interface.name) + self.instance.add_typer( + interface.instance, name=interface.name, help=interface.help + ) - def cli(self, command_def: str, **kwargs): - def decorator(func): + def cli(self, command_def: str, **extra_data): + def decorator(func: Callable): signature = inspect.signature(func) override_params = [] + params = self.filter_params(signature.parameters.values()) + local_data = merge_dicts(self.local_data, extra_data) + params_doc = get_params_doc(func.__doc__) command, args = parse_cli_command(command_def) - for param in signature.parameters.values(): + for param in params: + param_kwargs = local_data.get(param.name, {}) + param_kwargs["help"] = params_doc.get(param.name, None) + + if param_kwargs.pop("deprecated", False): + param_kwargs["rich_help_panel"] = "Deprecated Options" + + if param.name not in args and not param_kwargs.get("hidden", False): + param_kwargs["prompt"] = True - # Auto setup typer Argument or Option kwargs - default_kwargs = kwargs.get(param.name, {}) - # if param.name not in args and not default_kwargs.get("hidden", False): - # default_kwargs["prompt"] = True if param.name == "password": - default_kwargs["confirmation_prompt"] = True - default_kwargs["hide_input"] = True + param_kwargs["confirmation_prompt"] = True + param_kwargs["hide_input"] = True - # Define new default value for typer - default_cls = typer.Argument if param.name in args else typer.Option - if param.default is None: - default_value = default_cls(None, **default_kwargs) - elif param.default is param.empty: - default_value = default_cls(..., **default_kwargs) + # Populate default param value with typer.Argument|Option + param_default = ( + param.default + if not param.default == param.empty + else ... # required + ) + if param.name in args: + param_default = typer.Argument(param_default, **param_kwargs) else: - default_value = default_cls(param.default, **default_kwargs) + param_default = typer.Option(param_default, **param_kwargs) - override_params.append(param.replace(default=default_value)) + override_params.append(param.replace(default=param_default)) def hook_results(*args, **kwargs): results = func(*args, **kwargs) print_as_yaml(results) return results - hook_results.__name__ = func.__name__ - hook_results.__signature__ = signature.replace( - parameters=tuple(override_params) + command_func = override_function( + func, + signature, + override_params, + decorator=hook_results, + doc=func.__doc__.split("\b")[0] if func.__doc__ else None, ) - self.instance.command(command)(hook_results) + self.instance.command( + command, deprecated=local_data.get("deprecated", False) + )(command_func) + + self.clear_local_data() return func return decorator - - def api(self, *args, **kwargs): - def decorator(func): - return func - - return decorator