diff --git a/durabletask/worker.py b/durabletask/worker.py index 0ec2f66..e10538c 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -12,7 +12,7 @@ from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType -from typing import Any, Generator, Optional, Sequence, TypeVar, Union +from typing import Any, Generator, Iterator, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import empty_pb2 @@ -30,7 +30,7 @@ # If `opentelemetry-sdk` is available, enable the tracer try: from opentelemetry import trace - from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + from opentelemetry.trace.propagation.tracecontext TraceContextTextMapPropagator otel_propagator = TraceContextTextMapPropagator() otel_tracer = trace.get_tracer(__name__) @@ -283,7 +283,7 @@ class TaskHubGrpcWorker: activity function. """ - _response_stream: Optional[grpc.Future] = None + _response_stream: Optional[Iterator[Any]] = None _interceptors: Optional[list[shared.ClientInterceptor]] = None def __init__( @@ -418,10 +418,10 @@ def create_fresh_connection(): def invalidate_connection(): nonlocal current_channel, current_stub, current_reader_thread - # Cancel the response stream first to signal the reader thread to stop + # Close the response stream first to signal the reader thread to stop if self._response_stream is not None: try: - self._response_stream.cancel() + self._response_stream.close() except Exception: pass self._response_stream = None @@ -740,7 +740,10 @@ def stop(self): self._logger.info("Stopping gRPC worker...") if self._response_stream is not None: - self._response_stream.cancel() + try: + self._response_stream.close() + except Exception as e: + self._logger.exception(f"Error stopping response stream: {e}") self._shutdown.set() # Explicitly close the gRPC channel to ensure OTel interceptors and other resources are cleaned up if self._current_channel is not None: