mirror of
https://github.com/YunoHost/yunohost.git
synced 2024-09-03 20:06:10 +02:00
add BaseInterface class with helper + improve docstring handling
This commit is contained in:
parent
87317cea66
commit
fee2ee741c
3 changed files with 228 additions and 101 deletions
|
@ -1,52 +1,46 @@
|
|||
from __future__ import annotations
|
||||
from __future__ import (
|
||||
annotations,
|
||||
) # Enable self reference a class in its own method arguments
|
||||
|
||||
import inspect
|
||||
import types
|
||||
import re
|
||||
import fastapi
|
||||
import pydantic
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
from yunohost.interface.base import (
|
||||
BaseInterface,
|
||||
InterfaceKind,
|
||||
merge_dicts,
|
||||
get_params_doc,
|
||||
override_function,
|
||||
)
|
||||
|
||||
|
||||
def snake_to_camel_case(snake: str) -> str:
|
||||
return "".join(word.title() for word in snake.split("_"))
|
||||
|
||||
|
||||
def get_path_param(route: str) -> Optional[str]:
|
||||
last_path = route.split("/")[-1]
|
||||
if last_path and "{" in last_path:
|
||||
return last_path.strip("{}")
|
||||
return None
|
||||
|
||||
|
||||
def alter_params_for_body(
|
||||
parameters, as_body, as_body_except
|
||||
) -> list[Union[list[inspect.Parameter], inspect.Parameter]]:
|
||||
if as_body_except:
|
||||
body_params = []
|
||||
rest = []
|
||||
for param in parameters:
|
||||
if param.name not in as_body_except:
|
||||
body_params.append(param)
|
||||
else:
|
||||
rest.append(param)
|
||||
return [body_params, *rest]
|
||||
|
||||
if as_body:
|
||||
return [parameters]
|
||||
|
||||
return parameters
|
||||
def parse_api_route(route: str) -> list[str]:
|
||||
return re.findall(r"{(\w+)}", route)
|
||||
|
||||
|
||||
def params_to_body(
|
||||
params: list[inspect.Parameter], func_name: str
|
||||
params: list[inspect.Parameter],
|
||||
data: dict[str, Any],
|
||||
doc: dict[str, Any],
|
||||
func_name: str,
|
||||
) -> inspect.Parameter:
|
||||
model = pydantic.create_model(
|
||||
snake_to_camel_case(func_name),
|
||||
**{
|
||||
param.name: (
|
||||
param.annotation,
|
||||
param.default if param.default != param.empty else ...,
|
||||
pydantic.Field(
|
||||
param.default if param.default != param.empty else ...,
|
||||
description=doc.get(param.name, None),
|
||||
**data.get(param.name, {}),
|
||||
),
|
||||
)
|
||||
for param in params
|
||||
},
|
||||
|
@ -60,15 +54,15 @@ def params_to_body(
|
|||
)
|
||||
|
||||
|
||||
class Interface:
|
||||
type = "api"
|
||||
class Interface(BaseInterface):
|
||||
kind = InterfaceKind.API
|
||||
|
||||
instance: Union[fastapi.FastAPI, fastapi.APIRouter]
|
||||
name: str
|
||||
|
||||
def __init__(self, root: bool = False, name: Optional[str] = None):
|
||||
def __init__(self, root: bool = False, **kwargs):
|
||||
super().__init__(root=root, **kwargs)
|
||||
self.instance = fastapi.FastAPI() if root else fastapi.APIRouter()
|
||||
self.name = "root" if root else name or ""
|
||||
|
||||
def add(self, interface: Interface):
|
||||
assert isinstance(interface.instance, fastapi.APIRouter)
|
||||
|
@ -76,11 +70,28 @@ class Interface:
|
|||
interface.instance, prefix=f"/{interface.name}", tags=[interface.name]
|
||||
)
|
||||
|
||||
def cli(self, *args, **kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
def prepare_params(
|
||||
self,
|
||||
params: list[inspect.Parameter],
|
||||
as_body: bool,
|
||||
as_body_except: Optional[list[str]] = None,
|
||||
) -> list[Union[list[inspect.Parameter], inspect.Parameter]]:
|
||||
params = self.filter_params(params)
|
||||
|
||||
return decorator
|
||||
if as_body_except:
|
||||
body_params = []
|
||||
rest = []
|
||||
for param in params:
|
||||
if param.name not in as_body_except:
|
||||
body_params.append(param)
|
||||
else:
|
||||
rest.append(param)
|
||||
return [body_params, *rest]
|
||||
|
||||
if as_body:
|
||||
return [params]
|
||||
|
||||
return params
|
||||
|
||||
def api(
|
||||
self,
|
||||
|
@ -88,38 +99,40 @@ class Interface:
|
|||
method: str = "get",
|
||||
as_body: bool = False,
|
||||
as_body_except: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
**extra_data,
|
||||
):
|
||||
as_body = as_body if not as_body_except else True
|
||||
|
||||
def decorator(func):
|
||||
signature = inspect.signature(func)
|
||||
override_params = []
|
||||
params = alter_params_for_body(
|
||||
params = self.prepare_params(
|
||||
signature.parameters.values(), as_body, as_body_except
|
||||
)
|
||||
path_param = get_path_param(route)
|
||||
local_data = merge_dicts(self.local_data, extra_data)
|
||||
params_doc = get_params_doc(func.__doc__)
|
||||
paths = parse_api_route(route)
|
||||
|
||||
for param in params:
|
||||
if isinstance(param, list):
|
||||
override_params.append(params_to_body(param, func.__name__))
|
||||
override_params.append(
|
||||
params_to_body(param, local_data, params_doc, func.__name__)
|
||||
)
|
||||
else:
|
||||
default_kwargs = kwargs.get(param.name, {})
|
||||
default_cls = (
|
||||
fastapi.Path if param.name == path_param else fastapi.Query
|
||||
param_kwargs = local_data.get(param.name, {})
|
||||
param_kwargs["description"] = params_doc.get(param.name, None)
|
||||
param_default = (
|
||||
param.default
|
||||
if not param.default == param.empty
|
||||
else ... # required
|
||||
)
|
||||
|
||||
if param.default is None:
|
||||
default_value = default_cls(None, **default_kwargs)
|
||||
elif param.default is param.empty:
|
||||
default_value = default_cls(..., **default_kwargs)
|
||||
if param.name in paths:
|
||||
param_default = fastapi.Path(param_default, **param_kwargs)
|
||||
else:
|
||||
default_value = default_cls(param.default, **default_kwargs)
|
||||
param_default = fastapi.Query(param_default, **param_kwargs)
|
||||
|
||||
override_params.append(param.replace(default=default_value))
|
||||
|
||||
route_func = getattr(self.instance, method)(route)
|
||||
override_signature = signature.replace(parameters=tuple(override_params))
|
||||
override_params.append(param.replace(default=param_default))
|
||||
|
||||
if as_body:
|
||||
|
||||
|
@ -132,19 +145,18 @@ class Interface:
|
|||
new_kwargs[kwarg] = value
|
||||
return func(*args, **new_kwargs)
|
||||
|
||||
body_to_args_back.__name__ = func.__name__
|
||||
body_to_args_back.__signature__ = override_signature
|
||||
route_func(body_to_args_back)
|
||||
else:
|
||||
func_copy = types.FunctionType(
|
||||
func.__code__,
|
||||
func.__globals__,
|
||||
func.__name__,
|
||||
func.__defaults__,
|
||||
func.__closure__,
|
||||
route_func = override_function(
|
||||
func, signature, override_params, decorator=body_to_args_back
|
||||
)
|
||||
func_copy.__signature__ = override_signature
|
||||
route_func(func_copy)
|
||||
else:
|
||||
route_func = override_function(func, signature, override_params)
|
||||
|
||||
summary = func.__doc__.split("\n\n")[0] if func.__doc__ else None
|
||||
getattr(self.instance, method)(
|
||||
route, summary=summary, deprecated=local_data.get("deprecated")
|
||||
)(route_func)
|
||||
|
||||
self.clear_local_data()
|
||||
|
||||
return func
|
||||
|
||||
|
|
94
src/interface/base.py
Normal file
94
src/interface/base.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
import inspect
|
||||
import types
|
||||
import re
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
ViewFunc = TypeVar("ViewFunc")
|
||||
|
||||
|
||||
class InterfaceKind(Enum):
|
||||
API = "api"
|
||||
CLI = "cli"
|
||||
|
||||
|
||||
def pass_func(func: ViewFunc) -> ViewFunc:
|
||||
return func
|
||||
|
||||
|
||||
def merge_dicts(*dicts: dict[str, Any]) -> dict[str, Any]:
|
||||
merged: dict[str, Any] = {}
|
||||
|
||||
for dict_ in dicts:
|
||||
for key, value in dict_.items():
|
||||
if key not in merged or isinstance(merged[key], (str, int, float, bool)):
|
||||
merged[key] = value
|
||||
else:
|
||||
merged[key] |= value
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def get_params_doc(docstring: Optional[str]) -> dict[str, str]:
|
||||
if not docstring:
|
||||
return {}
|
||||
|
||||
return {
|
||||
param_name: param_desc
|
||||
for param_name, param_desc in re.findall(
|
||||
r"- \*\*(\w+)\*\*: (.*)", docstring, re.MULTILINE
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def override_function(
|
||||
func: Callable,
|
||||
func_signature: inspect.Signature,
|
||||
new_params: list[inspect.Parameter],
|
||||
decorator: Optional[Callable] = None,
|
||||
name: Optional[str] = None,
|
||||
doc: Optional[str] = None,
|
||||
) -> Callable:
|
||||
returned_func = decorator or types.FunctionType(
|
||||
func.__code__,
|
||||
func.__globals__,
|
||||
func.__name__,
|
||||
func.__defaults__,
|
||||
func.__closure__,
|
||||
)
|
||||
returned_func.__name__ = name or func.__name__
|
||||
returned_func.__doc__ = doc or func.__doc__
|
||||
returned_func.__signature__ = func_signature.replace(parameters=tuple(new_params))
|
||||
|
||||
return returned_func
|
||||
|
||||
|
||||
class BaseInterface:
|
||||
kind: InterfaceKind
|
||||
local_data: dict[str, Any] = {}
|
||||
|
||||
def __init__(self, root: bool = False, name: str = "", help: str = ""):
|
||||
self.name = "root" if root else name or ""
|
||||
self.help = help
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.local_data = kwargs
|
||||
return pass_func
|
||||
|
||||
def clear_local_data(self):
|
||||
self.local_data = {}
|
||||
|
||||
def filter_params(self, params: list[inspect.Parameter]) -> list[inspect.Parameter]:
|
||||
private = self.local_data.get("private", [])
|
||||
|
||||
if private:
|
||||
return [param for param in params if param.name not in private]
|
||||
|
||||
return params
|
||||
|
||||
def api(self, *args, **kwargs):
|
||||
return pass_func
|
||||
|
||||
def cli(self, *args, **kwargs):
|
||||
return pass_func
|
|
@ -1,13 +1,23 @@
|
|||
from __future__ import annotations
|
||||
from __future__ import (
|
||||
annotations,
|
||||
) # Enable self reference a class in its own method arguments
|
||||
|
||||
import inspect
|
||||
import typer
|
||||
import yaml
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable
|
||||
from rich import print as rprint
|
||||
from rich.syntax import Syntax
|
||||
|
||||
from yunohost.interface.base import (
|
||||
BaseInterface,
|
||||
InterfaceKind,
|
||||
merge_dicts,
|
||||
get_params_doc,
|
||||
override_function,
|
||||
)
|
||||
|
||||
|
||||
def parse_cli_command(command: str) -> tuple[str, list[str]]:
|
||||
command, *args = command.split(" ")
|
||||
|
@ -19,63 +29,74 @@ def print_as_yaml(data: Any):
|
|||
rprint(Syntax(data, "yaml", background_color="default"))
|
||||
|
||||
|
||||
class Interface:
|
||||
type = "cli"
|
||||
|
||||
class Interface(BaseInterface):
|
||||
kind = InterfaceKind.CLI
|
||||
instance: typer.Typer
|
||||
name: str
|
||||
|
||||
def __init__(self, root: bool = False, name: Optional[str] = None):
|
||||
self.instance = typer.Typer()
|
||||
self.name = "root" if root else name or ""
|
||||
def __init__(self, root: bool = False, **kwargs):
|
||||
super().__init__(root=root, **kwargs)
|
||||
self.instance = typer.Typer(rich_markup_mode="markdown")
|
||||
|
||||
def add(self, interface: Interface):
|
||||
self.instance.add_typer(interface.instance, name=interface.name)
|
||||
self.instance.add_typer(
|
||||
interface.instance, name=interface.name, help=interface.help
|
||||
)
|
||||
|
||||
def cli(self, command_def: str, **kwargs):
|
||||
def decorator(func):
|
||||
def cli(self, command_def: str, **extra_data):
|
||||
def decorator(func: Callable):
|
||||
signature = inspect.signature(func)
|
||||
override_params = []
|
||||
params = self.filter_params(signature.parameters.values())
|
||||
local_data = merge_dicts(self.local_data, extra_data)
|
||||
params_doc = get_params_doc(func.__doc__)
|
||||
command, args = parse_cli_command(command_def)
|
||||
|
||||
for param in signature.parameters.values():
|
||||
for param in params:
|
||||
param_kwargs = local_data.get(param.name, {})
|
||||
param_kwargs["help"] = params_doc.get(param.name, None)
|
||||
|
||||
if param_kwargs.pop("deprecated", False):
|
||||
param_kwargs["rich_help_panel"] = "Deprecated Options"
|
||||
|
||||
if param.name not in args and not param_kwargs.get("hidden", False):
|
||||
param_kwargs["prompt"] = True
|
||||
|
||||
# Auto setup typer Argument or Option kwargs
|
||||
default_kwargs = kwargs.get(param.name, {})
|
||||
# if param.name not in args and not default_kwargs.get("hidden", False):
|
||||
# default_kwargs["prompt"] = True
|
||||
if param.name == "password":
|
||||
default_kwargs["confirmation_prompt"] = True
|
||||
default_kwargs["hide_input"] = True
|
||||
param_kwargs["confirmation_prompt"] = True
|
||||
param_kwargs["hide_input"] = True
|
||||
|
||||
# Define new default value for typer
|
||||
default_cls = typer.Argument if param.name in args else typer.Option
|
||||
if param.default is None:
|
||||
default_value = default_cls(None, **default_kwargs)
|
||||
elif param.default is param.empty:
|
||||
default_value = default_cls(..., **default_kwargs)
|
||||
# Populate default param value with typer.Argument|Option
|
||||
param_default = (
|
||||
param.default
|
||||
if not param.default == param.empty
|
||||
else ... # required
|
||||
)
|
||||
if param.name in args:
|
||||
param_default = typer.Argument(param_default, **param_kwargs)
|
||||
else:
|
||||
default_value = default_cls(param.default, **default_kwargs)
|
||||
param_default = typer.Option(param_default, **param_kwargs)
|
||||
|
||||
override_params.append(param.replace(default=default_value))
|
||||
override_params.append(param.replace(default=param_default))
|
||||
|
||||
def hook_results(*args, **kwargs):
|
||||
results = func(*args, **kwargs)
|
||||
print_as_yaml(results)
|
||||
return results
|
||||
|
||||
hook_results.__name__ = func.__name__
|
||||
hook_results.__signature__ = signature.replace(
|
||||
parameters=tuple(override_params)
|
||||
command_func = override_function(
|
||||
func,
|
||||
signature,
|
||||
override_params,
|
||||
decorator=hook_results,
|
||||
doc=func.__doc__.split("\b")[0] if func.__doc__ else None,
|
||||
)
|
||||
self.instance.command(command)(hook_results)
|
||||
self.instance.command(
|
||||
command, deprecated=local_data.get("deprecated", False)
|
||||
)(command_func)
|
||||
|
||||
self.clear_local_data()
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def api(self, *args, **kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
|
Loading…
Add table
Reference in a new issue