add Field to fastapi Path|Query|Body factories

This commit is contained in:
axolotle 2023-01-17 18:53:31 +01:00
parent 88bc34adec
commit 396c39a5d3

View file

@ -10,9 +10,7 @@ import fastapi
import pydantic import pydantic
import starlette import starlette
from typing import Any, Optional, Union from typing import Optional, Union, get_type_hints
from pydantic.error_wrappers import ErrorWrapper
from yunohost.interface.base import ( from yunohost.interface.base import (
BaseInterface, BaseInterface,
@ -108,48 +106,44 @@ class Interface(BaseInterface):
as_body_except: Optional[list[str]] = None, as_body_except: Optional[list[str]] = None,
**extra_data, **extra_data,
): ):
as_body = as_body if not as_body_except else True
def decorator(func): def decorator(func):
signature = inspect.signature(func)
override_params = []
params = self.prepare_params(
signature.parameters.values(), as_body, as_body_except
)
local_data = merge_dicts(self.local_data, extra_data) local_data = merge_dicts(self.local_data, extra_data)
params_doc = get_params_doc(func.__doc__)
paths = parse_api_route(route)
for param in params: signature = inspect.signature(func)
if isinstance(param, list): annotations = get_type_hints(func, include_extras=True)
override_params.append( params = signature.parameters.values()
params_to_body(param, local_data, params_doc, func.__name__) doc = get_params_doc(func.__doc__)
) positional_params = parse_api_route(route)
forward_params = []
override_params = []
body_fields = {}
for param, field in self.build_fields(
Interface, params, annotations, doc, positional_params
):
if field and field.extra.get("file"):
param = param.replace(annotation=fastapi.UploadFile)
forward_params.append(param)
if not field or field.extra["deprecated"]:
continue
if as_body or (
as_body_except is not None and param.name not in as_body_except
):
body_fields[param.name] = (param.annotation, field)
else: else:
param_kwargs = local_data.get(param.name, {}) override_param = param.replace(
param_kwargs["description"] = params_doc.get(param.name, None) default=field_to_fastapi_default(field)
param_default = (
param.default
if not param.default == param.empty
else ... # required
) )
override_params.append(override_param)
pattern = param_kwargs.pop("pattern", None) if body_fields:
if pattern: override_params.insert(
# FIXME for now throw generic error (need to catch and update text) 0, field_to_fastapi_body_default(body_fields, func.__name__)
param_kwargs["regex"] = pattern )
if param_kwargs.pop("file", False):
new_param = param.replace(
annotation=fastapi.UploadFile,
default=param_default
)
elif param.name in paths:
new_param = param.replace(default=fastapi.Path(param_default, **param_kwargs))
else:
new_param = param.replace(default=fastapi.Query(param_default, **param_kwargs))
override_params.append(new_param)
def hook_results(*args, **kwargs): def hook_results(*args, **kwargs):
new_kwargs = {} new_kwargs = {}
@ -161,7 +155,7 @@ class Interface(BaseInterface):
new_kwargs = value.dict() | new_kwargs new_kwargs = value.dict() | new_kwargs
elif isinstance(value, starlette.datastructures.UploadFile): elif isinstance(value, starlette.datastructures.UploadFile):
# views expects a opened file (fastapi UploadFile is a bytes SpooledTemporaryFile) # views expects a opened file (fastapi UploadFile is a bytes SpooledTemporaryFile)
new_kwargs[name] = codecs.iterdecode(value.file, 'utf-8') new_kwargs[name] = codecs.iterdecode(value.file, "utf-8")
opened_files.append(name) opened_files.append(name)
else: else:
new_kwargs[name] = value new_kwargs[name] = value
@ -171,7 +165,13 @@ class Interface(BaseInterface):
except YunohostValidationError as e: except YunohostValidationError as e:
# Try to mimic Pydantic validation errors # Try to mimic Pydantic validation errors
# FIXME replace dummy error information # FIXME replace dummy error information
raise fastapi.exceptions.RequestValidationError([ErrorWrapper(ValueError(e.strerror), ("query", "test"))]) raise fastapi.exceptions.RequestValidationError(
[
pydantic.errors.ErrorWrapper(
ValueError(e.strerror), ("query", "test")
)
]
)
except: except:
raise raise
finally: finally:
@ -197,3 +197,53 @@ class Interface(BaseInterface):
return func return func
return decorator return decorator
def field_to_fastapi_body_default(
fields: dict[str, tuple[type, pydantic.fields.FieldInfo]],
func_name: str,
) -> inspect.Parameter:
model = pydantic.create_model(
snake_to_camel_case(func_name),
**fields,
)
default = fastapi.Body(
...,
example={
name: field.extra["example"] or f"missing example of {str(t)}"
if field.default in (..., None)
else field.default
for name, (t, field) in fields.items()
},
)
return inspect.Parameter(
func_name.split("_")[0],
inspect.Parameter.POSITIONAL_ONLY,
default=default,
annotation=model,
)
def field_to_fastapi_default(
field: pydantic.fields.FieldInfo,
) -> Union[fastapi.params.Path, fastapi.params.Query]:
generic = {
"description": field.description,
"regex": field.regex,
"include_in_schema": not field.extra["hidden"],
"deprecated": field.extra["deprecated"],
"example": field.extra["example"],
}
if field.extra["positional"]:
return fastapi.Path(
field.default,
**generic,
)
return fastapi.Query(
field.default,
**generic,
)