Skip to content
13 changes: 13 additions & 0 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DEFAULT_OPENAPI_VERSION,
)
from aws_lambda_powertools.event_handler.openapi.exceptions import (
RequestUnsupportedContentType,
RequestValidationError,
ResponseValidationError,
SchemaValidationError,
Expand Down Expand Up @@ -2972,6 +2973,18 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
route=route,
)

if isinstance(exp, RequestUnsupportedContentType):
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
return self._response_builder_class(
response=Response(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
content_type=content_types.APPLICATION_JSON,
body={"statusCode": HTTPStatus.UNSUPPORTED_MEDIA_TYPE, "detail": errors},
),
serializer=self._serializer,
route=route,
)

if isinstance(exp, ServiceError):
return self._response_builder_class(
response=Response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
)
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
from aws_lambda_powertools.event_handler.openapi.exceptions import (
RequestUnsupportedContentType,
RequestValidationError,
ResponseValidationError,
)
from aws_lambda_powertools.event_handler.openapi.params import Param

if TYPE_CHECKING:
Expand Down Expand Up @@ -129,7 +133,18 @@ def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
return self._parse_form_data(app)

else:
raise NotImplementedError("Only JSON body or Form() are supported")
raise RequestUnsupportedContentType(
"Only JSON body or Form() are supported",
errors=[
{
"type": "unsupported_content_type",
"loc": ("body",),
"msg": "Only JSON body or Form() are supported",
"input": {},
"ctx": {},
},
],
)

def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
"""Parse JSON data from the request body."""
Expand Down
10 changes: 10 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ class SchemaValidationError(ValidationException):

class OpenAPIMergeError(Exception):
"""Exception raised when there's a conflict during OpenAPI merge."""


class RequestUnsupportedContentType(NotImplementedError, ValidationException):
"""Exception raised when trying to read request body data, with unknown headers"""

# REVIEW: This inheritance is for backwards compatibility.
# Just inherit from ValidationException in Powertools V4
def __init__(self, msg: str, errors: Sequence[Any]) -> None:
NotImplementedError.__init__(self, msg)
ValidationException.__init__(self, errors)
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,32 @@ def handler(user: Model) -> Model:
assert json.loads(result["body"]) == {"name": "John", "age": 30}


def test_validate_unsupported_content_type_headers(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

# WHEN a handler is defined with a body parameter
# WHEN headers has unsupported content-type
@app.post("/")
def handler(user: Model) -> Model:
return user

gw_event["httpMethod"] = "POST"
gw_event["headers"] = {"Content-type": "text/fake-content-type"}
gw_event["path"] = "/"
gw_event["body"] = json.dumps({"name": "John", "age": 30})

# THEN the handler should return 415 (Unsupported Media Type)
# THEN the body must have the "unsupported_content_type" error message
result = app(gw_event, {})
assert result["statusCode"] == 415
assert "unsupported_content_type" in result["body"]


def test_validate_body_param_with_invalid_date(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)
Expand Down