Source code for simdb.remote.core.pydantic_utils

from __future__ import annotations

import contextlib
import copy
import functools
import gzip
import inspect
import logging
from typing import (
    Annotated,
    Any,
    get_args,
    get_origin,
)

from flask import Response, request
from flask_restx import Api, Namespace
from pydantic import BaseModel, ValidationError

from simdb.remote.core.errors import error as _error
from simdb.remote.models import ErrorResponse

logger = logging.getLogger(__name__)


[docs] class ResponseException(Exception): """Raised when a client error has occurred in the request (HTTP 4xx).""" def __init__(self, message, return_code=400): super().__init__(message) self.message = message self.return_code = return_code
[docs] class ServerException(Exception): """Raised when an unexpected server-side error has occurred (HTTP 500).""" def __init__(self, message, return_code=500): super().__init__(message) self.message = message self.return_code = return_code
# --------------------------------------------------------------------------- # Marker classes for Annotated-style parameter declarations # --------------------------------------------------------------------------- class _ParamSource: """Base class for parameter source markers."""
[docs] class Body(_ParamSource): """Marker: populate this parameter from the JSON request body."""
[docs] class Query(_ParamSource): """Marker: populate this parameter from ``request.args`` (query string)."""
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _register_defs(ns: Namespace | Api, defs): for name, schema in defs.items(): # Clean the schema: rewrite refs and remove internal $defs clean_schema = copy.deepcopy(schema) children = clean_schema.pop("$defs", {}) ns.schema_model(name, clean_schema) _register_defs(ns, children) def _collect_and_register(ns: Namespace | Api, model: type[BaseModel]): """ Registers a Pydantic model and all nested dependencies to a Flask-RESTX Namespace. """ full_schema = model.model_json_schema(ref_template="#/definitions/{model}") all_defs = full_schema.get("$defs", {}) # 1. Register all sub-models found in $defs first _register_defs(ns, all_defs) # 2. Register the root model root_name = model.__name__ root_schema = copy.deepcopy(full_schema) root_schema.pop("$defs", None) return ns.schema_model(root_name, root_schema) def _get_annotated_params( f: Any, ) -> list[tuple[str, type[BaseModel], _ParamSource]]: """Inspect *f*'s signature and return a list of ``(param_name, model, source)`` tuples for every parameter annotated as ``Annotated[SomePydanticModel, Source()]``. """ results = [] sig = inspect.signature(f) for param_name, param in sig.parameters.items(): annotation = param.annotation if annotation is inspect.Parameter.empty: continue if get_origin(annotation) is Annotated: args = get_args(annotation) if len(args) >= 2: model_type = args[0] source = args[1] if ( isinstance(source, _ParamSource) and isinstance(model_type, type) and issubclass(model_type, BaseModel) ): results.append((param_name, model_type, source)) return results def _validate_param( model: type[BaseModel], source: _ParamSource, ) -> BaseModel: """Validate and return a Pydantic model instance from the appropriate part of the current Flask request. Raises ------ ValidationError If the data does not conform to the model. ValueError If the request body is missing or not valid JSON (for Body sources). """ if isinstance(source, Header): # Convert Werkzeug Headers to a plain dict (lowercase keys) raw = {k.lower(): v for k, v in request.headers.items()} return model.model_validate(raw) elif isinstance(source, Query): # Convert ImmutableMultiDict to a plain dict (lists for multi-values) raw = request.args.to_dict(flat=False) # Flatten single-value lists for convenience flat = {k: v[0] if len(v) == 1 else v for k, v in raw.items()} return model.model_validate(flat) else: enc = (request.headers.get("Content-Encoding") or "").lower() request_data = request.get_data(cache=False) if request_data is None: raise ValueError("Invalid or missing JSON body") if enc == "gzip": with contextlib.suppress(OSError): request_data = gzip.decompress(request_data) return model.model_validate_json(request_data) # --------------------------------------------------------------------------- # FastAPI-style route decorator # ---------------------------------------------------------------------------
[docs] def pydantic_validate( ns: Namespace | Api, *, response_model: type[BaseModel] | None = None, error_model: type[BaseModel] = ErrorResponse, client_error_codes: tuple[int, ...] = (400,), ) -> Any: """Decorator factory that wires up Pydantic validation for a Flask-RESTX endpoint. Inspects the decorated function's signature for parameters annotated with ``Annotated[SomePydanticModel, Header()]``, ``Annotated[SomePydanticModel, Body()]`` or ``Annotated[SomePydanticModel, Query()]``, validates the corresponding parts of the incoming request, and injects the validated model instances as keyword arguments. All discovered input models are automatically registered with *ns* for Swagger/OpenAPI documentation. Body models are registered as ``@ns.expect`` models; header/query models are registered as parser arguments. If the function's return annotation is a ``BaseModel`` subclass (or if *response_model* is provided explicitly) the return value is automatically serialised with ``model_dump(mode="json")`` and wrapped in ``jsonify``. Error handling distinguishes between client errors and server errors: - :class:`ResponseException` (and request validation errors) → HTTP 4xx (default 400). Use ``return_code`` to customise (e.g. 404, 422). - :class:`ServerException` → HTTP 5xx (default 500). Use for explicit server-side failures. - Any other unhandled :class:`Exception` → HTTP 500, logged as an error. Parameters ---------- ns: The Flask-RESTX ``Namespace`` (or ``Api``) to register models on. response_model: Optional explicit response model. If ``None`` the decorator tries to infer it from the function's return annotation. error_model: Pydantic model used to serialise error responses. client_error_codes: HTTP status codes to document as client error responses in Swagger. Returns ------- A decorator suitable for use on Flask-RESTX ``Resource`` methods. Example ------- .. code-block:: python from typing import Annotated from simdb.remote.core.pydantic_utils import pydantic_validate, Header, Body class SimulationList(Resource): @pydantic_validate(api) def get( self, user: User, pagination: Annotated[PaginationData, Header()], ) -> PaginatedResponse: ... @pydantic_validate(api) def post( self, user: User, body: Annotated[SimulationPostData, Body()], ) -> SimulationPostResponse: ... """ def decorator(f): annotated_params = _get_annotated_params(f) # Determine response model from return annotation if not given explicitly _response_model = response_model _error_model = error_model if _response_model is None: ret = inspect.signature(f).return_annotation if ( ret is not inspect.Parameter.empty and inspect.isclass(ret) and issubclass(ret, BaseModel) ): _response_model = ret # Register all input models with the namespace _registered = {} _body_schema = None for _param_name, model_type, source in annotated_params: schema = _collect_and_register(ns, model_type) if isinstance(source, Body): _body_schema = schema # Register response model _resp_schema = None if _response_model is not None: _resp_schema = _collect_and_register(ns, _response_model) _error_schema = None if _error_model is not None: _error_schema = _collect_and_register(ns, _error_model) @functools.wraps(f) def wrapper(*args, **kwargs): for param_name, model_type, source in annotated_params: try: validated = _validate_param(model_type, source) except ValueError as exc: return _error(str(exc)) except ValidationError as exc: first_error = exc.errors()[0] loc = " -> ".join(str(loc) for loc in first_error["loc"]) msg = f"Validation error at '{loc}': {first_error['msg']}" return _error(msg) kwargs[param_name] = validated try: result = f(*args, **kwargs) except ResponseException as err: return Response( response=_error_model(error=err.message).model_dump_json( exclude_none=False ), status=err.return_code, mimetype="application/json", ) except ServerException as err: logger.error("Server error in %s: %s", f.__qualname__, err.message) return Response( response=_error_model(error=err.message).model_dump_json( exclude_none=False ), status=err.return_code, mimetype="application/json", ) except Exception as err: logger.exception("Unhandled exception in %s", f.__qualname__) return Response( response=_error_model(error=str(err)).model_dump_json( exclude_none=False ), status=500, mimetype="application/json", ) if isinstance(result, ErrorResponse): return Response( response=result.model_dump_json(exclude_none=False), status=400, mimetype="application/json", ) if isinstance(result, BaseModel): return Response( result.model_dump_json(exclude_none=False), mimetype="application/json", ) return result if _body_schema is not None: wrapper = ns.expect(_body_schema, validate=False)(wrapper) if _resp_schema is not None: wrapper = ns.response(200, "Success", _resp_schema)(wrapper) if _error_schema is not None: for error_code in client_error_codes: wrapper = ns.response(error_code, "Client error", _error_schema)( wrapper ) wrapper = ns.response(500, "Server error", _error_schema)(wrapper) for _param_name, model_type, source in annotated_params: if isinstance(source, (Query, Header)): location = "query" if isinstance(source, Query) else "header" schema = model_type.model_json_schema() properties = schema.get("properties", {}) required_fields = schema.get("required", []) for field_name, field_props in properties.items(): wrapper = ns.param( name=field_name, description=field_props.get("description", ""), _in=location, required=(field_name in required_fields), type=field_props.get("type", "string"), default=field_props.get("default"), )(wrapper) return wrapper return decorator