add BaseInterface class with helper + improve docstring handling

This commit is contained in:
axolotle 2023-01-08 21:23:54 +01:00
parent 87317cea66
commit fee2ee741c
3 changed files with 228 additions and 101 deletions

View file

@ -1,52 +1,46 @@
from __future__ import annotations
from __future__ import (
annotations,
) # Enable self reference a class in its own method arguments
import inspect
import types
import re
import fastapi
import pydantic
from typing import Optional, Union
from typing import Any, Optional, Union
from yunohost.interface.base import (
BaseInterface,
InterfaceKind,
merge_dicts,
get_params_doc,
override_function,
)
def snake_to_camel_case(snake: str) -> str:
return "".join(word.title() for word in snake.split("_"))
def get_path_param(route: str) -> Optional[str]:
last_path = route.split("/")[-1]
if last_path and "{" in last_path:
return last_path.strip("{}")
return None
def alter_params_for_body(
parameters, as_body, as_body_except
) -> list[Union[list[inspect.Parameter], inspect.Parameter]]:
if as_body_except:
body_params = []
rest = []
for param in parameters:
if param.name not in as_body_except:
body_params.append(param)
else:
rest.append(param)
return [body_params, *rest]
if as_body:
return [parameters]
return parameters
def parse_api_route(route: str) -> list[str]:
return re.findall(r"{(\w+)}", route)
def params_to_body(
params: list[inspect.Parameter], func_name: str
params: list[inspect.Parameter],
data: dict[str, Any],
doc: dict[str, Any],
func_name: str,
) -> inspect.Parameter:
model = pydantic.create_model(
snake_to_camel_case(func_name),
**{
param.name: (
param.annotation,
param.default if param.default != param.empty else ...,
pydantic.Field(
param.default if param.default != param.empty else ...,
description=doc.get(param.name, None),
**data.get(param.name, {}),
),
)
for param in params
},
@ -60,15 +54,15 @@ def params_to_body(
)
class Interface:
type = "api"
class Interface(BaseInterface):
kind = InterfaceKind.API
instance: Union[fastapi.FastAPI, fastapi.APIRouter]
name: str
def __init__(self, root: bool = False, name: Optional[str] = None):
def __init__(self, root: bool = False, **kwargs):
super().__init__(root=root, **kwargs)
self.instance = fastapi.FastAPI() if root else fastapi.APIRouter()
self.name = "root" if root else name or ""
def add(self, interface: Interface):
assert isinstance(interface.instance, fastapi.APIRouter)
@ -76,11 +70,28 @@ class Interface:
interface.instance, prefix=f"/{interface.name}", tags=[interface.name]
)
def cli(self, *args, **kwargs):
def decorator(func):
return func
def prepare_params(
self,
params: list[inspect.Parameter],
as_body: bool,
as_body_except: Optional[list[str]] = None,
) -> list[Union[list[inspect.Parameter], inspect.Parameter]]:
params = self.filter_params(params)
return decorator
if as_body_except:
body_params = []
rest = []
for param in params:
if param.name not in as_body_except:
body_params.append(param)
else:
rest.append(param)
return [body_params, *rest]
if as_body:
return [params]
return params
def api(
self,
@ -88,38 +99,40 @@ class Interface:
method: str = "get",
as_body: bool = False,
as_body_except: Optional[list[str]] = None,
**kwargs,
**extra_data,
):
as_body = as_body if not as_body_except else True
def decorator(func):
signature = inspect.signature(func)
override_params = []
params = alter_params_for_body(
params = self.prepare_params(
signature.parameters.values(), as_body, as_body_except
)
path_param = get_path_param(route)
local_data = merge_dicts(self.local_data, extra_data)
params_doc = get_params_doc(func.__doc__)
paths = parse_api_route(route)
for param in params:
if isinstance(param, list):
override_params.append(params_to_body(param, func.__name__))
override_params.append(
params_to_body(param, local_data, params_doc, func.__name__)
)
else:
default_kwargs = kwargs.get(param.name, {})
default_cls = (
fastapi.Path if param.name == path_param else fastapi.Query
param_kwargs = local_data.get(param.name, {})
param_kwargs["description"] = params_doc.get(param.name, None)
param_default = (
param.default
if not param.default == param.empty
else ... # required
)
if param.default is None:
default_value = default_cls(None, **default_kwargs)
elif param.default is param.empty:
default_value = default_cls(..., **default_kwargs)
if param.name in paths:
param_default = fastapi.Path(param_default, **param_kwargs)
else:
default_value = default_cls(param.default, **default_kwargs)
param_default = fastapi.Query(param_default, **param_kwargs)
override_params.append(param.replace(default=default_value))
route_func = getattr(self.instance, method)(route)
override_signature = signature.replace(parameters=tuple(override_params))
override_params.append(param.replace(default=param_default))
if as_body:
@ -132,19 +145,18 @@ class Interface:
new_kwargs[kwarg] = value
return func(*args, **new_kwargs)
body_to_args_back.__name__ = func.__name__
body_to_args_back.__signature__ = override_signature
route_func(body_to_args_back)
else:
func_copy = types.FunctionType(
func.__code__,
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__,
route_func = override_function(
func, signature, override_params, decorator=body_to_args_back
)
func_copy.__signature__ = override_signature
route_func(func_copy)
else:
route_func = override_function(func, signature, override_params)
summary = func.__doc__.split("\n\n")[0] if func.__doc__ else None
getattr(self.instance, method)(
route, summary=summary, deprecated=local_data.get("deprecated")
)(route_func)
self.clear_local_data()
return func

94
src/interface/base.py Normal file
View file

@ -0,0 +1,94 @@
import inspect
import types
import re
from enum import Enum
from typing import Any, Callable, Optional, TypeVar
ViewFunc = TypeVar("ViewFunc")
class InterfaceKind(Enum):
API = "api"
CLI = "cli"
def pass_func(func: ViewFunc) -> ViewFunc:
return func
def merge_dicts(*dicts: dict[str, Any]) -> dict[str, Any]:
merged: dict[str, Any] = {}
for dict_ in dicts:
for key, value in dict_.items():
if key not in merged or isinstance(merged[key], (str, int, float, bool)):
merged[key] = value
else:
merged[key] |= value
return merged
def get_params_doc(docstring: Optional[str]) -> dict[str, str]:
if not docstring:
return {}
return {
param_name: param_desc
for param_name, param_desc in re.findall(
r"- \*\*(\w+)\*\*: (.*)", docstring, re.MULTILINE
)
}
def override_function(
func: Callable,
func_signature: inspect.Signature,
new_params: list[inspect.Parameter],
decorator: Optional[Callable] = None,
name: Optional[str] = None,
doc: Optional[str] = None,
) -> Callable:
returned_func = decorator or types.FunctionType(
func.__code__,
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__,
)
returned_func.__name__ = name or func.__name__
returned_func.__doc__ = doc or func.__doc__
returned_func.__signature__ = func_signature.replace(parameters=tuple(new_params))
return returned_func
class BaseInterface:
kind: InterfaceKind
local_data: dict[str, Any] = {}
def __init__(self, root: bool = False, name: str = "", help: str = ""):
self.name = "root" if root else name or ""
self.help = help
def __call__(self, *args, **kwargs):
self.local_data = kwargs
return pass_func
def clear_local_data(self):
self.local_data = {}
def filter_params(self, params: list[inspect.Parameter]) -> list[inspect.Parameter]:
private = self.local_data.get("private", [])
if private:
return [param for param in params if param.name not in private]
return params
def api(self, *args, **kwargs):
return pass_func
def cli(self, *args, **kwargs):
return pass_func

View file

@ -1,13 +1,23 @@
from __future__ import annotations
from __future__ import (
annotations,
) # Enable self reference a class in its own method arguments
import inspect
import typer
import yaml
from typing import Any, Optional
from typing import Any, Callable
from rich import print as rprint
from rich.syntax import Syntax
from yunohost.interface.base import (
BaseInterface,
InterfaceKind,
merge_dicts,
get_params_doc,
override_function,
)
def parse_cli_command(command: str) -> tuple[str, list[str]]:
command, *args = command.split(" ")
@ -19,63 +29,74 @@ def print_as_yaml(data: Any):
rprint(Syntax(data, "yaml", background_color="default"))
class Interface:
type = "cli"
class Interface(BaseInterface):
kind = InterfaceKind.CLI
instance: typer.Typer
name: str
def __init__(self, root: bool = False, name: Optional[str] = None):
self.instance = typer.Typer()
self.name = "root" if root else name or ""
def __init__(self, root: bool = False, **kwargs):
super().__init__(root=root, **kwargs)
self.instance = typer.Typer(rich_markup_mode="markdown")
def add(self, interface: Interface):
self.instance.add_typer(interface.instance, name=interface.name)
self.instance.add_typer(
interface.instance, name=interface.name, help=interface.help
)
def cli(self, command_def: str, **kwargs):
def decorator(func):
def cli(self, command_def: str, **extra_data):
def decorator(func: Callable):
signature = inspect.signature(func)
override_params = []
params = self.filter_params(signature.parameters.values())
local_data = merge_dicts(self.local_data, extra_data)
params_doc = get_params_doc(func.__doc__)
command, args = parse_cli_command(command_def)
for param in signature.parameters.values():
for param in params:
param_kwargs = local_data.get(param.name, {})
param_kwargs["help"] = params_doc.get(param.name, None)
if param_kwargs.pop("deprecated", False):
param_kwargs["rich_help_panel"] = "Deprecated Options"
if param.name not in args and not param_kwargs.get("hidden", False):
param_kwargs["prompt"] = True
# Auto setup typer Argument or Option kwargs
default_kwargs = kwargs.get(param.name, {})
# if param.name not in args and not default_kwargs.get("hidden", False):
# default_kwargs["prompt"] = True
if param.name == "password":
default_kwargs["confirmation_prompt"] = True
default_kwargs["hide_input"] = True
param_kwargs["confirmation_prompt"] = True
param_kwargs["hide_input"] = True
# Define new default value for typer
default_cls = typer.Argument if param.name in args else typer.Option
if param.default is None:
default_value = default_cls(None, **default_kwargs)
elif param.default is param.empty:
default_value = default_cls(..., **default_kwargs)
# Populate default param value with typer.Argument|Option
param_default = (
param.default
if not param.default == param.empty
else ... # required
)
if param.name in args:
param_default = typer.Argument(param_default, **param_kwargs)
else:
default_value = default_cls(param.default, **default_kwargs)
param_default = typer.Option(param_default, **param_kwargs)
override_params.append(param.replace(default=default_value))
override_params.append(param.replace(default=param_default))
def hook_results(*args, **kwargs):
results = func(*args, **kwargs)
print_as_yaml(results)
return results
hook_results.__name__ = func.__name__
hook_results.__signature__ = signature.replace(
parameters=tuple(override_params)
command_func = override_function(
func,
signature,
override_params,
decorator=hook_results,
doc=func.__doc__.split("\b")[0] if func.__doc__ else None,
)
self.instance.command(command)(hook_results)
self.instance.command(
command, deprecated=local_data.get("deprecated", False)
)(command_func)
self.clear_local_data()
return func
return decorator
def api(self, *args, **kwargs):
def decorator(func):
return func
return decorator