From 396c39a5d399e7e93fd17801c1731822c3577770 Mon Sep 17 00:00:00 2001 From: axolotle Date: Tue, 17 Jan 2023 18:53:31 +0100 Subject: [PATCH] add Field to fastapi Path|Query|Body factories --- src/interface/api.py | 132 +++++++++++++++++++++++++++++-------------- 1 file changed, 91 insertions(+), 41 deletions(-) diff --git a/src/interface/api.py b/src/interface/api.py index a658f5874..318fbc453 100644 --- a/src/interface/api.py +++ b/src/interface/api.py @@ -10,9 +10,7 @@ import fastapi import pydantic import starlette -from typing import Any, Optional, Union - -from pydantic.error_wrappers import ErrorWrapper +from typing import Optional, Union, get_type_hints from yunohost.interface.base import ( BaseInterface, @@ -108,48 +106,44 @@ class Interface(BaseInterface): as_body_except: Optional[list[str]] = None, **extra_data, ): - as_body = as_body if not as_body_except else True - def decorator(func): - signature = inspect.signature(func) - override_params = [] - params = self.prepare_params( - signature.parameters.values(), as_body, as_body_except - ) local_data = merge_dicts(self.local_data, extra_data) - params_doc = get_params_doc(func.__doc__) - paths = parse_api_route(route) - for param in params: - if isinstance(param, list): - override_params.append( - params_to_body(param, local_data, params_doc, func.__name__) - ) + signature = inspect.signature(func) + annotations = get_type_hints(func, include_extras=True) + params = signature.parameters.values() + doc = get_params_doc(func.__doc__) + positional_params = parse_api_route(route) + + forward_params = [] + override_params = [] + body_fields = {} + for param, field in self.build_fields( + Interface, params, annotations, doc, positional_params + ): + + if field and field.extra.get("file"): + param = param.replace(annotation=fastapi.UploadFile) + + forward_params.append(param) + + if not field or field.extra["deprecated"]: + continue + + if as_body or ( + as_body_except is not None and param.name not in as_body_except + ): + body_fields[param.name] = (param.annotation, field) else: - param_kwargs = local_data.get(param.name, {}) - param_kwargs["description"] = params_doc.get(param.name, None) - param_default = ( - param.default - if not param.default == param.empty - else ... # required + override_param = param.replace( + default=field_to_fastapi_default(field) ) + override_params.append(override_param) - pattern = param_kwargs.pop("pattern", None) - if pattern: - # FIXME for now throw generic error (need to catch and update text) - param_kwargs["regex"] = pattern - - if param_kwargs.pop("file", False): - new_param = param.replace( - annotation=fastapi.UploadFile, - default=param_default - ) - elif param.name in paths: - new_param = param.replace(default=fastapi.Path(param_default, **param_kwargs)) - else: - new_param = param.replace(default=fastapi.Query(param_default, **param_kwargs)) - - override_params.append(new_param) + if body_fields: + override_params.insert( + 0, field_to_fastapi_body_default(body_fields, func.__name__) + ) def hook_results(*args, **kwargs): new_kwargs = {} @@ -161,7 +155,7 @@ class Interface(BaseInterface): new_kwargs = value.dict() | new_kwargs elif isinstance(value, starlette.datastructures.UploadFile): # views expects a opened file (fastapi UploadFile is a bytes SpooledTemporaryFile) - new_kwargs[name] = codecs.iterdecode(value.file, 'utf-8') + new_kwargs[name] = codecs.iterdecode(value.file, "utf-8") opened_files.append(name) else: new_kwargs[name] = value @@ -171,7 +165,13 @@ class Interface(BaseInterface): except YunohostValidationError as e: # Try to mimic Pydantic validation errors # FIXME replace dummy error information - raise fastapi.exceptions.RequestValidationError([ErrorWrapper(ValueError(e.strerror), ("query", "test"))]) + raise fastapi.exceptions.RequestValidationError( + [ + pydantic.errors.ErrorWrapper( + ValueError(e.strerror), ("query", "test") + ) + ] + ) except: raise finally: @@ -197,3 +197,53 @@ class Interface(BaseInterface): return func return decorator + + +def field_to_fastapi_body_default( + fields: dict[str, tuple[type, pydantic.fields.FieldInfo]], + func_name: str, +) -> inspect.Parameter: + + model = pydantic.create_model( + snake_to_camel_case(func_name), + **fields, + ) + default = fastapi.Body( + ..., + example={ + name: field.extra["example"] or f"missing example of {str(t)}" + if field.default in (..., None) + else field.default + for name, (t, field) in fields.items() + }, + ) + + return inspect.Parameter( + func_name.split("_")[0], + inspect.Parameter.POSITIONAL_ONLY, + default=default, + annotation=model, + ) + + +def field_to_fastapi_default( + field: pydantic.fields.FieldInfo, +) -> Union[fastapi.params.Path, fastapi.params.Query]: + generic = { + "description": field.description, + "regex": field.regex, + "include_in_schema": not field.extra["hidden"], + "deprecated": field.extra["deprecated"], + "example": field.extra["example"], + } + + if field.extra["positional"]: + return fastapi.Path( + field.default, + **generic, + ) + + return fastapi.Query( + field.default, + **generic, + )