diff --git a/durabletask/entities/__init__.py b/durabletask/entities/__init__.py index 46f059b..28259f9 100644 --- a/durabletask/entities/__init__.py +++ b/durabletask/entities/__init__.py @@ -8,7 +8,9 @@ from durabletask.entities.entity_lock import EntityLock from durabletask.entities.entity_context import EntityContext from durabletask.entities.entity_metadata import EntityMetadata +from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException -__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"] +__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata", + "EntityOperationFailedException"] PACKAGE_NAME = "durabletask.entities" diff --git a/durabletask/entities/entity_instance_id.py b/durabletask/entities/entity_instance_id.py index 02a2595..a4ddbde 100644 --- a/durabletask/entities/entity_instance_id.py +++ b/durabletask/entities/entity_instance_id.py @@ -1,6 +1,10 @@ class EntityInstanceId: def __init__(self, entity: str, key: str): - self.entity = entity + if not entity or not key: + raise ValueError("Entity name and key cannot be empty.") + if "@" in key: + raise ValueError("Entity key cannot contain '@' symbol.") + self.entity = entity.lower() self.key = key def __str__(self) -> str: @@ -35,8 +39,10 @@ def parse(entity_id: str) -> "EntityInstanceId": ValueError If the input string is not in the correct format. """ + if not entity_id.startswith("@"): + raise ValueError("Entity ID must start with '@'.") try: _, entity, key = entity_id.split("@", 2) - return EntityInstanceId(entity=entity, key=key) except ValueError as ex: - raise ValueError(f"Invalid entity ID format: {entity_id}", ex) + raise ValueError(f"Invalid entity ID format: {entity_id}") from ex + return EntityInstanceId(entity=entity, key=key) diff --git a/durabletask/entities/entity_operation_failed_exception.py b/durabletask/entities/entity_operation_failed_exception.py new file mode 100644 index 0000000..a69094e --- /dev/null +++ b/durabletask/entities/entity_operation_failed_exception.py @@ -0,0 +1,15 @@ +from durabletask.internal.orchestrator_service_pb2 import TaskFailureDetails +from durabletask.entities.entity_instance_id import EntityInstanceId + + +class EntityOperationFailedException(Exception): + """Exception raised when an operation on an Entity Function fails.""" + + def __init__(self, entity_instance_id: EntityInstanceId, operation_name: str, failure_details: TaskFailureDetails) -> None: + super().__init__() + self.entity_instance_id = entity_instance_id + self.operation_name = operation_name + self.failure_details = failure_details + + def __str__(self) -> str: + return f"Operation '{self.operation_name}' on entity '{self.entity_instance_id}' failed with error: {self.failure_details.errorMessage}" diff --git a/durabletask/internal/json_encode_output_exception.py b/durabletask/internal/json_encode_output_exception.py new file mode 100644 index 0000000..992b040 --- /dev/null +++ b/durabletask/internal/json_encode_output_exception.py @@ -0,0 +1,12 @@ +from typing import Any + + +class JsonEncodeOutputException(Exception): + """Custom exception type used to indicate that an orchestration result could not be JSON-encoded.""" + + def __init__(self, problem_object: Any): + super().__init__() + self.problem_object = problem_object + + def __str__(self) -> str: + return f"The orchestration result could not be encoded. Object details: {self.problem_object}" diff --git a/durabletask/worker.py b/durabletask/worker.py index 4d9da6d..521bccf 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -19,10 +19,12 @@ import grpc from google.protobuf import empty_pb2 +from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException from durabletask.internal import helpers from durabletask.internal.entity_state_shim import StateShim from durabletask.internal.helpers import new_timestamp from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext +from durabletask.internal.json_encode_output_exception import JsonEncodeOutputException from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub import durabletask.internal.helpers as ph @@ -141,14 +143,12 @@ class _Registry: orchestrators: dict[str, task.Orchestrator] activities: dict[str, task.Activity] entities: dict[str, task.Entity] - entity_instances: dict[str, DurableEntity] versioning: Optional[VersioningOptions] = None def __init__(self): self.orchestrators = {} self.activities = {} self.entities = {} - self.entity_instances = {} def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str: if fn is None: @@ -201,6 +201,7 @@ def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str: def add_named_entity(self, name: str, fn: task.Entity) -> None: if not name: raise ValueError("A non-empty entity name is required.") + name = name.lower() if name in self.entities: raise ValueError(f"A '{name}' entity already exists.") @@ -829,7 +830,7 @@ def __init__(self, instance_id: str, registry: _Registry): self._pending_actions: dict[int, pb.OrchestratorAction] = {} self._pending_tasks: dict[int, task.CompletableTask] = {} # Maps entity ID to task ID - self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} + self._entity_task_id_map: dict[str, tuple[EntityInstanceId, str, int]] = {} self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} # Maps criticalSectionId to task ID self._entity_lock_id_map: dict[str, int] = {} @@ -902,7 +903,10 @@ def set_complete( self._result = result result_json: Optional[str] = None if result is not None: - result_json = result if is_result_encoded else shared.to_json(result) + try: + result_json = result if is_result_encoded else shared.to_json(result) + except (ValueError, TypeError): + result_json = shared.to_json(str(JsonEncodeOutputException(result))) action = ph.new_complete_orchestration_action( self.next_sequence_number(), status, result_json ) @@ -1606,7 +1610,7 @@ def process_event( raise TypeError("Unexpected sub-orchestration task type") elif event.HasField("eventRaised"): if event.eventRaised.name in ctx._entity_task_id_map: - entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None)) + entity_id, operation, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None)) self._handle_entity_event_raised(ctx, event, entity_id, task_id, False) elif event.eventRaised.name in ctx._entity_lock_task_id_map: entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None)) @@ -1680,9 +1684,10 @@ def process_event( ) try: entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value) + operation = event.entityOperationCalled.operation except ValueError: raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'") - ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id) + ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, operation, entity_call_id) elif event.HasField("entityOperationSignaled"): # This history event confirms that the entity signal was successfully scheduled. # Remove the entityOperationSignaled event from the pending action list so we don't schedule it @@ -1743,7 +1748,7 @@ def process_event( ctx.resume() elif event.HasField("entityOperationCompleted"): request_id = event.entityOperationCompleted.requestId - entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None)) + entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None)) if not entity_id: raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'") if not task_id: @@ -1762,10 +1767,29 @@ def process_event( entity_task.complete(result) ctx.resume() elif event.HasField("entityOperationFailed"): - if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id}: Entity operation failed.") - self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}") - pass + request_id = event.entityOperationFailed.requestId + entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None)) + if not entity_id: + raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'") + if operation is None: + raise RuntimeError(f"Could not parse operation name from request ID '{request_id}'") + if not task_id: + raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'") + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + if not ctx.is_replaying: + self._logger.warning( + f"{ctx.instance_id}: Ignoring unexpected entityOperationFailed event with request ID = {request_id}." + ) + return + failure = EntityOperationFailedException( + entity_id, + operation, + event.entityOperationFailed.failureDetails + ) + ctx._entity_context.recover_lock_after_call(entity_id) + entity_task.fail(str(failure), failure) + ctx.resume() elif event.HasField("orchestratorCompleted"): # Added in Functions only (for some reason) and does not affect orchestrator flow pass @@ -1777,7 +1801,7 @@ def process_event( if action and action.HasField("sendEntityMessage"): if action.sendEntityMessage.HasField("entityOperationCalled"): entity_id, event_id = self._parse_entity_event_sent_input(event) - ctx._entity_task_id_map[event_id] = (entity_id, event.eventId) + ctx._entity_task_id_map[event_id] = (entity_id, action.sendEntityMessage.entityOperationCalled.operation, event.eventId) elif action.sendEntityMessage.HasField("entityLockRequested"): entity_id, event_id = self._parse_entity_event_sent_input(event) ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId) @@ -1937,11 +1961,7 @@ def execute( ctx = EntityContext(orchestration_id, operation, state, entity_id) if isinstance(fn, type) and issubclass(fn, DurableEntity): - if self._registry.entity_instances.get(str(entity_id), None): - entity_instance = self._registry.entity_instances[str(entity_id)] - else: - entity_instance = fn() - self._registry.entity_instances[str(entity_id)] = entity_instance + entity_instance = fn() if not hasattr(entity_instance, operation): raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'") method = getattr(entity_instance, operation) diff --git a/tests/durabletask-azuremanaged/entities/__init__.py b/tests/durabletask-azuremanaged/entities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py b/tests/durabletask-azuremanaged/entities/test_dts_class_based_entities_e2e.py similarity index 100% rename from tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py rename to tests/durabletask-azuremanaged/entities/test_dts_class_based_entities_e2e.py diff --git a/tests/durabletask-azuremanaged/entities/test_dts_entity_failure_handling.py b/tests/durabletask-azuremanaged/entities/test_dts_entity_failure_handling.py new file mode 100644 index 0000000..b12d158 --- /dev/null +++ b/tests/durabletask-azuremanaged/entities/test_dts_entity_failure_handling.py @@ -0,0 +1,186 @@ + +import json +import os +from durabletask import client, entities, task + +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +def test_class_entity_unhandled_failure_fails(): + class FailingEntity(entities.DurableEntity): + def fail(self, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("FailingEntity", "testEntity") + yield ctx.call_entity(entity_id, "fail") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(FailingEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + # NOTE: Because FailureDetails does not support inner_failure, we can't verify that the inner failure type is + # EntityOperationFailedException. In the future, we should consider adding support for inner failures in + # FailureDetails to make this more robust. This applies to all tests in this file. For now, the error message's + # structure is sufficient to verify that the failure was due to the EntityOperationFailedException. + assert state.failure_details.message == "Operation 'fail' on entity '@failingentity@testEntity' failed with " \ + "error: Something went wrong!" + assert state.runtime_status == client.OrchestrationStatus.FAILED + + +def test_function_entity_unhandled_failure_fails(): + def failing_entity(ctx: entities.EntityContext, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("failing_entity", "testEntity") + yield ctx.call_entity(entity_id, "fail") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(failing_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message == "Operation 'fail' on entity '@failing_entity@testEntity' failed with " \ + "error: Something went wrong!" + assert state.runtime_status == client.OrchestrationStatus.FAILED + + +def test_class_entity_handled_failure_succeeds(): + class FailingEntity(entities.DurableEntity): + def fail(self, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("FailingEntity", "testEntity") + try: + yield ctx.call_entity(entity_id, "fail") + except task.TaskFailedError as e: + return e.details.message # returning just the message to avoid issues with JSON serialization of FailureDetails + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(FailingEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert output == "Operation 'fail' on entity '@failingentity@testEntity' failed with error: Something went wrong!" + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + +def test_function_entity_handled_failure_succeeds(): + def failing_entity(ctx: entities.EntityContext, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("failing_entity", "testEntity") + try: + yield ctx.call_entity(entity_id, "fail") + except task.TaskFailedError as e: + return e.details.message # returning just the message to avoid issues with JSON serialization of FailureDetails + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(failing_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert output == "Operation 'fail' on entity '@failing_entity@testEntity' failed with error: Something went wrong!" + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + +def test_class_entity_failure_unlocks_entity(): + def failing_entity(ctx: entities.EntityContext, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + exception_count = 0 + entity_id = entities.EntityInstanceId("failing_entity", "testEntity") + with (yield ctx.lock_entities([entity_id])): + try: + yield ctx.call_entity(entity_id, "fail") + except task.TaskFailedError: + exception_count += 1 + try: + yield ctx.call_entity(entity_id, "fail") + except task.TaskFailedError: + exception_count += 1 + return exception_count + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(failing_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert output == 2 + assert state.runtime_status == client.OrchestrationStatus.COMPLETED diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py b/tests/durabletask-azuremanaged/entities/test_dts_function_based_entities_e2e.py similarity index 100% rename from tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py rename to tests/durabletask-azuremanaged/entities/test_dts_function_based_entities_e2e.py diff --git a/tests/durabletask/entities/test_entity_id_parsing.py b/tests/durabletask/entities/test_entity_id_parsing.py new file mode 100644 index 0000000..c56a878 --- /dev/null +++ b/tests/durabletask/entities/test_entity_id_parsing.py @@ -0,0 +1,50 @@ +import pytest +from durabletask.entities import EntityInstanceId + + +def test_entity_id_parsing_success(): + entity_id_str = "@MyEntity@TestInstance" + entity_id = EntityInstanceId.parse(entity_id_str) + assert entity_id.entity == "myentity" # should be case-insensitive (lowercased) + assert entity_id.key == "TestInstance" + assert str(entity_id) == "@myentity@TestInstance" + + +def test_is_entity_id_name_case_insensitive(): + id1 = EntityInstanceId("MyEntity", "instance1") + id2 = EntityInstanceId("myentity", "instance1") + assert id1 == id2 + + +def test_entity_id_parsing_failures(): + # Test empty string + with pytest.raises(ValueError): + EntityInstanceId.parse("") + + # Test invalid entity id format + with pytest.raises(ValueError): + EntityInstanceId.parse("invalidEntityId") + + # Test single @ + with pytest.raises(ValueError): + EntityInstanceId.parse("@") + + # Test double @ + with pytest.raises(ValueError): + EntityInstanceId.parse("@@") + + # Test @ with invalid placement + with pytest.raises(ValueError): + EntityInstanceId.parse("@invalid@") + + # Test @@ at end + with pytest.raises(ValueError): + EntityInstanceId.parse("@@invalid") + + # Test symbol in wrong position + with pytest.raises(ValueError): + EntityInstanceId.parse("invalid@symbolplacement") + + # Test multiple @ symbols + with pytest.raises(ValueError): + EntityInstanceId.parse("invalid@symbol@placement")