add BaseInterface class with helper + improve docstring handling

This commit is contained in:
axolotle 2023-01-08 21:23:54 +01:00
parent 87317cea66
commit fee2ee741c
3 changed files with 228 additions and 101 deletions

View file

@ -1,52 +1,46 @@
from __future__ import annotations from __future__ import (
annotations,
) # Enable self reference a class in its own method arguments
import inspect import inspect
import types import re
import fastapi import fastapi
import pydantic import pydantic
from typing import Optional, Union from typing import Any, Optional, Union
from yunohost.interface.base import (
BaseInterface,
InterfaceKind,
merge_dicts,
get_params_doc,
override_function,
)
def snake_to_camel_case(snake: str) -> str: def snake_to_camel_case(snake: str) -> str:
return "".join(word.title() for word in snake.split("_")) return "".join(word.title() for word in snake.split("_"))
def get_path_param(route: str) -> Optional[str]: def parse_api_route(route: str) -> list[str]:
last_path = route.split("/")[-1] return re.findall(r"{(\w+)}", route)
if last_path and "{" in last_path:
return last_path.strip("{}")
return None
def alter_params_for_body(
parameters, as_body, as_body_except
) -> list[Union[list[inspect.Parameter], inspect.Parameter]]:
if as_body_except:
body_params = []
rest = []
for param in parameters:
if param.name not in as_body_except:
body_params.append(param)
else:
rest.append(param)
return [body_params, *rest]
if as_body:
return [parameters]
return parameters
def params_to_body( def params_to_body(
params: list[inspect.Parameter], func_name: str params: list[inspect.Parameter],
data: dict[str, Any],
doc: dict[str, Any],
func_name: str,
) -> inspect.Parameter: ) -> inspect.Parameter:
model = pydantic.create_model( model = pydantic.create_model(
snake_to_camel_case(func_name), snake_to_camel_case(func_name),
**{ **{
param.name: ( param.name: (
param.annotation, param.annotation,
param.default if param.default != param.empty else ..., pydantic.Field(
param.default if param.default != param.empty else ...,
description=doc.get(param.name, None),
**data.get(param.name, {}),
),
) )
for param in params for param in params
}, },
@ -60,15 +54,15 @@ def params_to_body(
) )
class Interface: class Interface(BaseInterface):
type = "api" kind = InterfaceKind.API
instance: Union[fastapi.FastAPI, fastapi.APIRouter] instance: Union[fastapi.FastAPI, fastapi.APIRouter]
name: str name: str
def __init__(self, root: bool = False, name: Optional[str] = None): def __init__(self, root: bool = False, **kwargs):
super().__init__(root=root, **kwargs)
self.instance = fastapi.FastAPI() if root else fastapi.APIRouter() self.instance = fastapi.FastAPI() if root else fastapi.APIRouter()
self.name = "root" if root else name or ""
def add(self, interface: Interface): def add(self, interface: Interface):
assert isinstance(interface.instance, fastapi.APIRouter) assert isinstance(interface.instance, fastapi.APIRouter)
@ -76,11 +70,28 @@ class Interface:
interface.instance, prefix=f"/{interface.name}", tags=[interface.name] interface.instance, prefix=f"/{interface.name}", tags=[interface.name]
) )
def cli(self, *args, **kwargs): def prepare_params(
def decorator(func): self,
return func params: list[inspect.Parameter],
as_body: bool,
as_body_except: Optional[list[str]] = None,
) -> list[Union[list[inspect.Parameter], inspect.Parameter]]:
params = self.filter_params(params)
return decorator if as_body_except:
body_params = []
rest = []
for param in params:
if param.name not in as_body_except:
body_params.append(param)
else:
rest.append(param)
return [body_params, *rest]
if as_body:
return [params]
return params
def api( def api(
self, self,
@ -88,38 +99,40 @@ class Interface:
method: str = "get", method: str = "get",
as_body: bool = False, as_body: bool = False,
as_body_except: Optional[list[str]] = None, as_body_except: Optional[list[str]] = None,
**kwargs, **extra_data,
): ):
as_body = as_body if not as_body_except else True as_body = as_body if not as_body_except else True
def decorator(func): def decorator(func):
signature = inspect.signature(func) signature = inspect.signature(func)
override_params = [] override_params = []
params = alter_params_for_body( params = self.prepare_params(
signature.parameters.values(), as_body, as_body_except signature.parameters.values(), as_body, as_body_except
) )
path_param = get_path_param(route) local_data = merge_dicts(self.local_data, extra_data)
params_doc = get_params_doc(func.__doc__)
paths = parse_api_route(route)
for param in params: for param in params:
if isinstance(param, list): if isinstance(param, list):
override_params.append(params_to_body(param, func.__name__)) override_params.append(
params_to_body(param, local_data, params_doc, func.__name__)
)
else: else:
default_kwargs = kwargs.get(param.name, {}) param_kwargs = local_data.get(param.name, {})
default_cls = ( param_kwargs["description"] = params_doc.get(param.name, None)
fastapi.Path if param.name == path_param else fastapi.Query param_default = (
param.default
if not param.default == param.empty
else ... # required
) )
if param.default is None: if param.name in paths:
default_value = default_cls(None, **default_kwargs) param_default = fastapi.Path(param_default, **param_kwargs)
elif param.default is param.empty:
default_value = default_cls(..., **default_kwargs)
else: else:
default_value = default_cls(param.default, **default_kwargs) param_default = fastapi.Query(param_default, **param_kwargs)
override_params.append(param.replace(default=default_value)) override_params.append(param.replace(default=param_default))
route_func = getattr(self.instance, method)(route)
override_signature = signature.replace(parameters=tuple(override_params))
if as_body: if as_body:
@ -132,19 +145,18 @@ class Interface:
new_kwargs[kwarg] = value new_kwargs[kwarg] = value
return func(*args, **new_kwargs) return func(*args, **new_kwargs)
body_to_args_back.__name__ = func.__name__ route_func = override_function(
body_to_args_back.__signature__ = override_signature func, signature, override_params, decorator=body_to_args_back
route_func(body_to_args_back)
else:
func_copy = types.FunctionType(
func.__code__,
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__,
) )
func_copy.__signature__ = override_signature else:
route_func(func_copy) route_func = override_function(func, signature, override_params)
summary = func.__doc__.split("\n\n")[0] if func.__doc__ else None
getattr(self.instance, method)(
route, summary=summary, deprecated=local_data.get("deprecated")
)(route_func)
self.clear_local_data()
return func return func

