Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/amp/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
state_config_dict = config.get('state', {})
self.state_enabled = state_config_dict.get('enabled', True)
self.state_storage = state_config_dict.get('storage', 'memory')
self.state_data_dir = state_config_dict.get('data_dir', '.amp_state')
self.store_batch_id = state_config_dict.get('store_batch_id', True)
self.store_full_metadata = state_config_dict.get('store_full_metadata', False)

Expand Down
2 changes: 2 additions & 0 deletions src/amp/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Streaming module for continuous data loading
from .iterator import StreamingResultIterator
from .lmdb_state import LMDBStreamStateStore
from .parallel import (
BlockRangePartitionStrategy,
ParallelConfig,
Expand Down Expand Up @@ -35,6 +36,7 @@
'StreamStateStore',
'InMemoryStreamStateStore',
'NullStreamStateStore',
'LMDBStreamStateStore',
'BatchIdentifier',
'ProcessedBatch',
]
371 changes: 371 additions & 0 deletions src/amp/streaming/lmdb_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
"""
LMDB-based stream state store for durable batch tracking.

This implementation uses LMDB (Lightning Memory-Mapped Database) for fast,
embedded, durable storage of batch processing state. It can be used with any
loader (Kafka, PostgreSQL, etc.) to provide crash recovery and idempotency.
"""

import json
import logging
from pathlib import Path
from typing import Dict, List, Optional

import lmdb

from .state import BatchIdentifier, StreamStateStore
from .types import BlockRange, ResumeWatermark


class LMDBStreamStateStore(StreamStateStore):
env: lmdb.Environment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this matters for systems that provide users with documentation using these comments, but I think standard style is to put the class attributes after the comment block.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will correct.

"""
Generic LMDB-based state store for tracking processed batches.

Uses LMDB for fast, durable key-value storage with ACID transactions.
Tracks individual batches with unique hash-based IDs to support:
- Crash recovery and resume
- Idempotency (duplicate detection)
- Reorg handling (invalidate by block hash)
- Gap detection for parallel loading

Uses two LMDB sub-databases for efficient queries:
1. "batches" - Individual batch records keyed by batch_id
2. "metadata" - Max block metadata per network for fast resume position queries

Batch database layout:
- Key: {connection_name}|{table_name}|{batch_id}
- Value: JSON with {network, start_block, end_block, end_hash, start_parent_hash}

Metadata database layout:
- Key: {connection_name}|{table_name}|{network}
- Value: JSON with {end_block, end_hash, start_parent_hash} (max processed block)
"""

def __init__(
self,
connection_name: str,
data_dir: str = '.amp_state',
map_size: int = 10 * 1024 * 1024 * 1024,
sync: bool = True,
):
"""
Initialize LMDB state store with two sub-databases.

Args:
connection_name: Name of the connection (for multi-connection support)
data_dir: Directory to store LMDB database files
map_size: Maximum database size in bytes (default: 10GB)
sync: Whether to sync writes to disk (True for durability, False for speed)
"""
self.connection_name = connection_name
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)

self.logger = logging.getLogger(__name__)

self.env = lmdb.open(str(self.data_dir), map_size=map_size, sync=sync, max_dbs=2)

self.batches_db = self.env.open_db(b'batches')
self.metadata_db = self.env.open_db(b'metadata')

self.logger.info(f'Initialized LMDB state store at {self.data_dir} with 2 sub-databases')

def _make_batch_key(self, connection_name: str, table_name: str, batch_id: str) -> bytes:
"""Create composite key for batch database."""
return f'{connection_name}|{table_name}|{batch_id}'.encode('utf-8')

def _make_metadata_key(self, connection_name: str, table_name: str, network: str) -> bytes:
"""Create composite key for metadata database."""
return f'{connection_name}|{table_name}|{network}'.encode('utf-8')

def _parse_key(self, key: bytes) -> tuple[str, str, str]:
"""Parse composite key into (connection_name, table_name, batch_id/network)."""
parts = key.decode('utf-8').split('|')
return parts[0], parts[1], parts[2]

def _serialize_batch(self, batch: BatchIdentifier) -> bytes:
"""Serialize BatchIdentifier to JSON bytes."""
batch_value_dict = {
'network': batch.network,
'start_block': batch.start_block,
'end_block': batch.end_block,
'end_hash': batch.end_hash,
'start_parent_hash': batch.start_parent_hash,
}
return json.dumps(batch_value_dict).encode('utf-8')

