From 669bccff57f4bd3658aa342e032c48581d56667a Mon Sep 17 00:00:00 2001 From: axolotle Date: Tue, 17 Jan 2023 18:48:03 +0100 Subject: [PATCH] use pydantic Field as base info defition --- src/interface/base.py | 187 ++++++++++++++++++++++++++++++++++++++--- src/interface/types.py | 15 ++++ 2 files changed, 192 insertions(+), 10 deletions(-) create mode 100644 src/interface/types.py diff --git a/src/interface/base.py b/src/interface/base.py index e3bbc9c64..19dff2d02 100644 --- a/src/interface/base.py +++ b/src/interface/base.py @@ -1,9 +1,29 @@ +from __future__ import ( + annotations, +) # Enable self reference a class in its own method arguments + import inspect import types import re from enum import Enum -from typing import Any, Callable, Optional, TypeVar +from typing import ( + Annotated, + Any, + Callable, + Iterable, + Iterator, + Optional, + TypeVar, + Union, + get_origin, + get_args, +) + +import pydantic + +from yunohost.interface.types import PrivateParam + ViewFunc = TypeVar("ViewFunc") @@ -42,9 +62,12 @@ def get_params_doc(docstring: Optional[str]) -> dict[str, str]: } -def validate_pattern(pattern: str, value: str, name: str): +def validate_pattern(pattern: str, value: str, name: Optional[str] = None): + if not re.match(pattern, value, re.UNICODE): - raise ValueError(f"'{value}' does'nt match pattern '{pattern}'") + error = name if name else "'{value}' does'nt match pattern '{pattern}'" + raise ValueError(error.format(value=value, pattern=pattern)) + return value @@ -65,7 +88,7 @@ def override_function( ) returned_func.__name__ = name or func.__name__ returned_func.__doc__ = doc or func.__doc__ - returned_func.__signature__ = func_signature.replace(parameters=tuple(new_params)) + returned_func.__signature__ = func_signature.replace(parameters=tuple(new_params)) # type: ignore return returned_func @@ -92,16 +115,160 @@ class BaseInterface: def clear_local_data(self): self.local_data = {} - def filter_params(self, params: list[inspect.Parameter]) -> list[inspect.Parameter]: - private = self.local_data.get("private", []) + @staticmethod + def build_fields( + cls, + params: Iterable[inspect.Parameter], + annotations: dict[str, Any], + doc: dict[str, str], + positional_params: list[str], + ) -> Iterator[tuple[inspect.Parameter, Optional[pydantic.fields.FieldInfo]]]: - return [ - param for param in params - if param.name not in private and param.name != "operation_logger" - ] + for param in params: + annotation = annotations[param.name] + field: Optional[pydantic.fields.FieldInfo] = None + description = doc.get(param.name, None) + + if get_origin(annotation) is Annotated: + annotation, field = get_args(annotation) + + if get_origin(field) is PrivateParam: + field = None + elif isinstance(field, pydantic.fields.FieldInfo): + update_field_from_annotation( + field, + param.default, + name=param.name, + description=description, + positional=param.name in positional_params, + ) + else: + raise Exception( + "Views function paramaters can only be 'Annotated[Any, PrivateParam | Param]' but found '{new_param}'" + ) + + else: + field = Field( + param.default, + name=param.name, + description=description, + positional=param.name in positional_params, + ) + + param = param.replace(annotation=annotation) + + yield param, field def api(self, *args, **kwargs): return pass_func def cli(self, *args, **kwargs): return pass_func + + +def Field( + default: Any = ..., + *, + name: Optional[str] = None, + positional: bool = False, + param_decls: Optional[list[str]] = None, + deprecated: bool = False, + description: Optional[str] = None, + hidden: bool = False, + pattern: Optional[Union[str, tuple[str, str]]] = None, + ask: Union[str, bool] = False, + confirm: bool = False, + redac: bool = False, + file: bool = False, + panel: Optional[str] = None, + example: Optional[str] = None, + # default_factory: Optional[NoArgAnyCallable] = None, + # alias: str = None, + # title: str = None, + # exclude: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None, + # include: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None, + # const: bool = None, + # gt: float = None, + # ge: float = None, + # lt: float = None, + # le: float = None, + # multiple_of: float = None, + # allow_inf_nan: bool = None, + # max_digits: int = None, + # decimal_places: int = None, + # min_items: int = None, + # max_items: int = None, + # unique_items: bool = None, + # min_length: int = None, + # max_length: int = None, + # allow_mutation: bool = True, + # discriminator: str = None, + # repr: bool = True, + **kwargs: Any, +) -> pydantic.fields.FieldInfo: + + pattern_name = None + if isinstance(pattern, tuple): + pattern_name, pattern = pattern + + return pydantic.fields.Field( + default=... if default is inspect.Parameter.empty else default, + # default_factory=default_factory, + # alias=alias, + # title=title, + description=description, # type: ignore + # exclude=exclude, + # include=include, + # const=const, + # gt=gt, + # ge=ge, + # lt=lt, + # le=le, + # multiple_of=multiple_of, + # allow_inf_nan=allow_inf_nan, + # max_digits=max_digits, + # decimal_places=decimal_places, + # min_items=min_items, + # max_items=max_items, + # unique_items=unique_items, + # min_length=min_length, + # max_length=max_length, + # allow_mutation=allow_mutation, + regex=pattern, # type: ignore + # discriminator=discriminator, + # repr=repr, + # Yunohost custom + name=name, + param_decls=param_decls, + positional=positional, + deprecated=deprecated, + hidden=hidden, + # Typer Option only + ask=False if deprecated else ask, + confirm=confirm, + redac=redac, + # Type + file=file, + # Rich + panel=panel, + pattern_name=pattern_name, + example=example, + # **kwargs, + ) + + +def update_field_from_annotation( + field: pydantic.fields.FieldInfo, + default: Any, + name: Optional[str] = None, + description: Optional[str] = None, + positional: bool = False, +): + field.default = ... if default is inspect.Parameter.empty else default + if name: + field.extra["name"] = name + if description: + field.description = description + if positional: + field.extra["positional"] = positional + field.extra["ask"] = False diff --git a/src/interface/types.py b/src/interface/types.py new file mode 100644 index 000000000..5156b7472 --- /dev/null +++ b/src/interface/types.py @@ -0,0 +1,15 @@ +from __future__ import ( + annotations, +) # Enable self reference a class in its own method arguments + +from typing import Annotated, Generic, TypeVar + + +T = TypeVar("T") + + +class PrivateParam(Generic[T]): + ... + + +Private = Annotated[T, PrivateParam[T]]