94
src/interface/base.py Normal file
View file

@ -0,0 +1,94 @@
import inspect
import types
import re
from enum import Enum
from typing import Any, Callable, Optional, TypeVar
ViewFunc = TypeVar("ViewFunc")
class InterfaceKind(Enum):
API = "api"
CLI = "cli"
def pass_func(func: ViewFunc) -> ViewFunc:
return func
def merge_dicts(*dicts: dict[str, Any]) -> dict[str, Any]:
merged: dict[str, Any] = {}
for dict_ in dicts:
for key, value in dict_.items():
if key not in merged or isinstance(merged[key], (str, int, float, bool)):
merged[key] = value
else:
merged[key] |= value
return merged
def get_params_doc(docstring: Optional[str]) -> dict[str, str]:
if not docstring:
return {}
return {
param_name: param_desc
for param_name, param_desc in re.findall(
r"- \*\*(\w+)\*\*: (.*)", docstring, re.MULTILINE
)
}
def override_function(
func: Callable,
func_signature: inspect.Signature,
new_params: list[inspect.Parameter],
decorator: Optional[Callable] = None,
name: Optional[str] = None,
doc: Optional[str] = None,
) -> Callable:
returned_func = decorator or types.FunctionType(
func.__code__,
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__,
)
returned_func.__name__ = name or func.__name__
returned_func.__doc__ = doc or func.__doc__
returned_func.__signature__ = func_signature.replace(parameters=tuple(new_params))
return returned_func
class BaseInterface:
kind: InterfaceKind
local_data: dict[str, Any] = {}
def __init__(self, root: bool = False, name: str = "", help: str = ""):
self.name = "root" if root else name or ""
self.help = help
def __call__(self, *args, **kwargs):
self.local_data = kwargs
return pass_func
def clear_local_data(self):
self.local_data = {}
def filter_params(self, params: list[inspect.Parameter]) -> list[inspect.Parameter]:
private = self.local_data.get("private", [])
if private:
return [param for param in params if param.name not in private]
return params
def api(self, *args, **kwargs):
return pass_func
def cli(self, *args, **kwargs):
return pass_func

