Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion durabletask/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from durabletask.entities.entity_lock import EntityLock
from durabletask.entities.entity_context import EntityContext
from durabletask.entities.entity_metadata import EntityMetadata
from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException

__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"]
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata",
"EntityOperationFailedException"]

PACKAGE_NAME = "durabletask.entities"
12 changes: 9 additions & 3 deletions durabletask/entities/entity_instance_id.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
class EntityInstanceId:
def __init__(self, entity: str, key: str):
self.entity = entity
if not entity or not key:
raise ValueError("Entity name and key cannot be empty.")
if "@" in key:
raise ValueError("Entity key cannot contain '@' symbol.")
Comment on lines +5 to +6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be a breaking change unless we had other ways of filtering this that were already present.

self.entity = entity.lower()
self.key = key

def __str__(self) -> str:
Expand Down Expand Up @@ -35,8 +39,10 @@ def parse(entity_id: str) -> "EntityInstanceId":
ValueError
If the input string is not in the correct format.
"""
if not entity_id.startswith("@"):
raise ValueError("Entity ID must start with '@'.")
try:
_, entity, key = entity_id.split("@", 2)
return EntityInstanceId(entity=entity, key=key)
except ValueError as ex:
raise ValueError(f"Invalid entity ID format: {entity_id}", ex)
raise ValueError(f"Invalid entity ID format: {entity_id}") from ex
return EntityInstanceId(entity=entity, key=key)
15 changes: 15 additions & 0 deletions durabletask/entities/entity_operation_failed_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from durabletask.internal.orchestrator_service_pb2 import TaskFailureDetails
from durabletask.entities.entity_instance_id import EntityInstanceId


class EntityOperationFailedException(Exception):
"""Exception raised when an operation on an Entity Function fails."""

def __init__(self, entity_instance_id: EntityInstanceId, operation_name: str, failure_details: TaskFailureDetails) -> None:
super().__init__()
self.entity_instance_id = entity_instance_id
self.operation_name = operation_name
self.failure_details = failure_details

def __str__(self) -> str:
return f"Operation '{self.operation_name}' on entity '{self.entity_instance_id}' failed with error: {self.failure_details.errorMessage}"
12 changes: 12 additions & 0 deletions durabletask/internal/json_encode_output_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any


class JsonEncodeOutputException(Exception):
"""Custom exception type used to indicate that an orchestration result could not be JSON-encoded."""

def __init__(self, problem_object: Any):
super().__init__()
self.problem_object = problem_object

def __str__(self) -> str:
return f"The orchestration result could not be encoded. Object details: {self.problem_object}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own edification, how does this encode into JSON? From the name of the class, I would also assume that the output of this class is JSON not just the raw string, but if this is the python standard that's fine.

54 changes: 37 additions & 17 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import grpc
from google.protobuf import empty_pb2

from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException
from durabletask.internal import helpers
from durabletask.internal.entity_state_shim import StateShim
from durabletask.internal.helpers import new_timestamp
from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
from durabletask.internal.json_encode_output_exception import JsonEncodeOutputException
from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub
import durabletask.internal.helpers as ph
Expand Down Expand Up @@ -141,14 +143,12 @@ class _Registry:
orchestrators: dict[str, task.Orchestrator]
activities: dict[str, task.Activity]
entities: dict[str, task.Entity]
entity_instances: dict[str, DurableEntity]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is the intended fix for the leak issues, but will this have a significant impact on performance? Seems like an idle purge is a better solution, but if the creation is basically a no-op, I think this could be fine too. Just want to make sure we're considering both sides.

versioning: Optional[VersioningOptions] = None

def __init__(self):
self.orchestrators = {}
self.activities = {}
self.entities = {}
self.entity_instances = {}

def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
if fn is None:
Expand Down Expand Up @@ -201,6 +201,7 @@ def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
def add_named_entity(self, name: str, fn: task.Entity) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have the @ check here as well?

if not name:
raise ValueError("A non-empty entity name is required.")
name = name.lower()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't let me comment in the test file, but do we have a test for registering an entity with different casing working?

if name in self.entities:
raise ValueError(f"A '{name}' entity already exists.")

Expand Down Expand Up @@ -829,7 +830,7 @@ def __init__(self, instance_id: str, registry: _Registry):
self._pending_actions: dict[int, pb.OrchestratorAction] = {}
self._pending_tasks: dict[int, task.CompletableTask] = {}
# Maps entity ID to task ID
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, str, int]] = {}
self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
# Maps criticalSectionId to task ID
self._entity_lock_id_map: dict[str, int] = {}
Expand Down Expand Up @@ -902,7 +903,10 @@ def set_complete(
self._result = result
result_json: Optional[str] = None
if result is not None:
result_json = result if is_result_encoded else shared.to_json(result)
try:
result_json = result if is_result_encoded else shared.to_json(result)
except (ValueError, TypeError):
result_json = shared.to_json(str(JsonEncodeOutputException(result)))
Comment on lines 906 to 909
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewers - this is the "safe" approach where a return value that we cannot process to json is instead stringified along with an error message. I believe the other SDKs would just fail the orchestration outright - is this preferable here too?

action = ph.new_complete_orchestration_action(
self.next_sequence_number(), status, result_json
)
Expand Down Expand Up @@ -1606,7 +1610,7 @@ def process_event(
raise TypeError("Unexpected sub-orchestration task type")
elif event.HasField("eventRaised"):
if event.eventRaised.name in ctx._entity_task_id_map:
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
entity_id, operation, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None))
self._handle_entity_event_raised(ctx, event, entity_id, task_id, False)
elif event.eventRaised.name in ctx._entity_lock_task_id_map:
entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None))
Expand Down Expand Up @@ -1680,9 +1684,10 @@ def process_event(
)
try:
entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
operation = event.entityOperationCalled.operation
except ValueError:
raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id)
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, operation, entity_call_id)
elif event.HasField("entityOperationSignaled"):
# This history event confirms that the entity signal was successfully scheduled.
# Remove the entityOperationSignaled event from the pending action list so we don't schedule it
Expand Down Expand Up @@ -1743,7 +1748,7 @@ def process_event(
ctx.resume()
elif event.HasField("entityOperationCompleted"):
request_id = event.entityOperationCompleted.requestId
entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None))
entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None))
if not entity_id:
raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
if not task_id:
Expand All @@ -1762,10 +1767,29 @@ def process_event(
entity_task.complete(result)
ctx.resume()
elif event.HasField("entityOperationFailed"):
if not ctx.is_replaying:
self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
pass
request_id = event.entityOperationFailed.requestId
entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None))
if not entity_id:
raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
if operation is None:
raise RuntimeError(f"Could not parse operation name from request ID '{request_id}'")
if not task_id:
raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'")
entity_task = ctx._pending_tasks.pop(task_id, None)
if not entity_task:
if not ctx.is_replaying:
self._logger.warning(
f"{ctx.instance_id}: Ignoring unexpected entityOperationFailed event with request ID = {request_id}."
)
return
failure = EntityOperationFailedException(
entity_id,
operation,
event.entityOperationFailed.failureDetails
)
ctx._entity_context.recover_lock_after_call(entity_id)
entity_task.fail(str(failure), failure)
ctx.resume()
elif event.HasField("orchestratorCompleted"):
# Added in Functions only (for some reason) and does not affect orchestrator flow
pass
Expand All @@ -1777,7 +1801,7 @@ def process_event(
if action and action.HasField("sendEntityMessage"):
if action.sendEntityMessage.HasField("entityOperationCalled"):
entity_id, event_id = self._parse_entity_event_sent_input(event)
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
ctx._entity_task_id_map[event_id] = (entity_id, action.sendEntityMessage.entityOperationCalled.operation, event.eventId)
elif action.sendEntityMessage.HasField("entityLockRequested"):
entity_id, event_id = self._parse_entity_event_sent_input(event)
ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId)
Expand Down Expand Up @@ -1937,11 +1961,7 @@ def execute(
ctx = EntityContext(orchestration_id, operation, state, entity_id)

if isinstance(fn, type) and issubclass(fn, DurableEntity):
if self._registry.entity_instances.get(str(entity_id), None):
entity_instance = self._registry.entity_instances[str(entity_id)]
else:
entity_instance = fn()
self._registry.entity_instances[str(entity_id)] = entity_instance
entity_instance = fn()
if not hasattr(entity_instance, operation):
raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'")
method = getattr(entity_instance, operation)
Expand Down
Empty file.
Loading
Loading