From 859242cd0482fc9299accfbd779b1043ded4ea8a Mon Sep 17 00:00:00 2001 From: p1c2u Date: Wed, 19 Jul 2023 18:50:04 +0100 Subject: [PATCH] Allow unmarshalling kwargs for unmarshalling processors --- docs/integrations.rst | 48 +++++++++++++++ openapi_core/contrib/falcon/middlewares.py | 4 ++ openapi_core/contrib/flask/decorators.py | 6 +- openapi_core/contrib/flask/views.py | 10 ++-- openapi_core/unmarshalling/processors.py | 10 +++- openapi_core/validation/processors.py | 10 +++- .../falcon/data/v3.0/falconproject/openapi.py | 5 +- .../contrib/flask/test_flask_decorator.py | 60 ++++++++++--------- .../contrib/flask/test_flask_views.py | 8 ++- 9 files changed, 120 insertions(+), 41 deletions(-) diff --git a/docs/integrations.rst b/docs/integrations.rst index 77e0ee85..4a5fbd26 100644 --- a/docs/integrations.rst +++ b/docs/integrations.rst @@ -129,6 +129,23 @@ The Falcon API can be integrated by ``FalconOpenAPIMiddleware`` middleware. middleware=[openapi_middleware], ) +Additional customization parameters can be passed to the middleware. + +.. code-block:: python + :emphasize-lines: 5 + + from openapi_core.contrib.falcon.middlewares import FalconOpenAPIMiddleware + + openapi_middleware = FalconOpenAPIMiddleware.from_spec( + spec, + extra_format_validators=extra_format_validators, + ) + + app = falcon.App( + # ... + middleware=[openapi_middleware], + ) + After that you will have access to validation result object with all validated request data from Falcon view through request context. .. code-block:: python @@ -192,6 +209,18 @@ Flask views can be integrated by ``FlaskOpenAPIViewDecorator`` decorator. def home(): return "Welcome home" +Additional customization parameters can be passed to the decorator. + +.. code-block:: python + :emphasize-lines: 5 + + from openapi_core.contrib.flask.decorators import FlaskOpenAPIViewDecorator + + openapi = FlaskOpenAPIViewDecorator.from_spec( + spec, + extra_format_validators=extra_format_validators, + ) + If you want to decorate class based view you can use the decorators attribute: .. code-block:: python @@ -224,6 +253,25 @@ As an alternative to the decorator-based integration, a Flask method based views view_func=MyView.as_view('home', spec), ) +Additional customization parameters can be passed to the view. + +.. code-block:: python + :emphasize-lines: 10 + + from openapi_core.contrib.flask.views import FlaskOpenAPIView + + class MyView(FlaskOpenAPIView): + def get(self): + return "Welcome home" + + app.add_url_rule( + '/home', + view_func=MyView.as_view( + 'home', spec, + extra_format_validators=extra_format_validators, + ), + ) + Request parameters ~~~~~~~~~~~~~~~~~~ diff --git a/openapi_core/contrib/falcon/middlewares.py b/openapi_core/contrib/falcon/middlewares.py index 287ea5a9..752dd85f 100644 --- a/openapi_core/contrib/falcon/middlewares.py +++ b/openapi_core/contrib/falcon/middlewares.py @@ -32,11 +32,13 @@ def __init__( request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, + **unmarshaller_kwargs: Any, ): super().__init__( spec, request_unmarshaller_cls=request_unmarshaller_cls, response_unmarshaller_cls=response_unmarshaller_cls, + **unmarshaller_kwargs, ) self.request_class = request_class or self.request_class self.response_class = response_class or self.response_class @@ -51,6 +53,7 @@ def from_spec( request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, + **unmarshaller_kwargs: Any, ) -> "FalconOpenAPIMiddleware": return cls( spec, @@ -59,6 +62,7 @@ def from_spec( request_class=request_class, response_class=response_class, errors_handler=errors_handler, + **unmarshaller_kwargs, ) def process_request(self, req: Request, resp: Response) -> None: # type: ignore diff --git a/openapi_core/contrib/flask/decorators.py b/openapi_core/contrib/flask/decorators.py index 1da178ac..1d360ae4 100644 --- a/openapi_core/contrib/flask/decorators.py +++ b/openapi_core/contrib/flask/decorators.py @@ -36,11 +36,13 @@ def __init__( openapi_errors_handler: Type[ FlaskOpenAPIErrorsHandler ] = FlaskOpenAPIErrorsHandler, + **unmarshaller_kwargs: Any, ): super().__init__( spec, request_unmarshaller_cls=request_unmarshaller_cls, response_unmarshaller_cls=response_unmarshaller_cls, + **unmarshaller_kwargs, ) self.request_class = request_class self.response_class = response_class @@ -73,7 +75,7 @@ def _handle_request_view( request_result: RequestUnmarshalResult, view: Callable[[Any], Response], *args: Any, - **kwargs: Any + **kwargs: Any, ) -> Response: request = self._get_request() request.openapi = request_result # type: ignore @@ -113,6 +115,7 @@ def from_spec( openapi_errors_handler: Type[ FlaskOpenAPIErrorsHandler ] = FlaskOpenAPIErrorsHandler, + **unmarshaller_kwargs: Any, ) -> "FlaskOpenAPIViewDecorator": return cls( spec, @@ -122,4 +125,5 @@ def from_spec( response_class=response_class, request_provider=request_provider, openapi_errors_handler=openapi_errors_handler, + **unmarshaller_kwargs, ) diff --git a/openapi_core/contrib/flask/views.py b/openapi_core/contrib/flask/views.py index 23754bf4..71e1afe7 100644 --- a/openapi_core/contrib/flask/views.py +++ b/openapi_core/contrib/flask/views.py @@ -13,13 +13,15 @@ class FlaskOpenAPIView(MethodView): openapi_errors_handler = FlaskOpenAPIErrorsHandler - def __init__(self, spec: Spec): + def __init__(self, spec: Spec, **unmarshaller_kwargs: Any): super().__init__() self.spec = spec - def dispatch_request(self, *args: Any, **kwargs: Any) -> Any: - decorator = FlaskOpenAPIViewDecorator( + self.decorator = FlaskOpenAPIViewDecorator( self.spec, openapi_errors_handler=self.openapi_errors_handler, + **unmarshaller_kwargs, ) - return decorator(super().dispatch_request)(*args, **kwargs) + + def dispatch_request(self, *args: Any, **kwargs: Any) -> Any: + return self.decorator(super().dispatch_request)(*args, **kwargs) diff --git a/openapi_core/unmarshalling/processors.py b/openapi_core/unmarshalling/processors.py index b2200a90..5a1458c1 100644 --- a/openapi_core/unmarshalling/processors.py +++ b/openapi_core/unmarshalling/processors.py @@ -1,4 +1,5 @@ """OpenAPI core unmarshalling processors module""" +from typing import Any from typing import Optional from typing import Type @@ -20,6 +21,7 @@ def __init__( spec: Spec, request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + **unmarshaller_kwargs: Any, ): self.spec = spec if ( @@ -31,8 +33,12 @@ def __init__( request_unmarshaller_cls = classes.request_unmarshaller_cls if response_unmarshaller_cls is None: response_unmarshaller_cls = classes.response_unmarshaller_cls - self.request_unmarshaller = request_unmarshaller_cls(self.spec) - self.response_unmarshaller = response_unmarshaller_cls(self.spec) + self.request_unmarshaller = request_unmarshaller_cls( + self.spec, **unmarshaller_kwargs + ) + self.response_unmarshaller = response_unmarshaller_cls( + self.spec, **unmarshaller_kwargs + ) def process_request(self, request: Request) -> RequestUnmarshalResult: return self.request_unmarshaller.unmarshal(request) diff --git a/openapi_core/validation/processors.py b/openapi_core/validation/processors.py index 15f0c1b7..cef967af 100644 --- a/openapi_core/validation/processors.py +++ b/openapi_core/validation/processors.py @@ -1,4 +1,5 @@ """OpenAPI core validation processors module""" +from typing import Any from typing import Optional from openapi_core.protocols import Request @@ -15,6 +16,7 @@ def __init__( spec: Spec, request_validator_cls: Optional[RequestValidatorType] = None, response_validator_cls: Optional[ResponseValidatorType] = None, + **unmarshaller_kwargs: Any, ): self.spec = spec if request_validator_cls is None or response_validator_cls is None: @@ -23,8 +25,12 @@ def __init__( request_validator_cls = classes.request_validator_cls if response_validator_cls is None: response_validator_cls = classes.response_validator_cls - self.request_validator = request_validator_cls(self.spec) - self.response_validator = response_validator_cls(self.spec) + self.request_validator = request_validator_cls( + self.spec, **unmarshaller_kwargs + ) + self.response_validator = response_validator_cls( + self.spec, **unmarshaller_kwargs + ) def process_request(self, request: Request) -> None: self.request_validator.validate(request) diff --git a/tests/integration/contrib/falcon/data/v3.0/falconproject/openapi.py b/tests/integration/contrib/falcon/data/v3.0/falconproject/openapi.py index fbe86d14..2676ba21 100644 --- a/tests/integration/contrib/falcon/data/v3.0/falconproject/openapi.py +++ b/tests/integration/contrib/falcon/data/v3.0/falconproject/openapi.py @@ -8,4 +8,7 @@ openapi_spec_path = Path("tests/integration/data/v3.0/petstore.yaml") spec_dict = yaml.load(openapi_spec_path.read_text(), yaml.Loader) spec = Spec.from_dict(spec_dict) -openapi_middleware = FalconOpenAPIMiddleware.from_spec(spec) +openapi_middleware = FalconOpenAPIMiddleware.from_spec( + spec, + extra_media_type_deserializers={}, +) diff --git a/tests/integration/contrib/flask/test_flask_decorator.py b/tests/integration/contrib/flask/test_flask_decorator.py index a8b0c112..19bea449 100644 --- a/tests/integration/contrib/flask/test_flask_decorator.py +++ b/tests/integration/contrib/flask/test_flask_decorator.py @@ -173,7 +173,7 @@ def test_endpoint_error(self, client): } assert result.json == expected_data - def test_valid_response_object(self, client): + def test_response_object_valid(self, client): def view_response_callable(*args, **kwargs): from flask.globals import request @@ -197,7 +197,28 @@ def view_response_callable(*args, **kwargs): "data": "data", } - def test_valid_tuple_str(self, client): + @pytest.mark.parametrize( + "response,expected_status,expected_headers", + [ + # ((body, status, headers)) response tuple + ( + ("Not found", 404, {"X-Rate-Limit": "12"}), + 404, + {"X-Rate-Limit": "12"}, + ), + # (body, status) response tuple + (("Not found", 404), 404, {}), + # (body, headers) response tuple + ( + ({"data": "data"}, {"X-Rate-Limit": "12"}), + 200, + {"X-Rate-Limit": "12"}, + ), + ], + ) + def test_tuple_valid( + self, client, response, expected_status, expected_headers + ): def view_response_callable(*args, **kwargs): from flask.globals import request @@ -208,35 +229,16 @@ def view_response_callable(*args, **kwargs): "id": 12, } ) - return ("Not found", 404) + return response self.view_response_callable = view_response_callable result = client.get("/browse/12/") - assert result.status_code == 404 - assert result.text == "Not found" - - def test_valid_tuple_dict(self, client): - def view_response_callable(*args, **kwargs): - from flask.globals import request - - assert request.openapi - assert not request.openapi.errors - assert request.openapi.parameters == Parameters( - path={ - "id": 12, - } - ) - body = dict(data="data") - headers = {"X-Rate-Limit": "12"} - return (body, headers) - - self.view_response_callable = view_response_callable - - result = client.get("/browse/12/") - - assert result.status_code == 200 - assert result.json == { - "data": "data", - } + assert result.status_code == expected_status + expected_body = response[0] + if isinstance(expected_body, str): + assert result.text == expected_body + else: + assert result.json == expected_body + assert dict(result.headers).items() >= expected_headers.items() diff --git a/tests/integration/contrib/flask/test_flask_views.py b/tests/integration/contrib/flask/test_flask_views.py index 8d2f9d51..5a253ab5 100644 --- a/tests/integration/contrib/flask/test_flask_views.py +++ b/tests/integration/contrib/flask/test_flask_views.py @@ -38,7 +38,9 @@ def get(self, id): def post(self, id): return outer.view_response - return MyDetailsView.as_view("browse_details", spec) + return MyDetailsView.as_view( + "browse_details", spec, extra_media_type_deserializers={} + ) @pytest.fixture def list_view_func(self, spec): @@ -48,7 +50,9 @@ class MyListView(FlaskOpenAPIView): def get(self): return outer.view_response - return MyListView.as_view("browse_list", spec) + return MyListView.as_view( + "browse_list", spec, extra_format_validators={} + ) @pytest.fixture(autouse=True) def view(self, app, details_view_func, list_view_func):