def _serialize_metadata(self, end_block: int, end_hash: str, start_parent_hash: str) -> bytes:
"""Serialize metadata to JSON bytes."""
meta_value_dict = {
'end_block': end_block,
'end_hash': end_hash,
'start_parent_hash': start_parent_hash,
}
return json.dumps(meta_value_dict).encode('utf-8')

def _deserialize_batch(self, value: bytes) -> Dict:
"""Deserialize batch data from JSON bytes."""
return json.loads(value.decode('utf-8'))

def is_processed(self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier]) -> bool:
"""
Check if all given batches have already been processed.

Args:
connection_name: Connection identifier
table_name: Name of the table being loaded
batch_ids: List of batch identifiers to check

Returns:
True only if ALL batches are already processed
"""
if not batch_ids:
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the InMemoryStreamStateStore we return False for this case when empty batch_ids list is provided

if not batch_ids:
return False

It doesn't look we check for empty batch_ids before calling this, so this code path doesn't ever get hit; however it would be good to align the implementations for consistency.

You have any preference @incrypto32 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No preference, i'll update the code to return False for consistency


with self.env.begin(db=self.batches_db) as txn:
for batch_id in batch_ids:
key = self._make_batch_key(connection_name, table_name, batch_id.unique_id)
value = txn.get(key)
if value is None:
return False

return True

def mark_processed(self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier]) -> None:
"""
Mark batches as processed in durable storage.

Atomically updates both batch records and metadata (max block per network).

Args:
connection_name: Connection identifier
table_name: Name of the table being loaded
batch_ids: List of batch identifiers to mark as processed
"""
with self.env.begin(write=True) as txn:
for batch in batch_ids:
batch_key = self._make_batch_key(connection_name, table_name, batch.unique_id)
batch_value = self._serialize_batch(batch)
txn.put(batch_key, batch_value, db=self.batches_db)

meta_key = self._make_metadata_key(connection_name, table_name, batch.network)
current_meta = txn.get(meta_key, db=self.metadata_db)

should_update = False
if current_meta is None:
should_update = True
else:
current_meta_dict = self._deserialize_batch(current_meta)
if batch.end_block > current_meta_dict['end_block']:
should_update = True

if should_update:
meta_value = self._serialize_metadata(batch.end_block, batch.end_hash, batch.start_parent_hash)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No start block?

txn.put(meta_key, meta_value, db=self.metadata_db)

self.logger.debug(f'Marked {len(batch_ids)} batches as processed in {table_name}')

def get_resume_position(
self, connection_name: str, table_name: str, detect_gaps: bool = False
) -> Optional[ResumeWatermark]:
"""
Get the resume watermark (max processed block per network).

Reads only from metadata database. Does not scan batch records.

Args:
connection_name: Connection identifier
table_name: Destination table name
detect_gaps: If True, detect gaps. Not implemented - raises error.

Returns:
ResumeWatermark with max block ranges for all networks, or None if no state exists

Raises:
NotImplementedError: If detect_gaps=True
"""
if detect_gaps:
raise NotImplementedError('Gap detection not implemented in LMDB state store')

prefix = f'{connection_name}|{table_name}|'.encode('utf-8')
ranges = []

with self.env.begin(db=self.metadata_db) as txn:
cursor = txn.cursor()

if not cursor.set_range(prefix):
return None

for key, value in cursor:
if not key.startswith(prefix):
break

try:
_, _, network = self._parse_key(key)
meta_data = self._deserialize_batch(value)

