diff --git a/src/amp/client.py b/src/amp/client.py index 2eee462..20beaae 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -791,42 +791,62 @@ def query_and_load_streaming( self.logger.warning(f'Failed to load checkpoint, starting from beginning: {e}') try: - # Execute streaming query with Flight SQL - # Create a CommandStatementQuery message - command_query = FlightSql_pb2.CommandStatementQuery() - command_query.query = query - - # Add resume watermark if provided - if resume_watermark: - # TODO: Add watermark to query metadata when Flight SQL supports it - self.logger.info(f'Resuming stream from watermark: {resume_watermark}') - - # Wrap the CommandStatementQuery in an Any type - any_command = Any() - any_command.Pack(command_query) - cmd = any_command.SerializeToString() - - self.logger.info('Establishing Flight SQL connection...') - flight_descriptor = flight.FlightDescriptor.for_command(cmd) - info = self.conn.get_flight_info(flight_descriptor) - reader = self.conn.do_get(info.endpoints[0].ticket) - - # Create streaming iterator - stream_iterator = StreamingResultIterator(reader) - self.logger.info('Stream connection established, waiting for data...') - - # Optionally wrap with reorg detection - if with_reorg_detection: - stream_iterator = ReorgAwareStream(stream_iterator) - self.logger.info('Reorg detection enabled for streaming query') - - # Start continuous loading with checkpoint support with loader_instance: - self.logger.info(f'Starting continuous load to {destination}. Press Ctrl+C to stop.') - # Pass connection_name for checkpoint saving - yield from loader_instance.load_stream_continuous( - stream_iterator, destination, connection_name=connection_name, **load_config.__dict__ - ) + while True: + # Execute streaming query with Flight SQL + # Create a CommandStatementQuery message + command_query = FlightSql_pb2.CommandStatementQuery() + command_query.query = query + + # Add resume watermark if provided + if resume_watermark: + # TODO: Add watermark to query metadata when Flight SQL supports it + self.logger.info(f'Resuming stream from watermark: {resume_watermark}') + + # Wrap the CommandStatementQuery in an Any type + any_command = Any() + any_command.Pack(command_query) + cmd = any_command.SerializeToString() + + self.logger.info('Establishing Flight SQL connection...') + flight_descriptor = flight.FlightDescriptor.for_command(cmd) + info = self.conn.get_flight_info(flight_descriptor) + reader = self.conn.do_get(info.endpoints[0].ticket) + + # Create streaming iterator + stream_iterator = StreamingResultIterator(reader) + self.logger.info('Stream connection established, waiting for data...') + + # Optionally wrap with reorg detection + if with_reorg_detection: + stream_iterator = ReorgAwareStream(stream_iterator, resume_watermark=resume_watermark) + self.logger.info('Reorg detection enabled for streaming query') + + # Start continuous loading with checkpoint support + self.logger.info(f'Starting continuous load to {destination}. Press Ctrl+C to stop.') + + reorg_result = None + # Pass connection_name for checkpoint saving + for result in loader_instance.load_stream_continuous( + stream_iterator, destination, connection_name=connection_name, **load_config.__dict__ + ): + yield result + # Break on reorg to restart stream + if result.is_reorg: + reorg_result = result + break + + # Check if we need to restart due to reorg + if reorg_result: + # Close the old stream before restarting + if hasattr(stream_iterator, 'close'): + stream_iterator.close() + self.logger.info('Reorg detected, restarting stream with new resume position...') + resume_watermark = loader_instance.state_store.get_resume_position(connection_name, destination) + continue + + # Normal exit - stream completed + break except Exception as e: self.logger.error(f'Streaming query failed: {e}') diff --git a/src/amp/streaming/reorg.py b/src/amp/streaming/reorg.py index 9083db7..de870ee 100644 --- a/src/amp/streaming/reorg.py +++ b/src/amp/streaming/reorg.py @@ -6,7 +6,7 @@ from typing import Dict, Iterator, List from .iterator import StreamingResultIterator -from .types import BlockRange, ResponseBatch +from .types import BlockRange, ResponseBatch, ResumeWatermark class ReorgAwareStream: @@ -16,20 +16,32 @@ class ReorgAwareStream: This class monitors the block ranges in consecutive batches to detect chain reorganizations (reorgs). When a reorg is detected, a ResponseBatch with is_reorg=True is emitted containing the invalidation ranges. + + Supports cross-restart reorg detection by initializing from a resume watermark + that contains the last known block hashes from persistent state. """ - def __init__(self, stream_iterator: StreamingResultIterator): + def __init__(self, stream_iterator: StreamingResultIterator, resume_watermark: ResumeWatermark = None): """ Initialize the reorg-aware stream. Args: stream_iterator: The underlying streaming result iterator + resume_watermark: Optional watermark from persistent state (LMDB) containing + last known block ranges with hashes for cross-restart reorg detection """ self.stream_iterator = stream_iterator - # Track the latest range for each network self.prev_ranges_by_network: Dict[str, BlockRange] = {} self.logger = logging.getLogger(__name__) + if resume_watermark: + for block_range in resume_watermark.ranges: + self.prev_ranges_by_network[block_range.network] = block_range + self.logger.debug( + f'Initialized reorg detection for {block_range.network} ' + f'from block {block_range.end} hash {block_range.hash}' + ) + def __iter__(self) -> Iterator[ResponseBatch]: """Return iterator instance""" return self @@ -63,20 +75,16 @@ def __next__(self) -> ResponseBatch: for range in batch.metadata.ranges: self.prev_ranges_by_network[range.network] = range - # If we detected a reorg, yield the reorg notification first + # If we detected a reorg, return reorg batch + # Caller decides whether to stop/restart or continue if invalidation_ranges: self.logger.info(f'Reorg detected with {len(invalidation_ranges)} invalidation ranges') - # Store the batch to yield after the reorg - self._pending_batch = batch + # Clear memory for affected networks so restart works correctly + for inv_range in invalidation_ranges: + if inv_range.network in self.prev_ranges_by_network: + del self.prev_ranges_by_network[inv_range.network] return ResponseBatch.reorg_batch(invalidation_ranges) - # Check if we have a pending batch from a previous reorg detection - # REVIEW: I think we should remove this - if hasattr(self, '_pending_batch'): - pending = self._pending_batch - delattr(self, '_pending_batch') - return pending - # Normal case - just return the data batch return batch @@ -89,9 +97,9 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]: """ Detect reorganizations by comparing current ranges with previous ranges. - A reorg is detected when: - - A range starts at or before the end of the previous range for the same network - - The range is different from the previous range + A reorg is detected when either: + 1. Block number overlap: current range starts at or before previous range end + 2. Hash mismatch: server's prev_hash doesn't match our stored hash (cross-restart detection) Args: current_ranges: Block ranges from the current batch @@ -102,18 +110,39 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]: invalidation_ranges = [] for current_range in current_ranges: - # Get the previous range for this network prev_range = self.prev_ranges_by_network.get(current_range.network) if prev_range: - # Check if this indicates a reorg + is_reorg = False + + # Detection 1: Block number overlap (original logic) if current_range != prev_range and current_range.start <= prev_range.end: - # Reorg detected - create invalidation range - # Invalidate from the start of the current range to the max end + is_reorg = True + self.logger.info( + f'Reorg detected via block overlap: {current_range.network} ' + f'current start {current_range.start} <= prev end {prev_range.end}' + ) + + # Detection 2: Hash mismatch (cross-restart detection) + # Server sends prev_hash = hash of block before current range + # If it doesn't match our stored hash, chain has changed + elif ( + current_range.prev_hash is not None + and prev_range.hash is not None + and current_range.prev_hash != prev_range.hash + ): + is_reorg = True + self.logger.info( + f'Reorg detected via hash mismatch: {current_range.network} ' + f'server prev_hash {current_range.prev_hash} != stored hash {prev_range.hash}' + ) + + if is_reorg: invalidation = BlockRange( network=current_range.network, - start=current_range.start, + start=prev_range.start, end=max(current_range.end, prev_range.end), + hash=prev_range.hash, ) invalidation_ranges.append(invalidation) diff --git a/tests/unit/test_streaming_types.py b/tests/unit/test_streaming_types.py index 47eede2..88ad817 100644 --- a/tests/unit/test_streaming_types.py +++ b/tests/unit/test_streaming_types.py @@ -474,7 +474,7 @@ class MockIterator: assert len(invalidations) == 1 assert invalidations[0].network == 'ethereum' - assert invalidations[0].start == 180 + assert invalidations[0].start == 100 # prev_range.start assert invalidations[0].end == 280 # max(280, 200) def test_detect_reorg_multiple_networks(self): @@ -504,12 +504,12 @@ class MockIterator: # Check ethereum reorg eth_inv = next(inv for inv in invalidations if inv.network == 'ethereum') - assert eth_inv.start == 150 + assert eth_inv.start == 100 # prev_range.start assert eth_inv.end == 250 # Check polygon reorg poly_inv = next(inv for inv in invalidations if inv.network == 'polygon') - assert poly_inv.start == 140 + assert poly_inv.start == 50 # prev_range.start assert poly_inv.end == 240 def test_detect_reorg_same_range_no_reorg(self): @@ -546,7 +546,7 @@ class MockIterator: invalidations = stream._detect_reorg(current_ranges) assert len(invalidations) == 1 - assert invalidations[0].start == 250 + assert invalidations[0].start == 100 # prev_range.start assert invalidations[0].end == 300 # max(280, 300) def test_is_duplicate_batch_all_same(self): @@ -619,3 +619,105 @@ class MockIterator: stream = ReorgAwareStream(MockIterator()) assert stream._is_duplicate_batch([]) == False + + def test_init_from_resume_watermark(self): + """Test initialization from resume watermark for cross-restart reorg detection""" + + class MockIterator: + pass + + watermark = ResumeWatermark( + ranges=[ + BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'), + BlockRange(network='polygon', start=50, end=150, hash='0xdef456'), + ] + ) + + stream = ReorgAwareStream(MockIterator(), resume_watermark=watermark) + + assert 'ethereum' in stream.prev_ranges_by_network + assert 'polygon' in stream.prev_ranges_by_network + assert stream.prev_ranges_by_network['ethereum'].hash == '0xabc123' + assert stream.prev_ranges_by_network['polygon'].hash == '0xdef456' + + def test_detect_reorg_hash_mismatch(self): + """Test reorg detection via hash mismatch (cross-restart detection)""" + + class MockIterator: + pass + + stream = ReorgAwareStream(MockIterator()) + + stream.prev_ranges_by_network = { + 'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xoriginal'), + } + + current_ranges = [ + BlockRange(network='ethereum', start=201, end=300, prev_hash='0xdifferent'), + ] + + invalidations = stream._detect_reorg(current_ranges) + + assert len(invalidations) == 1 + assert invalidations[0].network == 'ethereum' + assert invalidations[0].hash == '0xoriginal' + + def test_detect_reorg_hash_match_no_reorg(self): + """Test no reorg when hashes match across restart""" + + class MockIterator: + pass + + stream = ReorgAwareStream(MockIterator()) + + stream.prev_ranges_by_network = { + 'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xsame'), + } + + current_ranges = [ + BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsame'), + ] + + invalidations = stream._detect_reorg(current_ranges) + + assert len(invalidations) == 0 + + def test_detect_reorg_hash_mismatch_with_none_prev_hash(self): + """Test no reorg detection when server prev_hash is None (genesis block)""" + + class MockIterator: + pass + + stream = ReorgAwareStream(MockIterator()) + + stream.prev_ranges_by_network = { + 'ethereum': BlockRange(network='ethereum', start=0, end=0, hash='0xgenesis'), + } + + current_ranges = [ + BlockRange(network='ethereum', start=1, end=100, prev_hash=None), + ] + + invalidations = stream._detect_reorg(current_ranges) + + assert len(invalidations) == 0 + + def test_detect_reorg_hash_mismatch_with_none_stored_hash(self): + """Test no reorg detection when stored hash is None""" + + class MockIterator: + pass + + stream = ReorgAwareStream(MockIterator()) + + stream.prev_ranges_by_network = { + 'ethereum': BlockRange(network='ethereum', start=100, end=200, hash=None), + } + + current_ranges = [ + BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsome_hash'), + ] + + invalidations = stream._detect_reorg(current_ranges) + + assert len(invalidations) == 0