View file

@ -1,13 +1,23 @@
from __future__ import annotations from __future__ import (
annotations,
) # Enable self reference a class in its own method arguments
import inspect import inspect
import typer import typer
import yaml import yaml
from typing import Any, Optional from typing import Any, Callable
from rich import print as rprint from rich import print as rprint
from rich.syntax import Syntax from rich.syntax import Syntax
from yunohost.interface.base import (
BaseInterface,
InterfaceKind,
merge_dicts,
get_params_doc,
override_function,
)
def parse_cli_command(command: str) -> tuple[str, list[str]]: def parse_cli_command(command: str) -> tuple[str, list[str]]:
command, *args = command.split(" ") command, *args = command.split(" ")
@ -19,63 +29,74 @@ def print_as_yaml(data: Any):
rprint(Syntax(data, "yaml", background_color="default")) rprint(Syntax(data, "yaml", background_color="default"))
class Interface: class Interface(BaseInterface):
type = "cli" kind = InterfaceKind.CLI
instance: typer.Typer instance: typer.Typer
name: str name: str
def __init__(self, root: bool = False, name: Optional[str] = None): def __init__(self, root: bool = False, **kwargs):
self.instance = typer.Typer() super().__init__(root=root, **kwargs)
self.name = "root" if root else name or "" self.instance = typer.Typer(rich_markup_mode="markdown")
def add(self, interface: Interface): def add(self, interface: Interface):
self.instance.add_typer(interface.instance, name=interface.name) self.instance.add_typer(
interface.instance, name=interface.name, help=interface.help
)
def cli(self, command_def: str, **kwargs): def cli(self, command_def: str, **extra_data):
def decorator(func): def decorator(func: Callable):
signature = inspect.signature(func) signature = inspect.signature(func)
override_params = [] override_params = []
params = self.filter_params(signature.parameters.values())
local_data = merge_dicts(self.local_data, extra_data)
params_doc = get_params_doc(func.__doc__)
command, args = parse_cli_command(command_def) command, args = parse_cli_command(command_def)
for param in signature.parameters.values(): for param in params:
param_kwargs = local_data.get(param.name, {})
param_kwargs["help"] = params_doc.get(param.name, None)
if param_kwargs.pop("deprecated", False):
param_kwargs["rich_help_panel"] = "Deprecated Options"
if param.name not in args and not param_kwargs.get("hidden", False):
param_kwargs["prompt"] = True
# Auto setup typer Argument or Option kwargs
default_kwargs = kwargs.get(param.name, {})
# if param.name not in args and not default_kwargs.get("hidden", False):
# default_kwargs["prompt"] = True
if param.name == "password": if param.name == "password":
default_kwargs["confirmation_prompt"] = True param_kwargs["confirmation_prompt"] = True
default_kwargs["hide_input"] = True param_kwargs["hide_input"] = True
# Define new default value for typer # Populate default param value with typer.Argument|Option
default_cls = typer.Argument if param.name in args else typer.Option param_default = (
if param.default is None: param.default
default_value = default_cls(None, **default_kwargs) if not param.default == param.empty
elif param.default is param.empty: else ... # required
default_value = default_cls(..., **default_kwargs) )
if param.name in args:
param_default = typer.Argument(param_default, **param_kwargs)
else: else:
default_value = default_cls(param.default, **default_kwargs) param_default = typer.Option(param_default, **param_kwargs)
override_params.append(param.replace(default=default_value)) override_params.append(param.replace(default=param_default))
def hook_results(*args, **kwargs): def hook_results(*args, **kwargs):
results = func(*args, **kwargs) results = func(*args, **kwargs)
print_as_yaml(results) print_as_yaml(results)
return results return results
hook_results.__name__ = func.__name__ command_func = override_function(
hook_results.__signature__ = signature.replace( func,
parameters=tuple(override_params) signature,
override_params,
decorator=hook_results,
doc=func.__doc__.split("\b")[0] if func.__doc__ else None,
) )
self.instance.command(command)(hook_results) self.instance.command(
command, deprecated=local_data.get("deprecated", False)
)(command_func)
self.clear_local_data()
return func return func
return decorator return decorator
def api(self, *args, **kwargs):
def decorator(func):
return func
return decorator