Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/amp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def query_and_load_streaming(

# Optionally wrap with reorg detection
if with_reorg_detection:
stream_iterator = ReorgAwareStream(stream_iterator)
stream_iterator = ReorgAwareStream(stream_iterator, resume_watermark=resume_watermark)
self.logger.info('Reorg detection enabled for streaming query')

# Start continuous loading with checkpoint support
Expand Down
53 changes: 43 additions & 10 deletions src/amp/streaming/reorg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Iterator, List

from .iterator import StreamingResultIterator
from .types import BlockRange, ResponseBatch
from .types import BlockRange, ResponseBatch, ResumeWatermark


class ReorgAwareStream:
Expand All @@ -16,20 +16,32 @@ class ReorgAwareStream:
This class monitors the block ranges in consecutive batches to detect chain
reorganizations (reorgs). When a reorg is detected, a ResponseBatch with
is_reorg=True is emitted containing the invalidation ranges.

Supports cross-restart reorg detection by initializing from a resume watermark
that contains the last known block hashes from persistent state.
"""

def __init__(self, stream_iterator: StreamingResultIterator):
def __init__(self, stream_iterator: StreamingResultIterator, resume_watermark: ResumeWatermark = None):
"""
Initialize the reorg-aware stream.

Args:
stream_iterator: The underlying streaming result iterator
resume_watermark: Optional watermark from persistent state (LMDB) containing
last known block ranges with hashes for cross-restart reorg detection
"""
self.stream_iterator = stream_iterator
# Track the latest range for each network
self.prev_ranges_by_network: Dict[str, BlockRange] = {}
self.logger = logging.getLogger(__name__)

if resume_watermark:
for block_range in resume_watermark.ranges:
self.prev_ranges_by_network[block_range.network] = block_range
self.logger.debug(
f'Initialized reorg detection for {block_range.network} '
f'from block {block_range.end} hash {block_range.hash}'
)

def __iter__(self) -> Iterator[ResponseBatch]:
"""Return iterator instance"""
return self
Expand Down Expand Up @@ -89,9 +101,9 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]:
"""
Detect reorganizations by comparing current ranges with previous ranges.

A reorg is detected when:
- A range starts at or before the end of the previous range for the same network
- The range is different from the previous range
A reorg is detected when either:
1. Block number overlap: current range starts at or before previous range end
2. Hash mismatch: server's prev_hash doesn't match our stored hash (cross-restart detection)

Args:
current_ranges: Block ranges from the current batch
Expand All @@ -102,18 +114,39 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]:
invalidation_ranges = []

for current_range in current_ranges:
# Get the previous range for this network
prev_range = self.prev_ranges_by_network.get(current_range.network)

if prev_range:
# Check if this indicates a reorg
is_reorg = False

# Detection 1: Block number overlap (original logic)
if current_range != prev_range and current_range.start <= prev_range.end:
# Reorg detected - create invalidation range
# Invalidate from the start of the current range to the max end
is_reorg = True
self.logger.info(
f'Reorg detected via block overlap: {current_range.network} '
f'current start {current_range.start} <= prev end {prev_range.end}'
)

# Detection 2: Hash mismatch (cross-restart detection)
# Server sends prev_hash = hash of block before current range
# If it doesn't match our stored hash, chain has changed
elif (
current_range.prev_hash is not None
and prev_range.hash is not None
and current_range.prev_hash != prev_range.hash
):
is_reorg = True
self.logger.info(
f'Reorg detected via hash mismatch: {current_range.network} '
f'server prev_hash {current_range.prev_hash} != stored hash {prev_range.hash}'
)

if is_reorg:
invalidation = BlockRange(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we invalidate the entire previous range in this case to be safe?

Copy link
Member Author

@incrypto32 incrypto32 Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes good catch. for the case of hash mismatch we need to invalidate entire previous range to be safe. But just setting it to previous range would create a gap since previous range would be skipped completely since processing started in the next range.
So there need to be some changes on how this ReorgAwareStream works. I'll look into and come up with a better solution. I also identified a bug in ReorgAwareStream currently when there is _pending_batch it gets returned but the current_batch that was just fetched from the stream gets dropped completely.

I'll come up with a fix for that as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proper fix would be to trigger a backfill when we invalidate more than the current range.

network=current_range.network,
start=current_range.start,
end=max(current_range.end, prev_range.end),
hash=prev_range.hash,
)
invalidation_ranges.append(invalidation)

Expand Down
102 changes: 102 additions & 0 deletions tests/unit/test_streaming_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,105 @@ class MockIterator:
stream = ReorgAwareStream(MockIterator())

assert stream._is_duplicate_batch([]) == False

def test_init_from_resume_watermark(self):
"""Test initialization from resume watermark for cross-restart reorg detection"""

class MockIterator:
pass

watermark = ResumeWatermark(
ranges=[
BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'),
BlockRange(network='polygon', start=50, end=150, hash='0xdef456'),
]
)

stream = ReorgAwareStream(MockIterator(), resume_watermark=watermark)

assert 'ethereum' in stream.prev_ranges_by_network
assert 'polygon' in stream.prev_ranges_by_network
assert stream.prev_ranges_by_network['ethereum'].hash == '0xabc123'
assert stream.prev_ranges_by_network['polygon'].hash == '0xdef456'

def test_detect_reorg_hash_mismatch(self):
"""Test reorg detection via hash mismatch (cross-restart detection)"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xoriginal'),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xdifferent'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 1
assert invalidations[0].network == 'ethereum'
assert invalidations[0].hash == '0xoriginal'

def test_detect_reorg_hash_match_no_reorg(self):
"""Test no reorg when hashes match across restart"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xsame'),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsame'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0

def test_detect_reorg_hash_mismatch_with_none_prev_hash(self):
"""Test no reorg detection when server prev_hash is None (genesis block)"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=0, end=0, hash='0xgenesis'),
}

current_ranges = [
BlockRange(network='ethereum', start=1, end=100, prev_hash=None),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0

def test_detect_reorg_hash_mismatch_with_none_stored_hash(self):
"""Test no reorg detection when stored hash is None"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash=None),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsome_hash'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0
Loading