From 931a4abdf57feb95526378b6cb7c29433bf15fba Mon Sep 17 00:00:00 2001 From: Ford Date: Mon, 2 Feb 2026 13:28:31 -0800 Subject: [PATCH] Add separate async client --- src/amp/admin/async_client.py | 168 +++++++ src/amp/admin/async_datasets.py | 244 ++++++++++ src/amp/admin/async_jobs.py | 187 ++++++++ src/amp/admin/async_schema.py | 65 +++ src/amp/async_client.py | 718 +++++++++++++++++++++++++++++ src/amp/registry/async_client.py | 180 ++++++++ src/amp/registry/async_datasets.py | 437 ++++++++++++++++++ tests/unit/test_async_client.py | 328 +++++++++++++ 8 files changed, 2327 insertions(+) create mode 100644 src/amp/admin/async_client.py create mode 100644 src/amp/admin/async_datasets.py create mode 100644 src/amp/admin/async_jobs.py create mode 100644 src/amp/admin/async_schema.py create mode 100644 src/amp/async_client.py create mode 100644 src/amp/registry/async_client.py create mode 100644 src/amp/registry/async_datasets.py create mode 100644 tests/unit/test_async_client.py diff --git a/src/amp/admin/async_client.py b/src/amp/admin/async_client.py new file mode 100644 index 0000000..33ae241 --- /dev/null +++ b/src/amp/admin/async_client.py @@ -0,0 +1,168 @@ +"""Async HTTP client for Amp Admin API. + +This module provides the async AdminClient class for communicating +with the Amp Admin API over HTTP using asyncio and httpx. +""" + +import os +from typing import Optional + +import httpx + +from .errors import map_error_response + + +class AsyncAdminClient: + """Async HTTP client for Amp Admin API. + + Provides access to Admin API endpoints through sub-clients for + datasets, jobs, and schema operations using async/await. + + Args: + base_url: Base URL for Admin API (e.g., 'http://localhost:8080') + auth_token: Optional Bearer token for authentication (highest priority) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) + + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth + + Example: + >>> # Use amp auth from file + >>> async with AsyncAdminClient('http://localhost:8080', auth=True) as client: + ... datasets = await client.datasets.list_all() + >>> + >>> # Use manual token + >>> async with AsyncAdminClient('http://localhost:8080', auth_token='your-token') as client: + ... job = await client.jobs.get(123) + """ + + def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = False): + """Initialize async Admin API client. + + Args: + base_url: Base URL for Admin API (e.g., 'http://localhost:8080') + auth_token: Optional Bearer token for authentication + auth: If True, load auth token from ~/.amp/cache + + Raises: + ValueError: If both auth=True and auth_token are provided + """ + if auth and auth_token: + raise ValueError('Cannot specify both auth=True and auth_token. Choose one authentication method.') + + self.base_url = base_url.rstrip('/') + + # Resolve auth token provider with priority: explicit param > env var > auth file + self._get_token = None + if auth_token: + # Priority 1: Explicit auth_token parameter (static token) + self._get_token = lambda: auth_token + elif os.getenv('AMP_AUTH_TOKEN'): + # Priority 2: AMP_AUTH_TOKEN environment variable (static token) + env_token = os.getenv('AMP_AUTH_TOKEN') + self._get_token = lambda: env_token + elif auth: + # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing) + from amp.auth import AuthService + + auth_service = AuthService() + self._get_token = auth_service.get_token # Callable that auto-refreshes + + # Create async HTTP client (no auth header yet - will be added per-request) + self._http = httpx.AsyncClient( + base_url=self.base_url, + timeout=30.0, + follow_redirects=True, + ) + + async def _request( + self, method: str, path: str, json: Optional[dict] = None, params: Optional[dict] = None, **kwargs + ) -> httpx.Response: + """Make async HTTP request with error handling. + + Args: + method: HTTP method (GET, POST, DELETE, etc.) + path: API endpoint path (e.g., '/datasets') + json: Optional JSON request body + params: Optional query parameters + **kwargs: Additional arguments passed to httpx.request() + + Returns: + HTTP response object + + Raises: + AdminAPIError: If the API returns an error response + """ + # Add auth header dynamically (auto-refreshes if needed) + headers = kwargs.get('headers', {}) + if self._get_token: + headers['Authorization'] = f'Bearer {self._get_token()}' + kwargs['headers'] = headers + + response = await self._http.request(method, path, json=json, params=params, **kwargs) + + # Handle error responses + if response.status_code >= 400: + try: + error_data = response.json() + raise map_error_response(response.status_code, error_data) + except ValueError: + # Response is not JSON, fall back to generic HTTP error + response.raise_for_status() + + return response + + @property + def datasets(self): + """Access async datasets client. + + Returns: + AsyncDatasetsClient for dataset operations + """ + from .async_datasets import AsyncDatasetsClient + + return AsyncDatasetsClient(self) + + @property + def jobs(self): + """Access async jobs client. + + Returns: + AsyncJobsClient for job operations + """ + from .async_jobs import AsyncJobsClient + + return AsyncJobsClient(self) + + @property + def schema(self): + """Access async schema client. + + Returns: + AsyncSchemaClient for schema operations + """ + from .async_schema import AsyncSchemaClient + + return AsyncSchemaClient(self) + + async def close(self): + """Close the HTTP client and release resources. + + Example: + >>> client = AsyncAdminClient('http://localhost:8080') + >>> try: + ... datasets = await client.datasets.list_all() + ... finally: + ... await client.close() + """ + await self._http.aclose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/src/amp/admin/async_datasets.py b/src/amp/admin/async_datasets.py new file mode 100644 index 0000000..f87668c --- /dev/null +++ b/src/amp/admin/async_datasets.py @@ -0,0 +1,244 @@ +"""Async datasets client for Admin API. + +This module provides the AsyncDatasetsClient class for managing datasets, +including registration, deployment, versioning, and manifest operations. +""" + +from typing import TYPE_CHECKING, Dict, Optional + +from amp.utils.manifest_inspector import describe_manifest + +from . import models + +if TYPE_CHECKING: + from .async_client import AsyncAdminClient + + +class AsyncDatasetsClient: + """Async client for dataset operations. + + Provides async methods for registering, deploying, listing, and managing datasets + through the Admin API. + + Args: + admin_client: Parent AsyncAdminClient instance + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... datasets = await client.datasets.list_all() + """ + + def __init__(self, admin_client: 'AsyncAdminClient'): + """Initialize async datasets client. + + Args: + admin_client: Parent AsyncAdminClient instance + """ + self._admin = admin_client + + async def register(self, namespace: str, name: str, version: str, manifest: dict) -> None: + """Register a dataset manifest. + + Registers a new dataset configuration in the server's local registry. + The manifest defines tables, dependencies, and extraction logic. + + Args: + namespace: Dataset namespace (e.g., '_') + name: Dataset name + version: Semantic version (e.g., '1.0.0') or tag ('latest', 'dev') + manifest: Dataset manifest dict (kind='manifest') + + Raises: + InvalidManifestError: If manifest is invalid + DependencyValidationError: If dependencies are invalid + ManifestRegistrationError: If registration fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.datasets.register('_', 'my_dataset', '1.0.0', manifest) + """ + request_data = models.RegisterRequest(namespace=namespace, name=name, version=version, manifest=manifest) + + await self._admin._request('POST', '/datasets', json=request_data.model_dump(mode='json', exclude_none=True)) + + async def deploy( + self, + namespace: str, + name: str, + revision: str, + end_block: Optional[str] = None, + parallelism: Optional[int] = None, + ) -> models.DeployResponse: + """Deploy a dataset version. + + Triggers data extraction for the specified dataset version. + + Args: + namespace: Dataset namespace + name: Dataset name + revision: Version tag ('latest', 'dev', '1.0.0', etc.) + end_block: Optional end block ('latest', '-100', '1000000', or null) + parallelism: Optional number of parallel workers + + Returns: + DeployResponse with job_id + + Raises: + DatasetNotFoundError: If dataset/version not found + SchedulerError: If deployment fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... response = await client.datasets.deploy('_', 'my_dataset', '1.0.0', parallelism=4) + ... print(f'Job ID: {response.job_id}') + """ + path = f'/datasets/{namespace}/{name}/versions/{revision}/deploy' + + # Build request body (POST requires JSON body, not query params) + body = {} + if end_block is not None: + body['end_block'] = end_block + if parallelism is not None: + body['parallelism'] = parallelism + + response = await self._admin._request('POST', path, json=body if body else {}) + return models.DeployResponse.model_validate(response.json()) + + async def list_all(self) -> models.DatasetsResponse: + """List all registered datasets. + + Returns all datasets across all namespaces with version information. + + Returns: + DatasetsResponse with list of datasets + + Raises: + ListAllDatasetsError: If listing fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... datasets = await client.datasets.list_all() + ... for ds in datasets.datasets: + ... print(f'{ds.namespace}/{ds.name}: {ds.latest_version}') + """ + response = await self._admin._request('GET', '/datasets') + return models.DatasetsResponse.model_validate(response.json()) + + async def get_versions(self, namespace: str, name: str) -> models.VersionsResponse: + """List all versions of a dataset. + + Returns version information including semantic versions and special tags. + + Args: + namespace: Dataset namespace + name: Dataset name + + Returns: + VersionsResponse with version list + + Raises: + DatasetNotFoundError: If dataset not found + ListDatasetVersionsError: If listing fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... versions = await client.datasets.get_versions('_', 'eth_firehose') + ... print(f'Latest: {versions.special_tags.latest}') + """ + path = f'/datasets/{namespace}/{name}/versions' + response = await self._admin._request('GET', path) + return models.VersionsResponse.model_validate(response.json()) + + async def get_version(self, namespace: str, name: str, revision: str) -> models.VersionInfo: + """Get detailed information about a specific dataset version. + + Args: + namespace: Dataset namespace + name: Dataset name + revision: Version tag or semantic version + + Returns: + VersionInfo with dataset details + + Raises: + DatasetNotFoundError: If dataset/version not found + GetDatasetVersionError: If retrieval fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... info = await client.datasets.get_version('_', 'eth_firehose', '1.0.0') + ... print(f'Kind: {info.kind}') + """ + path = f'/datasets/{namespace}/{name}/versions/{revision}' + response = await self._admin._request('GET', path) + return models.VersionInfo.model_validate(response.json()) + + async def get_manifest(self, namespace: str, name: str, revision: str) -> dict: + """Get the manifest for a specific dataset version. + + Args: + namespace: Dataset namespace + name: Dataset name + revision: Version tag or semantic version + + Returns: + Manifest dict + + Raises: + DatasetNotFoundError: If dataset/version not found + GetManifestError: If retrieval fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... manifest = await client.datasets.get_manifest('_', 'eth_firehose', '1.0.0') + ... print(manifest['kind']) + """ + path = f'/datasets/{namespace}/{name}/versions/{revision}/manifest' + response = await self._admin._request('GET', path) + return response.json() + + async def describe( + self, namespace: str, name: str, revision: str = 'latest' + ) -> Dict[str, list[Dict[str, str | bool]]]: + """Get a structured summary of tables and columns in a dataset. + + Returns a dictionary mapping table names to lists of column information, + making it easy to programmatically inspect the dataset schema. + + Args: + namespace: Dataset namespace + name: Dataset name + revision: Version tag (default: 'latest') + + Returns: + dict: Mapping of table names to column information. + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... schema = await client.datasets.describe('_', 'eth_firehose', 'latest') + ... for table_name, columns in schema.items(): + ... print(f"Table: {table_name}") + """ + manifest = await self.get_manifest(namespace, name, revision) + return describe_manifest(manifest) + + async def delete(self, namespace: str, name: str) -> None: + """Delete all versions and metadata for a dataset. + + Removes all manifest links and version tags for the dataset. + Orphaned manifests (not referenced by other datasets) are also deleted. + + Args: + namespace: Dataset namespace + name: Dataset name + + Raises: + InvalidPathError: If namespace/name invalid + UnlinkDatasetManifestsError: If deletion fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.datasets.delete('_', 'my_old_dataset') + """ + path = f'/datasets/{namespace}/{name}' + await self._admin._request('DELETE', path) diff --git a/src/amp/admin/async_jobs.py b/src/amp/admin/async_jobs.py new file mode 100644 index 0000000..c811e09 --- /dev/null +++ b/src/amp/admin/async_jobs.py @@ -0,0 +1,187 @@ +"""Async jobs client for Admin API. + +This module provides the AsyncJobsClient class for monitoring and managing +extraction jobs using async/await. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Optional + +from . import models + +if TYPE_CHECKING: + from .async_client import AsyncAdminClient + + +class AsyncJobsClient: + """Async client for job operations. + + Provides async methods for monitoring, managing, and waiting for extraction jobs. + + Args: + admin_client: Parent AsyncAdminClient instance + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... job = await client.jobs.get(123) + ... print(f'Status: {job.status}') + """ + + def __init__(self, admin_client: 'AsyncAdminClient'): + """Initialize async jobs client. + + Args: + admin_client: Parent AsyncAdminClient instance + """ + self._admin = admin_client + + async def get(self, job_id: int) -> models.JobInfo: + """Get job information by ID. + + Args: + job_id: Job ID to retrieve + + Returns: + JobInfo with job details + + Raises: + JobNotFoundError: If job not found + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... job = await client.jobs.get(123) + ... print(f'Status: {job.status}') + """ + path = f'/jobs/{job_id}' + response = await self._admin._request('GET', path) + return models.JobInfo.model_validate(response.json()) + + async def list(self, limit: int = 50, last_job_id: Optional[int] = None) -> models.JobsResponse: + """List jobs with pagination. + + Args: + limit: Maximum number of jobs to return (default: 50, max: 1000) + last_job_id: Cursor from previous page's next_cursor field + + Returns: + JobsResponse with jobs and optional next_cursor + + Raises: + ListJobsError: If listing fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... response = await client.jobs.list(limit=100) + ... for job in response.jobs: + ... print(f'{job.id}: {job.status}') + """ + params = {'limit': limit} + if last_job_id is not None: + params['last_job_id'] = last_job_id + + response = await self._admin._request('GET', '/jobs', params=params) + return models.JobsResponse.model_validate(response.json()) + + async def wait_for_completion( + self, job_id: int, poll_interval: int = 5, timeout: Optional[int] = None + ) -> models.JobInfo: + """Poll job until completion or timeout. + + Continuously polls the job status until it reaches a terminal state + (Completed, Failed, or Stopped). Uses asyncio.sleep for non-blocking waits. + + Args: + job_id: Job ID to monitor + poll_interval: Seconds between status checks (default: 5) + timeout: Optional timeout in seconds (default: None = infinite) + + Returns: + Final JobInfo when job completes + + Raises: + JobNotFoundError: If job not found + TimeoutError: If timeout is reached before completion + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... deploy_resp = await client.datasets.deploy('_', 'my_dataset', '1.0.0') + ... final_job = await client.jobs.wait_for_completion(deploy_resp.job_id) + ... print(f'Final status: {final_job.status}') + """ + elapsed = 0.0 + terminal_states = {'Completed', 'Failed', 'Stopped'} + + while True: + job = await self.get(job_id) + + # Check if job reached terminal state + if job.status in terminal_states: + return job + + # Check timeout + if timeout is not None and elapsed >= timeout: + raise TimeoutError( + f'Job {job_id} did not complete within {timeout} seconds. Current status: {job.status}' + ) + + # Wait before next poll (non-blocking) + await asyncio.sleep(poll_interval) + elapsed += poll_interval + + async def stop(self, job_id: int) -> None: + """Stop a running job. + + Requests the job to stop gracefully. The job will transition through + StopRequested and Stopping states before reaching Stopped. + + Args: + job_id: Job ID to stop + + Raises: + JobNotFoundError: If job not found + JobStopError: If stop request fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.jobs.stop(123) + """ + path = f'/jobs/{job_id}/stop' + await self._admin._request('POST', path) + + async def delete(self, job_id: int) -> None: + """Delete a job in terminal state. + + Only jobs in terminal states (Completed, Failed, Stopped) can be deleted. + + Args: + job_id: Job ID to delete + + Raises: + JobNotFoundError: If job not found + JobDeleteError: If job is not in terminal state or deletion fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.jobs.delete(123) + """ + path = f'/jobs/{job_id}' + await self._admin._request('DELETE', path) + + async def delete_many(self, job_ids: list[int]) -> None: + """Delete multiple jobs in bulk. + + All specified jobs must be in terminal states. + + Args: + job_ids: List of job IDs to delete + + Raises: + JobsDeleteError: If any deletion fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.jobs.delete_many([123, 124, 125]) + """ + await self._admin._request('DELETE', '/jobs', json={'job_ids': job_ids}) diff --git a/src/amp/admin/async_schema.py b/src/amp/admin/async_schema.py new file mode 100644 index 0000000..74297eb --- /dev/null +++ b/src/amp/admin/async_schema.py @@ -0,0 +1,65 @@ +"""Async schema client for Admin API. + +This module provides the AsyncSchemaClient class for querying output schemas +of SQL queries without executing them using async/await. +""" + +from typing import TYPE_CHECKING + +from . import models + +if TYPE_CHECKING: + from .async_client import AsyncAdminClient + + +class AsyncSchemaClient: + """Async client for schema operations. + + Provides async methods for validating SQL queries and determining output schemas + using DataFusion's query planner. + + Args: + admin_client: Parent AsyncAdminClient instance + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... schema = await client.schema.get_output_schema('SELECT * FROM eth.blocks', True) + """ + + def __init__(self, admin_client: 'AsyncAdminClient'): + """Initialize async schema client. + + Args: + admin_client: Parent AsyncAdminClient instance + """ + self._admin = admin_client + + async def get_output_schema(self, sql_query: str, is_sql_dataset: bool = True) -> models.OutputSchemaResponse: + """Get output schema for a SQL query. + + Validates the query and returns the Arrow schema that would be produced, + without actually executing the query. + + Args: + sql_query: SQL query to analyze + is_sql_dataset: Whether this is for a SQL dataset (default: True) + + Returns: + OutputSchemaResponse with Arrow schema + + Raises: + GetOutputSchemaError: If schema analysis fails + DependencyValidationError: If query references invalid dependencies + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... schema_resp = await client.schema.get_output_schema( + ... 'SELECT block_num, hash FROM eth.blocks WHERE block_num > 1000000', + ... is_sql_dataset=True + ... ) + ... print(schema_resp.schema) + """ + request_data = models.OutputSchemaRequest(sql_query=sql_query, is_sql_dataset=is_sql_dataset) + + response = await self._admin._request('POST', '/schema', json=request_data.model_dump(mode='json')) + return models.OutputSchemaResponse.model_validate(response.json()) diff --git a/src/amp/async_client.py b/src/amp/async_client.py new file mode 100644 index 0000000..9b5ecc1 --- /dev/null +++ b/src/amp/async_client.py @@ -0,0 +1,718 @@ +"""Async Flight SQL client with data loading capabilities. + +This module provides the AsyncAmpClient class for async operations +with the Flight SQL server and Admin/Registry APIs. + +The async client is optimized for: +- Non-blocking HTTP API calls (Admin, Registry) +- Concurrent operations using asyncio +- Streaming data with async iteration + +Note: Flight SQL (gRPC) operations currently remain synchronous as PyArrow's +Flight client doesn't have native async support. For streaming operations, +consider using run_in_executor or the sync Client. +""" + +import asyncio +import logging +import os +from typing import AsyncIterator, Dict, Iterator, List, Optional, Union + +import pyarrow as pa +from google.protobuf.any_pb2 import Any +from pyarrow import flight +from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory + +from . import FlightSql_pb2 +from .config.connection_manager import ConnectionManager +from .config.label_manager import LabelManager +from .loaders.registry import create_loader, get_available_loaders +from .loaders.types import LabelJoinConfig, LoadConfig, LoadMode, LoadResult +from .streaming import ( + ReorgAwareStream, + ResumeWatermark, + StreamingResultIterator, +) + + +class AuthMiddleware(ClientMiddleware): + """Flight middleware to add Bearer token authentication header.""" + + def __init__(self, get_token): + """Initialize auth middleware. + + Args: + get_token: Callable that returns the current access token + """ + self.get_token = get_token + + def sending_headers(self): + """Add Authorization header to outgoing requests.""" + return {'authorization': f'Bearer {self.get_token()}'} + + +class AuthMiddlewareFactory(ClientMiddlewareFactory): + """Factory for creating auth middleware instances.""" + + def __init__(self, get_token): + """Initialize auth middleware factory. + + Args: + get_token: Callable that returns the current access token + """ + self.get_token = get_token + + def start_call(self, info): + """Create auth middleware for each call.""" + return AuthMiddleware(self.get_token) + + +class AsyncQueryBuilder: + """Async chainable query builder for data loading operations. + + Provides async versions of query operations. + """ + + def __init__(self, client: 'AsyncAmpClient', query: str): + self.client = client + self.query = query + self._result_cache = None + self._dependencies: Dict[str, str] = {} + self.logger = logging.getLogger(__name__) + + async def load( + self, + connection: str, + destination: str, + config: Dict[str, any] = None, + label_config: Optional[LabelJoinConfig] = None, + **kwargs, + ) -> Union[LoadResult, AsyncIterator[LoadResult]]: + """ + Async load query results to specified destination. + + Note: The actual data loading operations run synchronously in a thread + pool executor since PyArrow Flight doesn't support native async. + + Args: + connection: Named connection or connection name for auto-discovery + destination: Target destination (table name, key, path, etc.) + config: Inline configuration dict (alternative to connection) + label_config: Optional LabelJoinConfig for joining with label data + **kwargs: Additional loader-specific options + + Returns: + LoadResult or async iterator of LoadResults + """ + # Handle streaming mode + if kwargs.get('stream', False): + kwargs.pop('stream') + streaming_query = self._ensure_streaming_query(self.query) + return await self.client.query_and_load_streaming( + query=streaming_query, + destination=destination, + connection_name=connection, + config=config, + label_config=label_config, + **kwargs, + ) + + # Validate that parallel_config is only used with stream=True + if kwargs.get('parallel_config'): + raise ValueError('parallel_config requires stream=True') + + kwargs.setdefault('read_all', False) + + return await self.client.query_and_load( + query=self.query, + destination=destination, + connection_name=connection, + config=config, + label_config=label_config, + **kwargs, + ) + + def _ensure_streaming_query(self, query: str) -> str: + """Ensure query has SETTINGS stream = true""" + query = query.strip().rstrip(';') + if 'SETTINGS stream = true' not in query.upper(): + query += ' SETTINGS stream = true' + return query + + async def stream(self) -> AsyncIterator[pa.RecordBatch]: + """Stream query results as Arrow batches asynchronously.""" + self.logger.debug(f'Starting async stream for query: {self.query[:50]}...') + # Run synchronous Flight SQL operation in executor + loop = asyncio.get_event_loop() + batches = await loop.run_in_executor(None, lambda: list(self.client.get_sql_sync(self.query, read_all=False))) + for batch in batches: + yield batch + + async def to_arrow(self) -> pa.Table: + """Get query results as Arrow table asynchronously.""" + if self._result_cache is None: + self.logger.debug(f'Executing query for Arrow table: {self.query[:50]}...') + loop = asyncio.get_event_loop() + self._result_cache = await loop.run_in_executor( + None, lambda: self.client.get_sql_sync(self.query, read_all=True) + ) + return self._result_cache + + async def to_manifest(self, table_name: str, network: str = 'mainnet') -> dict: + """Generate a dataset manifest from this query asynchronously. + + Automatically fetches the Arrow schema using the Admin API /schema endpoint. + Requires the Client to be initialized with admin_url. + + Args: + table_name: Name for the table in the manifest + network: Network name (default: 'mainnet') + + Returns: + Complete manifest dict ready for registration + """ + # Get schema from Admin API + schema_response = await self.client.schema.get_output_schema(self.query, is_sql_dataset=True) + + # Build manifest structure + manifest = { + 'kind': 'manifest', + 'dependencies': self._dependencies, + 'tables': { + table_name: { + 'input': {'sql': self.query}, + 'schema': schema_response.schema_, + 'network': network, + } + }, + 'functions': {}, + } + return manifest + + def with_dependency(self, alias: str, reference: str) -> 'AsyncQueryBuilder': + """Add a dataset dependency for manifest generation.""" + self._dependencies[alias] = reference + return self + + def __repr__(self): + return f"AsyncQueryBuilder(query='{self.query[:50]}{'...' if len(self.query) > 50 else ''}')" + + +class AsyncAmpClient: + """Async Flight SQL client with data loading capabilities. + + Supports both query operations (via Flight SQL) and optional admin operations + (via async HTTP Admin API) and registry operations (via async Registry API). + + The Flight SQL operations are run in a thread pool executor since PyArrow's + Flight client doesn't have native async support. HTTP operations (Admin, + Registry) are fully async. + + Args: + url: Flight SQL URL (for backward compatibility, treated as query_url) + query_url: Query endpoint URL via Flight SQL (e.g., 'grpc://localhost:1602') + admin_url: Optional Admin API URL (e.g., 'http://localhost:8080') + registry_url: Optional Registry API URL (default: staging registry) + auth_token: Optional Bearer token for authentication (highest priority) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) + + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth + + Example: + >>> # Query with async admin operations + >>> async with AsyncAmpClient( + ... query_url='grpc://localhost:1602', + ... admin_url='http://localhost:8080', + ... auth=True + ... ) as client: + ... datasets = await client.datasets.list_all() + ... table = await client.sql("SELECT * FROM eth.blocks LIMIT 10").to_arrow() + """ + + def __init__( + self, + url: Optional[str] = None, + query_url: Optional[str] = None, + admin_url: Optional[str] = None, + registry_url: str = 'https://api.registry.amp.staging.thegraph.com', + auth_token: Optional[str] = None, + auth: bool = False, + ): + # Backward compatibility: url parameter → query_url + if url and not query_url: + query_url = url + + # Resolve auth token provider with priority: explicit param > env var > auth file + get_token = None + if auth_token: + def get_token(): + return auth_token + elif os.getenv('AMP_AUTH_TOKEN'): + env_token = os.getenv('AMP_AUTH_TOKEN') + + def get_token(): + return env_token + elif auth: + from amp.auth import AuthService + + auth_service = AuthService() + get_token = auth_service.get_token + + # Initialize Flight SQL client + if query_url: + if get_token: + middleware = [AuthMiddlewareFactory(get_token)] + self.conn = flight.connect(query_url, middleware=middleware) + else: + self.conn = flight.connect(query_url) + else: + raise ValueError('Either url or query_url must be provided for Flight SQL connection') + + # Initialize managers + self.connection_manager = ConnectionManager() + self.label_manager = LabelManager() + self.logger = logging.getLogger(__name__) + + # Store URLs and auth params for lazy initialization of async clients + self._admin_url = admin_url + self._registry_url = registry_url + self._auth_token = auth_token + self._auth = auth + + # Lazy-initialized async clients + self._admin_client = None + self._registry_client = None + + def sql(self, query: str) -> AsyncQueryBuilder: + """ + Create an async chainable query builder. + + Args: + query: SQL query string + + Returns: + AsyncQueryBuilder instance for chaining operations + """ + return AsyncQueryBuilder(self, query) + + def configure_connection(self, name: str, loader: str, config: Dict[str, any]) -> None: + """Configure a named connection for reuse.""" + self.connection_manager.add_connection(name, loader, config) + + def configure_label(self, name: str, csv_path: str, binary_columns: Optional[List[str]] = None) -> None: + """Configure a label dataset from a CSV file for joining with streaming data.""" + self.label_manager.add_label(name, csv_path, binary_columns) + + def list_connections(self) -> Dict[str, str]: + """List all configured connections.""" + return self.connection_manager.list_connections() + + def get_available_loaders(self) -> List[str]: + """Get list of available data loaders.""" + return get_available_loaders() + + # Async Admin API access (optional, requires admin_url) + @property + def datasets(self): + """Access async datasets client for Admin API operations. + + Returns: + AsyncDatasetsClient for dataset registration, deployment, and management + + Raises: + ValueError: If admin_url was not provided during Client initialization + """ + if not self._admin_url: + raise ValueError( + 'Admin API not configured. Provide admin_url parameter to AsyncAmpClient() ' + 'to enable dataset management operations.' + ) + if not self._admin_client: + from amp.admin.async_client import AsyncAdminClient + + if self._auth: + self._admin_client = AsyncAdminClient(self._admin_url, auth=True) + elif self._auth_token or os.getenv('AMP_AUTH_TOKEN'): + token = self._auth_token or os.getenv('AMP_AUTH_TOKEN') + self._admin_client = AsyncAdminClient(self._admin_url, auth_token=token) + else: + self._admin_client = AsyncAdminClient(self._admin_url) + return self._admin_client.datasets + + @property + def jobs(self): + """Access async jobs client for Admin API operations. + + Returns: + AsyncJobsClient for job monitoring and management + + Raises: + ValueError: If admin_url was not provided during Client initialization + """ + if not self._admin_url: + raise ValueError( + 'Admin API not configured. Provide admin_url parameter to AsyncAmpClient() ' + 'to enable job monitoring operations.' + ) + if not self._admin_client: + from amp.admin.async_client import AsyncAdminClient + + if self._auth: + self._admin_client = AsyncAdminClient(self._admin_url, auth=True) + elif self._auth_token or os.getenv('AMP_AUTH_TOKEN'): + token = self._auth_token or os.getenv('AMP_AUTH_TOKEN') + self._admin_client = AsyncAdminClient(self._admin_url, auth_token=token) + else: + self._admin_client = AsyncAdminClient(self._admin_url) + return self._admin_client.jobs + + @property + def schema(self): + """Access async schema client for Admin API operations. + + Returns: + AsyncSchemaClient for SQL query schema analysis + + Raises: + ValueError: If admin_url was not provided during Client initialization + """ + if not self._admin_url: + raise ValueError( + 'Admin API not configured. Provide admin_url parameter to AsyncAmpClient() ' + 'to enable schema analysis operations.' + ) + if not self._admin_client: + from amp.admin.async_client import AsyncAdminClient + + if self._auth: + self._admin_client = AsyncAdminClient(self._admin_url, auth=True) + elif self._auth_token or os.getenv('AMP_AUTH_TOKEN'): + token = self._auth_token or os.getenv('AMP_AUTH_TOKEN') + self._admin_client = AsyncAdminClient(self._admin_url, auth_token=token) + else: + self._admin_client = AsyncAdminClient(self._admin_url) + return self._admin_client.schema + + @property + def registry(self): + """Access async registry client for Registry API operations. + + Returns: + AsyncRegistryClient for dataset discovery, search, and publishing + + Raises: + ValueError: If registry_url was not provided during Client initialization + """ + if not self._registry_url: + raise ValueError( + 'Registry API not configured. Provide registry_url parameter to AsyncAmpClient() ' + 'to enable dataset discovery and search operations.' + ) + if not self._registry_client: + from amp.registry.async_client import AsyncRegistryClient + + if self._auth: + self._registry_client = AsyncRegistryClient(self._registry_url, auth=True) + elif self._auth_token or os.getenv('AMP_AUTH_TOKEN'): + token = self._auth_token or os.getenv('AMP_AUTH_TOKEN') + self._registry_client = AsyncRegistryClient(self._registry_url, auth_token=token) + else: + self._registry_client = AsyncRegistryClient(self._registry_url) + return self._registry_client + + # Synchronous Flight SQL methods (run in executor for async context) + def get_sql_sync(self, query: str, read_all: bool = False): + """Execute SQL query and return Arrow data (synchronous). + + This is the underlying synchronous method used by async wrappers. + """ + command_query = FlightSql_pb2.CommandStatementQuery() + command_query.query = query + + any_command = Any() + any_command.Pack(command_query) + cmd = any_command.SerializeToString() + + flight_descriptor = flight.FlightDescriptor.for_command(cmd) + info = self.conn.get_flight_info(flight_descriptor) + reader = self.conn.do_get(info.endpoints[0].ticket) + + if read_all: + return reader.read_all() + else: + return self._batch_generator(reader) + + def _batch_generator(self, reader) -> Iterator[pa.RecordBatch]: + """Generate batches from Flight reader.""" + while True: + try: + chunk = reader.read_chunk() + yield chunk.data + except StopIteration: + break + + async def get_sql(self, query: str, read_all: bool = False): + """Execute SQL query asynchronously and return Arrow data. + + Runs the synchronous Flight SQL operation in a thread pool executor. + """ + loop = asyncio.get_event_loop() + if read_all: + return await loop.run_in_executor(None, lambda: self.get_sql_sync(query, read_all=True)) + else: + batches = await loop.run_in_executor(None, lambda: list(self.get_sql_sync(query, read_all=False))) + return batches + + async def query_and_load( + self, + query: str, + destination: str, + connection_name: str, + config: Optional[Dict[str, any]] = None, + label_config: Optional[LabelJoinConfig] = None, + **kwargs, + ) -> Union[LoadResult, AsyncIterator[LoadResult]]: + """Execute query and load results directly into target system asynchronously. + + Runs the data loading operation in a thread pool executor. + """ + loop = asyncio.get_event_loop() + + # Run the synchronous query_and_load in executor + def sync_load(): + # Get connection configuration and determine loader type + if connection_name: + try: + connection_info = self.connection_manager.get_connection_info(connection_name) + loader_config = connection_info['config'] + loader_type = connection_info['loader'] + except ValueError as e: + self.logger.error(f'Connection error: {e}') + raise + elif config: + loader_type = config.pop('loader_type', None) + if not loader_type: + raise ValueError("When using inline config, 'loader_type' must be specified") + loader_config = config + else: + raise ValueError('Either connection_name or config must be provided') + + # Extract load options + read_all = kwargs.pop('read_all', False) + load_config = LoadConfig( + batch_size=kwargs.pop('batch_size', 10000), + mode=LoadMode(kwargs.pop('mode', 'append')), + create_table=kwargs.pop('create_table', True), + schema_evolution=kwargs.pop('schema_evolution', False), + **{k: v for k, v in kwargs.items() if k in ['max_retries', 'retry_delay']}, + ) + + for key in ['max_retries', 'retry_delay']: + kwargs.pop(key, None) + + loader_specific_kwargs = kwargs + + if read_all: + table = self.get_sql_sync(query, read_all=True) + return self._load_table_sync( + table, + loader_type, + destination, + loader_config, + load_config, + label_config=label_config, + **loader_specific_kwargs, + ) + else: + batch_stream = self.get_sql_sync(query, read_all=False) + return list( + self._load_stream_sync( + batch_stream, + loader_type, + destination, + loader_config, + load_config, + label_config=label_config, + **loader_specific_kwargs, + ) + ) + + return await loop.run_in_executor(None, sync_load) + + def _load_table_sync( + self, + table: pa.Table, + loader: str, + table_name: str, + config: Dict[str, any], + load_config: LoadConfig, + **kwargs, + ) -> LoadResult: + """Load a complete Arrow Table synchronously.""" + try: + loader_instance = create_loader(loader, config, label_manager=self.label_manager) + + with loader_instance: + return loader_instance.load_table(table, table_name, **load_config.__dict__, **kwargs) + except Exception as e: + self.logger.error(f'Failed to load table: {e}') + return LoadResult( + rows_loaded=0, + duration=0.0, + ops_per_second=0.0, + table_name=table_name, + loader_type=loader, + success=False, + error=str(e), + ) + + def _load_stream_sync( + self, + batch_stream: Iterator[pa.RecordBatch], + loader: str, + table_name: str, + config: Dict[str, any], + load_config: LoadConfig, + **kwargs, + ) -> Iterator[LoadResult]: + """Load from a stream of batches synchronously.""" + try: + loader_instance = create_loader(loader, config, label_manager=self.label_manager) + + with loader_instance: + yield from loader_instance.load_stream(batch_stream, table_name, **load_config.__dict__, **kwargs) + except Exception as e: + self.logger.error(f'Failed to load stream: {e}') + yield LoadResult( + rows_loaded=0, + duration=0.0, + ops_per_second=0.0, + table_name=table_name, + loader_type=loader, + success=False, + error=str(e), + ) + + async def query_and_load_streaming( + self, + query: str, + destination: str, + connection_name: str, + config: Optional[Dict[str, any]] = None, + label_config: Optional[LabelJoinConfig] = None, + with_reorg_detection: bool = True, + resume_watermark: Optional[ResumeWatermark] = None, + **kwargs, + ) -> AsyncIterator[LoadResult]: + """Execute a streaming query and continuously load results asynchronously. + + Runs the streaming operation in a thread pool executor and yields results. + """ + loop = asyncio.get_event_loop() + + # Run streaming query synchronously and collect results + def sync_streaming(): + # Get connection configuration + if connection_name: + try: + connection_info = self.connection_manager.get_connection_info(connection_name) + loader_config = connection_info['config'] + loader_type = connection_info['loader'] + except ValueError as e: + self.logger.error(f'Connection error: {e}') + raise + elif config: + loader_type = config.pop('loader_type', None) + if not loader_type: + raise ValueError("When using inline config, 'loader_type' must be specified") + loader_config = config + else: + raise ValueError('Either connection_name or config must be provided') + + # Extract load config + load_config = LoadConfig( + batch_size=kwargs.pop('batch_size', 10000), + mode=LoadMode(kwargs.pop('mode', 'append')), + create_table=kwargs.pop('create_table', True), + schema_evolution=kwargs.pop('schema_evolution', False), + **{k: v for k, v in kwargs.items() if k in ['max_retries', 'retry_delay']}, + ) + + self.logger.info(f'Starting async streaming query to {loader_type}:{destination}') + + loader_instance = create_loader(loader_type, loader_config, label_manager=self.label_manager) + + results = [] + + try: + # Execute streaming query with Flight SQL + command_query = FlightSql_pb2.CommandStatementQuery() + command_query.query = query + + any_command = Any() + any_command.Pack(command_query) + cmd = any_command.SerializeToString() + + flight_descriptor = flight.FlightDescriptor.for_command(cmd) + info = self.conn.get_flight_info(flight_descriptor) + reader = self.conn.do_get(info.endpoints[0].ticket) + + stream_iterator = StreamingResultIterator(reader) + + if with_reorg_detection: + stream_iterator = ReorgAwareStream(stream_iterator) + + with loader_instance: + for result in loader_instance.load_stream_continuous( + stream_iterator, destination, connection_name=connection_name, **load_config.__dict__ + ): + results.append(result) + + except Exception as e: + self.logger.error(f'Streaming query failed: {e}') + results.append( + LoadResult( + rows_loaded=0, + duration=0.0, + ops_per_second=0.0, + table_name=destination, + loader_type=loader_type, + success=False, + error=str(e), + metadata={'streaming_error': True}, + ) + ) + + return results + + results = await loop.run_in_executor(None, sync_streaming) + for result in results: + yield result + + async def close(self): + """Close all connections and release resources.""" + # Close Flight SQL connection + if hasattr(self, 'conn') and self.conn: + try: + self.conn.close() + except Exception as e: + self.logger.warning(f'Error closing Flight connection: {e}') + + # Close async admin client if initialized + if self._admin_client: + await self._admin_client.close() + + # Close async registry client if initialized + if self._registry_client: + await self._registry_client.close() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/src/amp/registry/async_client.py b/src/amp/registry/async_client.py new file mode 100644 index 0000000..acb454f --- /dev/null +++ b/src/amp/registry/async_client.py @@ -0,0 +1,180 @@ +"""Async Registry API client.""" + +import logging +import os +from typing import Optional + +import httpx + +from . import errors + +logger = logging.getLogger(__name__) + + +class AsyncRegistryClient: + """Async client for interacting with the Amp Registry API. + + The Registry API provides dataset discovery, search, and publishing capabilities. + + Args: + base_url: Base URL for the Registry API (default: staging registry) + auth_token: Optional Bearer token for authenticated operations (highest priority) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) + + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth + + Example: + >>> # Read-only operations (no auth required) + >>> async with AsyncRegistryClient() as client: + ... datasets = await client.datasets.search('ethereum') + >>> + >>> # Authenticated operations with explicit token + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... await client.datasets.publish(...) + """ + + def __init__( + self, + base_url: str = 'https://api.registry.amp.staging.thegraph.com', + auth_token: Optional[str] = None, + auth: bool = False, + ): + """Initialize async Registry client. + + Args: + base_url: Base URL for the Registry API + auth_token: Optional Bearer token for authentication + auth: If True, load auth token from ~/.amp/cache + + Raises: + ValueError: If both auth=True and auth_token are provided + """ + if auth and auth_token: + raise ValueError('Cannot specify both auth=True and auth_token. Choose one authentication method.') + + self.base_url = base_url.rstrip('/') + + # Resolve auth token provider with priority: explicit param > env var > auth file + self._get_token = None + if auth_token: + # Priority 1: Explicit auth_token parameter (static token) + def get_token(): + return auth_token + + self._get_token = get_token + elif os.getenv('AMP_AUTH_TOKEN'): + # Priority 2: AMP_AUTH_TOKEN environment variable (static token) + env_token = os.getenv('AMP_AUTH_TOKEN') + + def get_token(): + return env_token + + self._get_token = get_token + elif auth: + # Priority 3: Load from ~/.amp/cache/amp_cli_auth (auto-refreshing) + from amp.auth import AuthService + + auth_service = AuthService() + self._get_token = auth_service.get_token # Callable that auto-refreshes + + # Create async HTTP client (no auth header yet - will be added per-request) + self._http = httpx.AsyncClient( + base_url=self.base_url, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + timeout=30.0, + ) + + logger.info(f'Initialized async Registry client for {base_url}') + + @property + def datasets(self): + """Access the async datasets client. + + Returns: + AsyncRegistryDatasetsClient: Client for dataset operations + """ + from .async_datasets import AsyncRegistryDatasetsClient + + return AsyncRegistryDatasetsClient(self) + + async def _request( + self, + method: str, + path: str, + **kwargs, + ) -> httpx.Response: + """Make an async HTTP request to the Registry API. + + Args: + method: HTTP method (GET, POST, etc.) + path: API path (without base URL) + **kwargs: Additional arguments to pass to httpx + + Returns: + httpx.Response: HTTP response + + Raises: + RegistryError: If the request fails + """ + url = path if path.startswith('http') else f'{self.base_url}{path}' + + # Add auth header dynamically (auto-refreshes if needed) + headers = kwargs.get('headers', {}) + if self._get_token: + headers['Authorization'] = f'Bearer {self._get_token()}' + kwargs['headers'] = headers + + try: + response = await self._http.request(method, url, **kwargs) + + # Handle error responses + if response.status_code >= 400: + self._handle_error(response) + + return response + + except httpx.RequestError as e: + raise errors.RegistryError(f'Request failed: {e}') from e + + def _handle_error(self, response: httpx.Response) -> None: + """Handle error responses from the API. + + Args: + response: HTTP error response + + Raises: + RegistryError: Mapped exception for the error + """ + try: + error_data = response.json() + error_code = error_data.get('error_code', '') + error_message = error_data.get('error_message', response.text) + request_id = error_data.get('request_id', '') + + # Map to specific exception + raise errors.map_error(error_code, error_message, request_id) + + except (ValueError, KeyError): + # Couldn't parse error response, raise generic error + raise errors.RegistryError( + f'HTTP {response.status_code}: {response.text}', + error_code=str(response.status_code), + ) from None + + async def close(self): + """Close the HTTP client.""" + await self._http.aclose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/src/amp/registry/async_datasets.py b/src/amp/registry/async_datasets.py new file mode 100644 index 0000000..012ae0b --- /dev/null +++ b/src/amp/registry/async_datasets.py @@ -0,0 +1,437 @@ +"""Async Registry datasets client.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Dict, Optional + +from amp.utils.manifest_inspector import describe_manifest + +from . import models + +if TYPE_CHECKING: + from .async_client import AsyncRegistryClient + +logger = logging.getLogger(__name__) + + +class AsyncRegistryDatasetsClient: + """Async client for dataset operations in the Registry API. + + Provides async methods for: + - Searching and discovering datasets + - Fetching dataset details and manifests + - Publishing datasets (requires authentication) + - Managing dataset visibility and versions + + Args: + registry_client: Parent AsyncRegistryClient instance + """ + + def __init__(self, registry_client: AsyncRegistryClient): + """Initialize async datasets client. + + Args: + registry_client: Parent AsyncRegistryClient instance + """ + self._registry = registry_client + + # Read Operations (Public - No Auth Required) + + async def list( + self, limit: int = 50, page: int = 1, sort_by: Optional[str] = None, direction: Optional[str] = None + ) -> models.DatasetListResponse: + """List all published datasets with pagination. + + Args: + limit: Maximum number of datasets to return (default: 50, max: 1000) + page: Page number (1-indexed, default: 1) + sort_by: Field to sort by (e.g., 'name', 'created_at', 'updated_at') + direction: Sort direction ('asc' or 'desc') + + Returns: + DatasetListResponse: Paginated list of datasets + + Example: + >>> async with AsyncRegistryClient() as client: + ... response = await client.datasets.list(limit=10, page=1) + ... print(f"Found {response.total_count} datasets") + """ + params: Dict[str, Any] = {'limit': limit, 'page': page} + if sort_by: + params['sort_by'] = sort_by + if direction: + params['direction'] = direction + + response = await self._registry._request('GET', '/api/v1/datasets', params=params) + return models.DatasetListResponse.model_validate(response.json()) + + async def search(self, query: str, limit: int = 50, page: int = 1) -> models.DatasetSearchResponse: + """Search datasets using full-text keyword search. + + Results are ranked by relevance score. + + Args: + query: Search query string + limit: Maximum number of results (default: 50, max: 1000) + page: Page number (1-indexed, default: 1) + + Returns: + DatasetSearchResponse: Search results with relevance scores + + Example: + >>> async with AsyncRegistryClient() as client: + ... results = await client.datasets.search('ethereum blocks') + ... for dataset in results.datasets: + ... print(f"[{dataset.score}] {dataset.namespace}/{dataset.name}") + """ + params = {'search': query, 'limit': limit, 'page': page} + response = await self._registry._request('GET', '/api/v1/datasets/search', params=params) + return models.DatasetSearchResponse.model_validate(response.json()) + + async def ai_search(self, query: str, limit: int = 50) -> list[models.DatasetWithScore]: + """Search datasets using AI-powered semantic search. + + Uses embeddings and natural language processing for better matching. + + Args: + query: Natural language search query + limit: Maximum number of results (default: 50) + + Returns: + list[DatasetWithScore]: List of datasets with relevance scores + + Example: + >>> async with AsyncRegistryClient() as client: + ... results = await client.datasets.ai_search('find NFT transfer data') + ... for dataset in results: + ... print(f"[{dataset.score}] {dataset.namespace}/{dataset.name}") + """ + params = {'search': query, 'limit': limit} + response = await self._registry._request('GET', '/api/v1/datasets/search/ai', params=params) + return [models.DatasetWithScore.model_validate(d) for d in response.json()] + + async def get(self, namespace: str, name: str) -> models.Dataset: + """Get detailed information about a specific dataset. + + Args: + namespace: Dataset namespace (e.g., 'edgeandnode', 'edgeandnode') + name: Dataset name (e.g., 'ethereum-mainnet') + + Returns: + Dataset: Complete dataset information + + Example: + >>> async with AsyncRegistryClient() as client: + ... dataset = await client.datasets.get('edgeandnode', 'ethereum-mainnet') + ... print(f"Latest version: {dataset.latest_version}") + """ + path = f'/api/v1/datasets/{namespace}/{name}' + response = await self._registry._request('GET', path) + return models.Dataset.model_validate(response.json()) + + async def list_versions(self, namespace: str, name: str) -> list[models.DatasetVersion]: + """List all versions of a dataset. + + Versions are returned sorted by latest first. + + Args: + namespace: Dataset namespace + name: Dataset name + + Returns: + list[DatasetVersion]: List of dataset versions + + Example: + >>> async with AsyncRegistryClient() as client: + ... versions = await client.datasets.list_versions('edgeandnode', 'ethereum-mainnet') + ... for version in versions: + ... print(f" - v{version.version} ({version.status})") + """ + path = f'/api/v1/datasets/{namespace}/{name}/versions' + response = await self._registry._request('GET', path) + return [models.DatasetVersion.model_validate(v) for v in response.json()] + + async def get_version(self, namespace: str, name: str, version: str) -> models.DatasetVersion: + """Get details of a specific dataset version. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag (e.g., '1.0.0', 'latest') + + Returns: + DatasetVersion: Version details + + Example: + >>> async with AsyncRegistryClient() as client: + ... version = await client.datasets.get_version('edgeandnode', 'ethereum-mainnet', 'latest') + ... print(f"Version: {version.version}") + """ + path = f'/api/v1/datasets/{namespace}/{name}/versions/{version}' + response = await self._registry._request('GET', path) + return models.DatasetVersion.model_validate(response.json()) + + async def get_manifest(self, namespace: str, name: str, version: str) -> dict: + """Fetch the manifest JSON for a specific dataset version. + + Manifests define the dataset structure, dependencies, and ETL logic. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag (e.g., '1.0.0', 'latest') + + Returns: + dict: Manifest JSON content + + Example: + >>> async with AsyncRegistryClient() as client: + ... manifest = await client.datasets.get_manifest('edgeandnode', 'ethereum-mainnet', 'latest') + ... print(f"Tables: {list(manifest.get('tables', {}).keys())}") + """ + path = f'/api/v1/datasets/{namespace}/{name}/versions/{version}/manifest' + response = await self._registry._request('GET', path) + return response.json() + + async def describe( + self, namespace: str, name: str, version: str = 'latest' + ) -> Dict[str, list[Dict[str, str | bool]]]: + """Get a structured summary of tables and columns in a dataset. + + Returns a dictionary mapping table names to lists of column information, + making it easy to programmatically inspect the dataset schema. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag (default: 'latest') + + Returns: + dict: Mapping of table names to column information. + + Example: + >>> async with AsyncRegistryClient() as client: + ... schema = await client.datasets.describe('edgeandnode', 'ethereum-mainnet', 'latest') + ... for table_name, columns in schema.items(): + ... print(f"Table: {table_name}") + """ + manifest = await self.get_manifest(namespace, name, version) + return describe_manifest(manifest) + + # Write Operations (Require Authentication) + + async def publish( + self, + namespace: str, + name: str, + version: str, + manifest: dict, + visibility: str = 'public', + description: Optional[str] = None, + tags: Optional[list[str]] = None, + chains: Optional[list[str]] = None, + sources: Optional[list[str]] = None, + ) -> models.Dataset: + """Publish a new dataset to the registry. + + Requires authentication (Bearer token). + + Args: + namespace: Dataset namespace (owner's username or org) + name: Dataset name + version: Initial version tag (e.g., '1.0.0') + manifest: Dataset manifest JSON + visibility: Dataset visibility ('public' or 'private', default: 'public') + description: Dataset description + tags: Optional list of tags/keywords + chains: Optional list of blockchain networks + sources: Optional list of data sources + + Returns: + Dataset: Created dataset + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... dataset = await client.datasets.publish( + ... namespace='myuser', + ... name='my_dataset', + ... version='1.0.0', + ... manifest=manifest, + ... description='My custom dataset' + ... ) + """ + body = { + 'name': name, + 'version': version, + 'manifest': manifest, + 'visibility': visibility, + } + if description: + body['description'] = description + if tags: + body['tags'] = tags + if chains: + body['chains'] = chains + if sources: + body['sources'] = sources + + response = await self._registry._request('POST', '/api/v1/owners/@me/datasets/publish', json=body) + return models.Dataset.model_validate(response.json()) + + async def publish_version( + self, + namespace: str, + name: str, + version: str, + manifest: dict, + description: Optional[str] = None, + ) -> models.DatasetVersion: + """Publish a new version of an existing dataset. + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + version: New version tag (e.g., '1.1.0') + manifest: Dataset manifest JSON for this version + description: Optional version description + + Returns: + DatasetVersion: Created version + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... version = await client.datasets.publish_version( + ... namespace='myuser', + ... name='my_dataset', + ... version='1.1.0', + ... manifest=manifest + ... ) + """ + body = {'version': version, 'manifest': manifest} + if description: + body['description'] = description + + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}/versions/publish' + response = await self._registry._request('POST', path, json=body) + return models.DatasetVersion.model_validate(response.json()) + + async def update( + self, + namespace: str, + name: str, + description: Optional[str] = None, + tags: Optional[list[str]] = None, + chains: Optional[list[str]] = None, + sources: Optional[list[str]] = None, + ) -> models.Dataset: + """Update dataset metadata. + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + description: Updated description + tags: Updated tags + chains: Updated chains + sources: Updated sources + + Returns: + Dataset: Updated dataset + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... dataset = await client.datasets.update( + ... namespace='myuser', + ... name='my_dataset', + ... description='Updated description' + ... ) + """ + body = {} + if description is not None: + body['description'] = description + if tags is not None: + body['tags'] = tags + if chains is not None: + body['chains'] = chains + if sources is not None: + body['sources'] = sources + + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}' + response = await self._registry._request('PUT', path, json=body) + return models.Dataset.model_validate(response.json()) + + async def update_visibility(self, namespace: str, name: str, visibility: str) -> models.Dataset: + """Update dataset visibility (public/private). + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + visibility: New visibility ('public' or 'private') + + Returns: + Dataset: Updated dataset + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... dataset = await client.datasets.update_visibility('myuser', 'my_dataset', 'private') + """ + body = {'visibility': visibility} + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}/visibility' + response = await self._registry._request('PATCH', path, json=body) + return models.Dataset.model_validate(response.json()) + + async def update_version_status( + self, namespace: str, name: str, version: str, status: str + ) -> models.DatasetVersion: + """Update the status of a dataset version. + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag + status: New status ('draft', 'published', 'deprecated', or 'archived') + + Returns: + DatasetVersion: Updated version + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... version = await client.datasets.update_version_status( + ... 'myuser', 'my_dataset', '1.0.0', 'deprecated' + ... ) + """ + body = {'status': status} + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}/versions/{version}' + response = await self._registry._request('PATCH', path, json=body) + return models.DatasetVersion.model_validate(response.json()) + + async def delete_version( + self, namespace: str, name: str, version: str + ) -> models.ArchiveDatasetVersionResponse: + """Delete (archive) a dataset version. + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag to delete + + Returns: + ArchiveDatasetVersionResponse: Confirmation of deletion + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... response = await client.datasets.delete_version('myuser', 'my_dataset', '0.1.0') + """ + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}/versions/{version}' + response = await self._registry._request('DELETE', path) + return models.ArchiveDatasetVersionResponse.model_validate(response.json()) diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py new file mode 100644 index 0000000..ef41044 --- /dev/null +++ b/tests/unit/test_async_client.py @@ -0,0 +1,328 @@ +""" +Unit tests for AsyncAmpClient and AsyncQueryBuilder API methods. + +These tests focus on the pure logic and data structures without requiring +actual Flight SQL connections or Admin API calls. +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from src.amp.async_client import AsyncAmpClient, AsyncQueryBuilder + + +@pytest.mark.unit +class TestAsyncQueryBuilder: + """Test AsyncQueryBuilder pure methods and logic""" + + def test_with_dependency_chaining(self): + """Test adding and chaining dependencies""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks JOIN btc.blocks') + + result = qb.with_dependency('eth', '_/eth_firehose@0.0.0').with_dependency('btc', '_/btc_firehose@1.2.3') + + assert result is qb # Returns self for chaining + assert qb._dependencies == {'eth': '_/eth_firehose@0.0.0', 'btc': '_/btc_firehose@1.2.3'} + + def test_with_dependency_overwrites_existing_alias(self): + """Test that same alias overwrites previous dependency""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + qb.with_dependency('eth', '_/eth_firehose@0.0.0') + qb.with_dependency('eth', '_/eth_firehose@1.0.0') + + assert qb._dependencies == {'eth': '_/eth_firehose@1.0.0'} + + def test_ensure_streaming_query_adds_settings(self): + """Test that streaming settings are added when not present""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + + result = qb._ensure_streaming_query('SELECT * FROM eth.blocks') + assert result == 'SELECT * FROM eth.blocks SETTINGS stream = true' + + # Strips semicolons + result = qb._ensure_streaming_query('SELECT * FROM eth.blocks;') + assert result == 'SELECT * FROM eth.blocks SETTINGS stream = true' + + def test_ensure_streaming_query_preserves_existing_settings(self): + """Test that existing SETTINGS stream = true is preserved""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + + # Should not duplicate when already present + result = qb._ensure_streaming_query('SELECT * FROM eth.blocks SETTINGS stream = true') + assert 'SETTINGS stream = true' in result + + def test_querybuilder_repr(self): + """Test AsyncQueryBuilder string representation""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + repr_str = repr(qb) + + assert 'AsyncQueryBuilder' in repr_str + assert 'SELECT * FROM eth.blocks' in repr_str + + # Long queries are truncated + long_query = 'SELECT ' + ', '.join([f'col{i}' for i in range(100)]) + ' FROM eth.blocks' + qb_long = AsyncQueryBuilder(client=None, query=long_query) + assert '...' in repr(qb_long) + + def test_dependencies_initialized_empty(self): + """Test that dependencies and cache are initialized correctly""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + + assert qb._dependencies == {} + assert qb._result_cache is None + + +@pytest.mark.unit +class TestAsyncClientInitialization: + """Test AsyncAmpClient initialization logic""" + + def test_client_requires_url_or_query_url(self): + """Test that AsyncAmpClient requires either url or query_url""" + with pytest.raises(ValueError, match='Either url or query_url must be provided'): + AsyncAmpClient() + + +@pytest.mark.unit +class TestAsyncClientAuthPriority: + """Test AsyncAmpClient authentication priority (explicit token > env var > auth file)""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_explicit_token_highest_priority(self, mock_connect, mock_getenv): + """Test that explicit auth_token parameter has highest priority""" + mock_getenv.return_value = 'env-var-token' + + AsyncAmpClient(query_url='grpc://localhost:1602', auth_token='explicit-token') + + # Verify that explicit token was used (not env var) + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].get_token() == 'explicit-token' + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_env_var_second_priority(self, mock_connect, mock_getenv): + """Test that AMP_AUTH_TOKEN env var has second priority""" + + # Return 'env-var-token' for AMP_AUTH_TOKEN, None for others + def getenv_side_effect(key, default=None): + if key == 'AMP_AUTH_TOKEN': + return 'env-var-token' + return default + + mock_getenv.side_effect = getenv_side_effect + + AsyncAmpClient(query_url='grpc://localhost:1602') + + # Verify env var was checked + calls = [str(call) for call in mock_getenv.call_args_list] + assert any('AMP_AUTH_TOKEN' in call for call in calls) + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].get_token() == 'env-var-token' + + @patch('amp.auth.AuthService') + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_auth_file_lowest_priority(self, mock_connect, mock_getenv, mock_auth_service): + """Test that auth=True has lowest priority""" + + # Return None for all getenv calls + def getenv_side_effect(key, default=None): + return default + + mock_getenv.side_effect = getenv_side_effect + + mock_service_instance = Mock() + mock_service_instance.get_token.return_value = 'file-token' + mock_auth_service.return_value = mock_service_instance + + AsyncAmpClient(query_url='grpc://localhost:1602', auth=True) + + # Verify auth file was used + mock_auth_service.assert_called_once() + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + # The middleware should use the auth service's get_token method directly + assert middleware[0].get_token == mock_service_instance.get_token + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_no_auth_when_nothing_provided(self, mock_connect, mock_getenv): + """Test that no auth middleware is added when no auth is provided""" + + # Return None/default for all getenv calls + def getenv_side_effect(key, default=None): + return default + + mock_getenv.side_effect = getenv_side_effect + + AsyncAmpClient(query_url='grpc://localhost:1602') + + # Verify no middleware was added + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware') + assert middleware is None or len(middleware) == 0 + + +@pytest.mark.unit +class TestAsyncClientSqlMethod: + """Test AsyncAmpClient.sql() method""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_sql_returns_async_query_builder(self, mock_connect, mock_getenv): + """Test that sql() returns an AsyncQueryBuilder instance""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + result = client.sql('SELECT * FROM eth.blocks') + + assert isinstance(result, AsyncQueryBuilder) + assert result.query == 'SELECT * FROM eth.blocks' + assert result.client is client + + +@pytest.mark.unit +class TestAsyncClientProperties: + """Test AsyncAmpClient properties for Admin and Registry access""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_datasets_raises_without_admin_url(self, mock_connect, mock_getenv): + """Test that datasets property raises when admin_url not provided""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + + with pytest.raises(ValueError, match='Admin API not configured'): + _ = client.datasets + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_jobs_raises_without_admin_url(self, mock_connect, mock_getenv): + """Test that jobs property raises when admin_url not provided""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + + with pytest.raises(ValueError, match='Admin API not configured'): + _ = client.jobs + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_schema_raises_without_admin_url(self, mock_connect, mock_getenv): + """Test that schema property raises when admin_url not provided""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + + with pytest.raises(ValueError, match='Admin API not configured'): + _ = client.schema + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_registry_raises_without_registry_url(self, mock_connect, mock_getenv): + """Test that registry property raises when registry_url not provided""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602', registry_url=None) + + with pytest.raises(ValueError, match='Registry API not configured'): + _ = client.registry + + +@pytest.mark.unit +class TestAsyncClientConfigurationMethods: + """Test AsyncAmpClient configuration methods""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_configure_connection(self, mock_connect, mock_getenv): + """Test that configure_connection stores connection config""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + client.configure_connection('test_conn', 'postgresql', {'host': 'localhost', 'database': 'test'}) + + # Verify connection was stored in manager + connections = client.list_connections() + assert 'test_conn' in connections + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_get_available_loaders(self, mock_connect, mock_getenv): + """Test that get_available_loaders returns list of loaders""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + loaders = client.get_available_loaders() + + assert isinstance(loaders, list) + # Should have at least postgresql and redis loaders + assert 'postgresql' in loaders or len(loaders) > 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestAsyncQueryBuilderLoad: + """Test AsyncQueryBuilder.load() method""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + async def test_load_raises_for_parallel_config_without_stream(self, mock_connect, mock_getenv): + """Test that load() raises error when parallel_config used without stream=True""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + qb = client.sql('SELECT * FROM eth.blocks') + + with pytest.raises(ValueError, match='parallel_config requires stream=True'): + await qb.load( + connection='test_conn', + destination='test_table', + parallel_config={'partitions': 4}, + ) + + +@pytest.mark.unit +class TestAsyncClientContextManager: + """Test AsyncAmpClient context manager support""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + @pytest.mark.asyncio + async def test_async_context_manager(self, mock_connect, mock_getenv): + """Test that AsyncAmpClient works as async context manager""" + mock_getenv.return_value = None + mock_conn = Mock() + mock_connect.return_value = mock_conn + + async with AsyncAmpClient(query_url='grpc://localhost:1602') as client: + assert client is not None + assert client.conn is mock_conn + + # Verify connection was closed on exit + mock_conn.close.assert_called_once() + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + @pytest.mark.asyncio + async def test_close_method(self, mock_connect, mock_getenv): + """Test that close() properly closes all connections""" + mock_getenv.return_value = None + mock_conn = Mock() + mock_connect.return_value = mock_conn + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + await client.close() + + mock_conn.close.assert_called_once()