ranges.append(
BlockRange(
network=network,
start=meta_data['end_block'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is an invalidation range (open ended) hence the start=end?

Let's leave a comment explaining why above, so this doesn't trip us up in the future. Maybe linking to crash recovery explanation or code.

end=meta_data['end_block'],
hash=meta_data.get('end_hash'),
prev_hash=meta_data.get('start_parent_hash'),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This start_parent_hash value is (bear with me because this is a mouthful) the hash of the parent block of the start block of the last range that was consumed?

start_parent -> start_block --..--> end_block/end_hash/inval_range_start_block

This hash is disconnected from the other values in the invalidation range. Did you have intention for using this start_parent_hash? If not I'd suggest we don't set prev_hash at all.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops! this was an oversight. corrected

)
)

except (json.JSONDecodeError, KeyError) as e:
self.logger.warning(f'Failed to parse metadata: {e}')
continue

if not ranges:
return None

return ResumeWatermark(ranges=ranges)

def invalidate_from_block(
self, connection_name: str, table_name: str, network: str, from_block: int
) -> List[BatchIdentifier]:
"""
Invalidate (delete) all batches from a specific block onwards.

Used for reorg handling to remove invalidated data. Requires full scan
of batches database to find matching batches.

Args:
connection_name: Connection identifier
table_name: Name of the table
network: Network name
from_block: Block number to invalidate from (inclusive)

Returns:
List of BatchIdentifier objects that were invalidated
"""
prefix = f'{connection_name}|{table_name}|'.encode('utf-8')
invalidated_batch_ids = []
keys_to_delete = []

with self.env.begin(db=self.batches_db) as txn:
cursor = txn.cursor()

if not cursor.set_range(prefix):
return []

for key, value in cursor:
if not key.startswith(prefix):
break

try:
batch_data = self._deserialize_batch(value)

if batch_data['network'] == network and batch_data['end_block'] >= from_block:
batch_id = BatchIdentifier(
network=batch_data['network'],
start_block=batch_data['start_block'],
end_block=batch_data['end_block'],
end_hash=batch_data.get('end_hash'),
start_parent_hash=batch_data.get('start_parent_hash'),
)
invalidated_batch_ids.append(batch_id)
keys_to_delete.append(key)

except (json.JSONDecodeError, KeyError) as e:
self.logger.warning(f'Failed to parse batch data during invalidation: {e}')
continue

if keys_to_delete:
with self.env.begin(write=True) as txn:
for key in keys_to_delete:
txn.delete(key, db=self.batches_db)

meta_key = self._make_metadata_key(connection_name, table_name, network)

remaining_batches = []
cursor = txn.cursor(db=self.batches_db)
if cursor.set_range(prefix):
for key, value in cursor:
if not key.startswith(prefix):
break
try:
batch_data = self._deserialize_batch(value)
if batch_data['network'] == network:
remaining_batches.append(batch_data)
except (json.JSONDecodeError, KeyError) as e:
self.logger.warning(f'Failed to parse batch data during metadata recalculation: {e}')
continue

if remaining_batches:
remaining_batches.sort(key=lambda b: b['end_block'])
max_batch = remaining_batches[-1]
meta_value = self._serialize_metadata(
max_batch['end_block'], max_batch.get('end_hash'), max_batch.get('start_parent_hash')
)
txn.put(meta_key, meta_value, db=self.metadata_db)
else:
txn.delete(meta_key, db=self.metadata_db)

self.logger.info(
f'Invalidated {len(invalidated_batch_ids)} batches from block {from_block} on {network} in {table_name}'
)

return invalidated_batch_ids

def cleanup_before_block(self, connection_name: str, table_name: str, network: str, before_block: int) -> None:
"""
Clean up old batch records before a specific block.

Removes batches where end_block < before_block. Requires full scan
to find matching batches for the given network.

Args:
connection_name: Connection identifier
table_name: Name of the table
network: Network name
before_block: Block number to clean up before (exclusive)
"""
prefix = f'{connection_name}|{table_name}|'.encode('utf-8')
keys_to_delete = []

with self.env.begin(db=self.batches_db) as txn:
cursor = txn.cursor()

if not cursor.set_range(prefix):
return

for key, value in cursor:
if not key.startswith(prefix):
break

try:
batch_data = self._deserialize_batch(value)

if batch_data['network'] == network and batch_data['end_block'] < before_block:
keys_to_delete.append(key)

except (json.JSONDecodeError, KeyError) as e:
self.logger.warning(f'Failed to parse batch data during cleanup: {e}')
continue

if keys_to_delete:
with self.env.begin(write=True, db=self.batches_db) as txn:
for key in keys_to_delete:
txn.delete(key)

self.logger.info(
f'Cleaned up {len(keys_to_delete)} old batches before block {before_block} on {network} in {table_name}'
)

def close(self) -> None:
"""Close the LMDB environment."""
if self.env:
self.env.close()
self.logger.info('Closed LMDB state store')

def __enter__(self) -> 'LMDBStreamStateStore':
"""Context manager entry."""
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
Loading
Loading