mirror of
https://github.com/YunoHost/yunohost.git
synced 2024-09-03 20:06:10 +02:00
add BaseInterface class with helper + improve docstring handling
This commit is contained in:
parent
87317cea66
commit
fee2ee741c
3 changed files with 228 additions and 101 deletions
|
@ -1,52 +1,46 @@
|
||||||
from __future__ import annotations
|
from __future__ import (
|
||||||
|
annotations,
|
||||||
|
) # Enable self reference a class in its own method arguments
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import types
|
import re
|
||||||
import fastapi
|
import fastapi
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
from yunohost.interface.base import (
|
||||||
|
BaseInterface,
|
||||||
|
InterfaceKind,
|
||||||
|
merge_dicts,
|
||||||
|
get_params_doc,
|
||||||
|
override_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def snake_to_camel_case(snake: str) -> str:
|
def snake_to_camel_case(snake: str) -> str:
|
||||||
return "".join(word.title() for word in snake.split("_"))
|
return "".join(word.title() for word in snake.split("_"))
|
||||||
|
|
||||||
|
|
||||||
def get_path_param(route: str) -> Optional[str]:
|
def parse_api_route(route: str) -> list[str]:
|
||||||
last_path = route.split("/")[-1]
|
return re.findall(r"{(\w+)}", route)
|
||||||
if last_path and "{" in last_path:
|
|
||||||
return last_path.strip("{}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def alter_params_for_body(
|
|
||||||
parameters, as_body, as_body_except
|
|
||||||
) -> list[Union[list[inspect.Parameter], inspect.Parameter]]:
|
|
||||||
if as_body_except:
|
|
||||||
body_params = []
|
|
||||||
rest = []
|
|
||||||
for param in parameters:
|
|
||||||
if param.name not in as_body_except:
|
|
||||||
body_params.append(param)
|
|
||||||
else:
|
|
||||||
rest.append(param)
|
|
||||||
return [body_params, *rest]
|
|
||||||
|
|
||||||
if as_body:
|
|
||||||
return [parameters]
|
|
||||||
|
|
||||||
return parameters
|
|
||||||
|
|
||||||
|
|
||||||
def params_to_body(
|
def params_to_body(
|
||||||
params: list[inspect.Parameter], func_name: str
|
params: list[inspect.Parameter],
|
||||||
|
data: dict[str, Any],
|
||||||
|
doc: dict[str, Any],
|
||||||
|
func_name: str,
|
||||||
) -> inspect.Parameter:
|
) -> inspect.Parameter:
|
||||||
model = pydantic.create_model(
|
model = pydantic.create_model(
|
||||||
snake_to_camel_case(func_name),
|
snake_to_camel_case(func_name),
|
||||||
**{
|
**{
|
||||||
param.name: (
|
param.name: (
|
||||||
param.annotation,
|
param.annotation,
|
||||||
param.default if param.default != param.empty else ...,
|
pydantic.Field(
|
||||||
|
param.default if param.default != param.empty else ...,
|
||||||
|
description=doc.get(param.name, None),
|
||||||
|
**data.get(param.name, {}),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for param in params
|
for param in params
|
||||||
},
|
},
|
||||||
|
@ -60,15 +54,15 @@ def params_to_body(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Interface:
|
class Interface(BaseInterface):
|
||||||
type = "api"
|
kind = InterfaceKind.API
|
||||||
|
|
||||||
instance: Union[fastapi.FastAPI, fastapi.APIRouter]
|
instance: Union[fastapi.FastAPI, fastapi.APIRouter]
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
def __init__(self, root: bool = False, name: Optional[str] = None):
|
def __init__(self, root: bool = False, **kwargs):
|
||||||
|
super().__init__(root=root, **kwargs)
|
||||||
self.instance = fastapi.FastAPI() if root else fastapi.APIRouter()
|
self.instance = fastapi.FastAPI() if root else fastapi.APIRouter()
|
||||||
self.name = "root" if root else name or ""
|
|
||||||
|
|
||||||
def add(self, interface: Interface):
|
def add(self, interface: Interface):
|
||||||
assert isinstance(interface.instance, fastapi.APIRouter)
|
assert isinstance(interface.instance, fastapi.APIRouter)
|
||||||
|
@ -76,11 +70,28 @@ class Interface:
|
||||||
interface.instance, prefix=f"/{interface.name}", tags=[interface.name]
|
interface.instance, prefix=f"/{interface.name}", tags=[interface.name]
|
||||||
)
|
)
|
||||||
|
|
||||||
def cli(self, *args, **kwargs):
|
def prepare_params(
|
||||||
def decorator(func):
|
self,
|
||||||
return func
|
params: list[inspect.Parameter],
|
||||||
|
as_body: bool,
|
||||||
|
as_body_except: Optional[list[str]] = None,
|
||||||
|
) -> list[Union[list[inspect.Parameter], inspect.Parameter]]:
|
||||||
|
params = self.filter_params(params)
|
||||||
|
|
||||||
return decorator
|
if as_body_except:
|
||||||
|
body_params = []
|
||||||
|
rest = []
|
||||||
|
for param in params:
|
||||||
|
if param.name not in as_body_except:
|
||||||
|
body_params.append(param)
|
||||||
|
else:
|
||||||
|
rest.append(param)
|
||||||
|
return [body_params, *rest]
|
||||||
|
|
||||||
|
if as_body:
|
||||||
|
return [params]
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def api(
|
def api(
|
||||||
self,
|
self,
|
||||||
|
@ -88,38 +99,40 @@ class Interface:
|
||||||
method: str = "get",
|
method: str = "get",
|
||||||
as_body: bool = False,
|
as_body: bool = False,
|
||||||
as_body_except: Optional[list[str]] = None,
|
as_body_except: Optional[list[str]] = None,
|
||||||
**kwargs,
|
**extra_data,
|
||||||
):
|
):
|
||||||
as_body = as_body if not as_body_except else True
|
as_body = as_body if not as_body_except else True
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
signature = inspect.signature(func)
|
signature = inspect.signature(func)
|
||||||
override_params = []
|
override_params = []
|
||||||
params = alter_params_for_body(
|
params = self.prepare_params(
|
||||||
signature.parameters.values(), as_body, as_body_except
|
signature.parameters.values(), as_body, as_body_except
|
||||||
)
|
)
|
||||||
path_param = get_path_param(route)
|
local_data = merge_dicts(self.local_data, extra_data)
|
||||||
|
params_doc = get_params_doc(func.__doc__)
|
||||||
|
paths = parse_api_route(route)
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
if isinstance(param, list):
|
if isinstance(param, list):
|
||||||
override_params.append(params_to_body(param, func.__name__))
|
override_params.append(
|
||||||
|
params_to_body(param, local_data, params_doc, func.__name__)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
default_kwargs = kwargs.get(param.name, {})
|
param_kwargs = local_data.get(param.name, {})
|
||||||
default_cls = (
|
param_kwargs["description"] = params_doc.get(param.name, None)
|
||||||
fastapi.Path if param.name == path_param else fastapi.Query
|
param_default = (
|
||||||
|
param.default
|
||||||
|
if not param.default == param.empty
|
||||||
|
else ... # required
|
||||||
)
|
)
|
||||||
|
|
||||||
if param.default is None:
|
if param.name in paths:
|
||||||
default_value = default_cls(None, **default_kwargs)
|
param_default = fastapi.Path(param_default, **param_kwargs)
|
||||||
elif param.default is param.empty:
|
|
||||||
default_value = default_cls(..., **default_kwargs)
|
|
||||||
else:
|
else:
|
||||||
default_value = default_cls(param.default, **default_kwargs)
|
param_default = fastapi.Query(param_default, **param_kwargs)
|
||||||
|
|
||||||
override_params.append(param.replace(default=default_value))
|
override_params.append(param.replace(default=param_default))
|
||||||
|
|
||||||
route_func = getattr(self.instance, method)(route)
|
|
||||||
override_signature = signature.replace(parameters=tuple(override_params))
|
|
||||||
|
|
||||||
if as_body:
|
if as_body:
|
||||||
|
|
||||||
|
@ -132,19 +145,18 @@ class Interface:
|
||||||
new_kwargs[kwarg] = value
|
new_kwargs[kwarg] = value
|
||||||
return func(*args, **new_kwargs)
|
return func(*args, **new_kwargs)
|
||||||
|
|
||||||
body_to_args_back.__name__ = func.__name__
|
route_func = override_function(
|
||||||
body_to_args_back.__signature__ = override_signature
|
func, signature, override_params, decorator=body_to_args_back
|
||||||
route_func(body_to_args_back)
|
|
||||||
else:
|
|
||||||
func_copy = types.FunctionType(
|
|
||||||
func.__code__,
|
|
||||||
func.__globals__,
|
|
||||||
func.__name__,
|
|
||||||
func.__defaults__,
|
|
||||||
func.__closure__,
|
|
||||||
)
|
)
|
||||||
func_copy.__signature__ = override_signature
|
else:
|
||||||
route_func(func_copy)
|
route_func = override_function(func, signature, override_params)
|
||||||
|
|
||||||
|
summary = func.__doc__.split("\n\n")[0] if func.__doc__ else None
|
||||||
|
getattr(self.instance, method)(
|
||||||
|
route, summary=summary, deprecated=local_data.get("deprecated")
|
||||||
|
)(route_func)
|
||||||
|
|
||||||
|
self.clear_local_data()
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
94
src/interface/base.py
Normal file
94
src/interface/base.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
import re
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Optional, TypeVar
|
||||||
|
|
||||||
|
ViewFunc = TypeVar("ViewFunc")
|
||||||
|
|
||||||
|
|
||||||
|
class InterfaceKind(Enum):
|
||||||
|
API = "api"
|
||||||
|
CLI = "cli"
|
||||||
|
|
||||||
|
|
||||||
|
def pass_func(func: ViewFunc) -> ViewFunc:
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dicts(*dicts: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
merged: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for dict_ in dicts:
|
||||||
|
for key, value in dict_.items():
|
||||||
|
if key not in merged or isinstance(merged[key], (str, int, float, bool)):
|
||||||
|
merged[key] = value
|
||||||
|
else:
|
||||||
|
merged[key] |= value
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def get_params_doc(docstring: Optional[str]) -> dict[str, str]:
|
||||||
|
if not docstring:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
param_name: param_desc
|
||||||
|
for param_name, param_desc in re.findall(
|
||||||
|
r"- \*\*(\w+)\*\*: (.*)", docstring, re.MULTILINE
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def override_function(
|
||||||
|
func: Callable,
|
||||||
|
func_signature: inspect.Signature,
|
||||||
|
new_params: list[inspect.Parameter],
|
||||||
|
decorator: Optional[Callable] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
doc: Optional[str] = None,
|
||||||
|
) -> Callable:
|
||||||
|
returned_func = decorator or types.FunctionType(
|
||||||
|
func.__code__,
|
||||||
|
func.__globals__,
|
||||||
|
func.__name__,
|
||||||
|
func.__defaults__,
|
||||||
|
func.__closure__,
|
||||||
|
)
|
||||||
|
returned_func.__name__ = name or func.__name__
|
||||||
|
returned_func.__doc__ = doc or func.__doc__
|
||||||
|
returned_func.__signature__ = func_signature.replace(parameters=tuple(new_params))
|
||||||
|
|
||||||
|
return returned_func
|
||||||
|
|
||||||
|
|
||||||
|
class BaseInterface:
|
||||||
|
kind: InterfaceKind
|
||||||
|
local_data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __init__(self, root: bool = False, name: str = "", help: str = ""):
|
||||||
|
self.name = "root" if root else name or ""
|
||||||
|
self.help = help
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
self.local_data = kwargs
|
||||||
|
return pass_func
|
||||||
|
|
||||||
|
def clear_local_data(self):
|
||||||
|
self.local_data = {}
|
||||||
|
|
||||||
|
def filter_params(self, params: list[inspect.Parameter]) -> list[inspect.Parameter]:
|
||||||
|
private = self.local_data.get("private", [])
|
||||||
|
|
||||||
|
if private:
|
||||||
|
return [param for param in params if param.name not in private]
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def api(self, *args, **kwargs):
|
||||||
|
return pass_func
|
||||||
|
|
||||||
|
def cli(self, *args, **kwargs):
|
||||||
|
return pass_func
|
|
@ -1,13 +1,23 @@
|
||||||
from __future__ import annotations
|
from __future__ import (
|
||||||
|
annotations,
|
||||||
|
) # Enable self reference a class in its own method arguments
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import typer
|
import typer
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable
|
||||||
from rich import print as rprint
|
from rich import print as rprint
|
||||||
from rich.syntax import Syntax
|
from rich.syntax import Syntax
|
||||||
|
|
||||||
|
from yunohost.interface.base import (
|
||||||
|
BaseInterface,
|
||||||
|
InterfaceKind,
|
||||||
|
merge_dicts,
|
||||||
|
get_params_doc,
|
||||||
|
override_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_cli_command(command: str) -> tuple[str, list[str]]:
|
def parse_cli_command(command: str) -> tuple[str, list[str]]:
|
||||||
command, *args = command.split(" ")
|
command, *args = command.split(" ")
|
||||||
|
@ -19,63 +29,74 @@ def print_as_yaml(data: Any):
|
||||||
rprint(Syntax(data, "yaml", background_color="default"))
|
rprint(Syntax(data, "yaml", background_color="default"))
|
||||||
|
|
||||||
|
|
||||||
class Interface:
|
class Interface(BaseInterface):
|
||||||
type = "cli"
|
kind = InterfaceKind.CLI
|
||||||
|
|
||||||
instance: typer.Typer
|
instance: typer.Typer
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
def __init__(self, root: bool = False, name: Optional[str] = None):
|
def __init__(self, root: bool = False, **kwargs):
|
||||||
self.instance = typer.Typer()
|
super().__init__(root=root, **kwargs)
|
||||||
self.name = "root" if root else name or ""
|
self.instance = typer.Typer(rich_markup_mode="markdown")
|
||||||
|
|
||||||
def add(self, interface: Interface):
|
def add(self, interface: Interface):
|
||||||
self.instance.add_typer(interface.instance, name=interface.name)
|
self.instance.add_typer(
|
||||||
|
interface.instance, name=interface.name, help=interface.help
|
||||||
|
)
|
||||||
|
|
||||||
def cli(self, command_def: str, **kwargs):
|
def cli(self, command_def: str, **extra_data):
|
||||||
def decorator(func):
|
def decorator(func: Callable):
|
||||||
signature = inspect.signature(func)
|
signature = inspect.signature(func)
|
||||||
override_params = []
|
override_params = []
|
||||||
|
params = self.filter_params(signature.parameters.values())
|
||||||
|
local_data = merge_dicts(self.local_data, extra_data)
|
||||||
|
params_doc = get_params_doc(func.__doc__)
|
||||||
command, args = parse_cli_command(command_def)
|
command, args = parse_cli_command(command_def)
|
||||||
|
|
||||||
for param in signature.parameters.values():
|
for param in params:
|
||||||
|
param_kwargs = local_data.get(param.name, {})
|
||||||
|
param_kwargs["help"] = params_doc.get(param.name, None)
|
||||||
|
|
||||||
|
if param_kwargs.pop("deprecated", False):
|
||||||
|
param_kwargs["rich_help_panel"] = "Deprecated Options"
|
||||||
|
|
||||||
|
if param.name not in args and not param_kwargs.get("hidden", False):
|
||||||
|
param_kwargs["prompt"] = True
|
||||||
|
|
||||||
# Auto setup typer Argument or Option kwargs
|
|
||||||
default_kwargs = kwargs.get(param.name, {})
|
|
||||||
# if param.name not in args and not default_kwargs.get("hidden", False):
|
|
||||||
# default_kwargs["prompt"] = True
|
|
||||||
if param.name == "password":
|
if param.name == "password":
|
||||||
default_kwargs["confirmation_prompt"] = True
|
param_kwargs["confirmation_prompt"] = True
|
||||||
default_kwargs["hide_input"] = True
|
param_kwargs["hide_input"] = True
|
||||||
|
|
||||||
# Define new default value for typer
|
# Populate default param value with typer.Argument|Option
|
||||||
default_cls = typer.Argument if param.name in args else typer.Option
|
param_default = (
|
||||||
if param.default is None:
|
param.default
|
||||||
default_value = default_cls(None, **default_kwargs)
|
if not param.default == param.empty
|
||||||
elif param.default is param.empty:
|
else ... # required
|
||||||
default_value = default_cls(..., **default_kwargs)
|
)
|
||||||
|
if param.name in args:
|
||||||
|
param_default = typer.Argument(param_default, **param_kwargs)
|
||||||
else:
|
else:
|
||||||
default_value = default_cls(param.default, **default_kwargs)
|
param_default = typer.Option(param_default, **param_kwargs)
|
||||||
|
|
||||||
override_params.append(param.replace(default=default_value))
|
override_params.append(param.replace(default=param_default))
|
||||||
|
|
||||||
def hook_results(*args, **kwargs):
|
def hook_results(*args, **kwargs):
|
||||||
results = func(*args, **kwargs)
|
results = func(*args, **kwargs)
|
||||||
print_as_yaml(results)
|
print_as_yaml(results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
hook_results.__name__ = func.__name__
|
command_func = override_function(
|
||||||
hook_results.__signature__ = signature.replace(
|
func,
|
||||||
parameters=tuple(override_params)
|
signature,
|
||||||
|
override_params,
|
||||||
|
decorator=hook_results,
|
||||||
|
doc=func.__doc__.split("\b")[0] if func.__doc__ else None,
|
||||||
)
|
)
|
||||||
self.instance.command(command)(hook_results)
|
self.instance.command(
|
||||||
|
command, deprecated=local_data.get("deprecated", False)
|
||||||
|
)(command_func)
|
||||||
|
|
||||||
|
self.clear_local_data()
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def api(self, *args, **kwargs):
|
|
||||||
def decorator(func):
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue