Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,27 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
return self._handle_response(route=route, response=response)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
if response.body and response.is_json():
response.body = self._serialize_response(
field=route.dependant.return_param,
field = route.dependant.return_param

if field is None:
if not response.is_json():
return response
else:
# JSON serialize the body without validation
response.body = jsonable_encoder(response.body, custom_serializer=self._validation_serializer)
else:
response.body = self._serialize_response_with_validation(
field=field,
response_content=response.body,
has_route_custom_response_validation=route.custom_response_validation_http_code is not None,
)

return response

def _serialize_response(
def _serialize_response_with_validation(
self,
*,
field: ModelField | None = None,
field: ModelField,
response_content: Any,
include: IncEx | None = None,
exclude: IncEx | None = None,
Expand All @@ -254,45 +261,42 @@ def _serialize_response(
"""
Serialize the response content according to the field type.
"""
if field:
errors: list[dict[str, Any]] = []
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
# route-level validation must take precedence over app-level
if has_route_custom_response_validation:
raise ResponseValidationError(
errors=_normalize_errors(errors),
body=response_content,
source="route",
)
if self._has_response_validation_error:
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")

raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)

if hasattr(field, "serialize"):
return field.serialize(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
errors: list[dict[str, Any]] = []
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
# route-level validation must take precedence over app-level
if has_route_custom_response_validation:
raise ResponseValidationError(
errors=_normalize_errors(errors),
body=response_content,
source="route",
)
return jsonable_encoder(
if self._has_response_validation_error:
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")

raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)

if hasattr(field, "serialize"):
return field.serialize(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_serializer=self._validation_serializer,
)
else:
# Just serialize the response content returned from the handler.
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)

return jsonable_encoder(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_serializer=self._validation_serializer,
)

def _prepare_response_content(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def search(
# =============================================================================


@pytest.mark.skip("Due to issue #7981.")
@pytest.mark.asyncio
async def test_async_handler_with_validation():
# GIVEN an app with async handler and validation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,39 @@ def handler(user_id: int = 123):
assert result["statusCode"] == 200


@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed")
def test_validate_list_response(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

response_before_validation = [
{
"name": "Joe",
"age": 20,
},
{
"name": "Jane",
"age": 20,
},
]

@app.get("/list_response_with_same_element_types")
def handler_different_list() -> List[Model]:
return response_before_validation

# WHEN returning list with the same element type as the non-Optional return type
gw_event["path"] = "/list_response_with_same_element_types"
result = app(gw_event, {})
body = json.loads(result["body"])

# THEN it should return a validation error
assert result["statusCode"] == 200
assert body == response_before_validation


def test_validation_error_none_returned_non_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)
Expand All @@ -1656,6 +1688,32 @@ def handler_none_not_allowed() -> Model:
assert body["detail"][0]["loc"] == ["response"]


def test_validation_error_different_list_returned_non_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

different_list_response = ["a", "b", "c"]

@app.get("/list_response_with_different_element_types")
def handler_different_list() -> List[Model]:
return different_list_response

# WHEN returning list with the different element type as the non-Optional return type
gw_event["path"] = "/list_response_with_different_element_types"
result = app(gw_event, {})

# THEN it should return a validation error
assert result["statusCode"] == 422
body = json.loads(result["body"])
assert len(body["detail"]) == len(different_list_response)
assert body["detail"][0]["type"] == "model_attributes_type"
assert body["detail"][0]["loc"] == ["response", 0]


def test_validation_error_incomplete_model_returned_non_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)
Expand Down