From ec684c4785363f83b42e5346f4623cc1a2636dfc Mon Sep 17 00:00:00 2001 From: Nathan Allen Date: Mon, 9 Feb 2026 22:18:07 -0500 Subject: [PATCH] feat!: refactor tool discovery into extensible parser registry pattern Implement Parser Registry Pattern to enable support for multiple API specification formats beyond OpenAPI. This creates a clean, extensible architecture for future Postman Collections, GraphQL, and other formats. Changes: - Create parsers module with APISpecParser abstract base class - Implement ParserRegistry for parser management and discovery - Extract all OpenAPI logic into OpenAPIParser class (~600 lines) - Refactor OCPSchemaDiscovery to orchestration layer using registry - Add comprehensive tests for parser components - Update exports to include new parser classes BREAKING CHANGE: Private methods moved from OCPSchemaDiscovery to OpenAPIParser class. Methods like _parse_openapi_spec(), _normalize_tool_name(), _filter_tools_by_resources(), and others are no longer accessible on OCPSchemaDiscovery. Code using these private methods should use the OpenAPIParser class directly or the public discover_api() method. --- src/ocp_agent/__init__.py | 8 +- src/ocp_agent/parsers/__init__.py | 16 + src/ocp_agent/parsers/base.py | 93 ++++ src/ocp_agent/parsers/openapi_parser.py | 604 ++++++++++++++++++++++ src/ocp_agent/parsers/registry.py | 90 ++++ src/ocp_agent/schema_discovery.py | 651 ++++-------------------- tests/test_openapi_parser.py | 342 +++++++++++++ tests/test_parser_registry.py | 99 ++++ tests/test_schema_discovery.py | 444 +--------------- 9 files changed, 1363 insertions(+), 984 deletions(-) create mode 100644 src/ocp_agent/parsers/__init__.py create mode 100644 src/ocp_agent/parsers/base.py create mode 100644 src/ocp_agent/parsers/openapi_parser.py create mode 100644 src/ocp_agent/parsers/registry.py create mode 100644 tests/test_openapi_parser.py create mode 100644 tests/test_parser_registry.py diff --git a/src/ocp_agent/__init__.py b/src/ocp_agent/__init__.py index 15e893b..410e0c5 100644 --- a/src/ocp_agent/__init__.py +++ b/src/ocp_agent/__init__.py @@ -17,6 +17,7 @@ from .agent import OCPAgent from .storage import OCPStorage from .errors import OCPError, RegistryUnavailable, APINotFound, SchemaDiscoveryError, ValidationError +from .parsers import APISpecParser, ParserRegistry, OpenAPIParser __version__ = "0.1.0" __all__ = [ @@ -30,7 +31,7 @@ "extract_context_from_response", "validate_context", - # Convenience functions for cleaner API + # Convenience functions "parse_context", "add_context_headers", @@ -39,6 +40,11 @@ "OCPTool", "OCPAPISpec", + # Parser system + "APISpecParser", + "ParserRegistry", + "OpenAPIParser", + # Registry integration "OCPRegistry", diff --git a/src/ocp_agent/parsers/__init__.py b/src/ocp_agent/parsers/__init__.py new file mode 100644 index 0000000..0650fcd --- /dev/null +++ b/src/ocp_agent/parsers/__init__.py @@ -0,0 +1,16 @@ +""" +API Specification Parsers + +This module provides an extensible parser system for converting various API +specification formats (OpenAPI, Postman, GraphQL, etc.) into OCP tools. +""" + +from .base import APISpecParser +from .registry import ParserRegistry +from .openapi_parser import OpenAPIParser + +__all__ = [ + 'APISpecParser', + 'ParserRegistry', + 'OpenAPIParser', +] diff --git a/src/ocp_agent/parsers/base.py b/src/ocp_agent/parsers/base.py new file mode 100644 index 0000000..6c9cc4a --- /dev/null +++ b/src/ocp_agent/parsers/base.py @@ -0,0 +1,93 @@ +""" +Base Parser Interface + +Defines the abstract interface that all API specification parsers must implement. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List +from dataclasses import dataclass + + +@dataclass +class OCPTool: + """Represents a discovered API tool/endpoint""" + name: str + description: str + method: str + path: str + parameters: Dict[str, Any] + response_schema: Optional[Dict[str, Any]] + operation_id: Optional[str] = None + tags: Optional[List[str]] = None + + +@dataclass +class OCPAPISpec: + """Represents a parsed API specification""" + base_url: str + title: str + version: str + description: str + tools: List[OCPTool] + raw_spec: Dict[str, Any] + name: Optional[str] = None + + +class APISpecParser(ABC): + """ + Abstract base class for API specification parsers. + + All parsers must implement three methods: + - can_parse(): Detect if this parser can handle a given spec + - parse(): Convert the spec into an OCPAPISpec with tools + - get_format_name(): Return a human-readable format name + """ + + @abstractmethod + def can_parse(self, spec_data: Dict[str, Any]) -> bool: + """ + Determine if this parser can handle the given specification. + + Args: + spec_data: The raw specification data as a dictionary + + Returns: + True if this parser can handle the format, False otherwise + """ + pass + + @abstractmethod + def parse( + self, + spec_data: Dict[str, Any], + base_url_override: Optional[str] = None, + include_resources: Optional[List[str]] = None, + path_prefix: Optional[str] = None + ) -> OCPAPISpec: + """ + Parse the specification and extract tools. + + Args: + spec_data: The raw specification data as a dictionary + base_url_override: Optional override for the API base URL + include_resources: Optional list of resource names to filter tools by + path_prefix: Optional path prefix to strip before filtering + + Returns: + OCPAPISpec containing extracted tools and metadata + + Raises: + Exception: If parsing fails + """ + pass + + @abstractmethod + def get_format_name(self) -> str: + """ + Get the human-readable name of the format this parser handles. + + Returns: + Format name (e.g., "OpenAPI", "Postman Collection", "GraphQL") + """ + pass diff --git a/src/ocp_agent/parsers/openapi_parser.py b/src/ocp_agent/parsers/openapi_parser.py new file mode 100644 index 0000000..8eddf7a --- /dev/null +++ b/src/ocp_agent/parsers/openapi_parser.py @@ -0,0 +1,604 @@ +""" +OpenAPI Parser + +Parses OpenAPI 2.0 (Swagger) and OpenAPI 3.x specifications into OCP tools. +""" + +import re +import logging +from typing import Dict, List, Any, Optional + +from .base import APISpecParser, OCPAPISpec, OCPTool +from ..errors import SchemaDiscoveryError + +logger = logging.getLogger(__name__) + +# Configuration constants +DEFAULT_API_TITLE = 'Unknown API' +DEFAULT_API_VERSION = '1.0.0' +SUPPORTED_HTTP_METHODS = ['get', 'post', 'put', 'patch', 'delete'] + + +class OpenAPIParser(APISpecParser): + """ + Parser for OpenAPI 2.0 (Swagger) and OpenAPI 3.x specifications. + + Supports: + - Swagger 2.0 + - OpenAPI 3.0.x + - OpenAPI 3.1.x + - OpenAPI 3.2.x + """ + + def __init__(self): + self._spec_version: Optional[str] = None + + def can_parse(self, spec_data: Dict[str, Any]) -> bool: + """Check if this is an OpenAPI or Swagger specification.""" + return "openapi" in spec_data or "swagger" in spec_data + + def get_format_name(self) -> str: + """Return the format name.""" + return "OpenAPI" + + def parse( + self, + spec_data: Dict[str, Any], + base_url_override: Optional[str] = None, + include_resources: Optional[List[str]] = None, + path_prefix: Optional[str] = None + ) -> OCPAPISpec: + """ + Parse OpenAPI specification into OCP tools. + + Args: + spec_data: The OpenAPI specification as a dictionary + base_url_override: Optional override for API base URL + include_resources: Optional list of resource names to filter tools by + path_prefix: Optional path prefix to strip before filtering + + Returns: + OCPAPISpec with discovered tools and capabilities + """ + # Detect version + self._spec_version = self._detect_spec_version(spec_data) + + # Parse the specification + api_spec = self._parse_openapi_spec(spec_data, base_url_override) + + # Apply resource filtering if specified + if include_resources: + filtered_tools = self._filter_tools_by_resources( + api_spec.tools, include_resources, path_prefix + ) + return OCPAPISpec( + base_url=api_spec.base_url, + title=api_spec.title, + version=api_spec.version, + description=api_spec.description, + tools=filtered_tools, + raw_spec=api_spec.raw_spec + ) + + return api_spec + + def _detect_spec_version(self, spec: Dict[str, Any]) -> str: + """Detect OpenAPI/Swagger version from spec. + + Returns: + Version string: 'swagger_2', 'openapi_3.0', 'openapi_3.1', 'openapi_3.2' + """ + if "swagger" in spec: + swagger_version = spec["swagger"] + if swagger_version.startswith("2."): + return "swagger_2" + raise SchemaDiscoveryError(f"Unsupported Swagger version: {swagger_version}") + elif "openapi" in spec: + openapi_version = spec["openapi"] + if openapi_version.startswith("3.0"): + return "openapi_3.0" + elif openapi_version.startswith("3.1"): + return "openapi_3.1" + elif openapi_version.startswith("3.2"): + return "openapi_3.2" + raise SchemaDiscoveryError(f"Unsupported OpenAPI version: {openapi_version}") + + raise SchemaDiscoveryError("Unable to detect spec version: missing 'swagger' or 'openapi' field") + + def _parse_openapi_spec(self, spec_data: Dict[str, Any], base_url_override: Optional[str] = None) -> OCPAPISpec: + """Parse OpenAPI specification into OCP tools""" + + # Extract basic info + info = spec_data.get('info', {}) + title = info.get('title', DEFAULT_API_TITLE) + version = info.get('version', DEFAULT_API_VERSION) + description = info.get('description', '') + + # Determine base URL (version-specific) + base_url = base_url_override + if not base_url: + base_url = self._extract_base_url(spec_data) + + # Create memoization cache for $ref resolution + memo_cache = {} + + # Parse paths into tools + tools = [] + paths = spec_data.get('paths', {}) + + for path, path_item in paths.items(): + for method, operation in path_item.items(): + if method.lower() in SUPPORTED_HTTP_METHODS: + tool = self._create_tool_from_operation( + path, method.upper(), operation, spec_data, memo_cache + ) + if tool: + tools.append(tool) + + return OCPAPISpec( + base_url=base_url, + title=title, + version=version, + description=description, + tools=tools, + raw_spec=spec_data + ) + + def _extract_base_url(self, spec_data: Dict[str, Any]) -> str: + """Extract base URL from spec (version-aware).""" + if self._spec_version == "swagger_2": + # Swagger 2.0: construct from host, basePath, and schemes + schemes = spec_data.get('schemes', ['https']) + host = spec_data.get('host', '') + base_path = spec_data.get('basePath', '') + + if host: + scheme = schemes[0] if schemes else 'https' + return f"{scheme}://{host}{base_path}" + return '' + else: + # OpenAPI 3.x: use servers array + servers = spec_data.get('servers', []) + if servers: + return servers[0].get('url', '') + return '' + + def _normalize_tool_name(self, name: str) -> str: + """Normalize tool name to camelCase, removing special characters. + + Converts various naming patterns to consistent camelCase: + - 'meta/root' → 'metaRoot' + - 'repos/disable-vulnerability-alerts' → 'reposDisableVulnerabilityAlerts' + - 'admin_apps_approve' → 'adminAppsApprove' + - 'FetchAccount' → 'fetchAccount' + - 'v2010/Accounts' → 'v2010Accounts' + - 'get_users_list' → 'getUsersList' + - 'SMS/send' → 'smsSend' + """ + # Handle empty or None names + if not name: + return name + + # First, split PascalCase/camelCase words (e.g., "FetchAccount" -> "Fetch Account") + # Insert space before uppercase letters that follow lowercase letters or digits + pascal_split = re.sub(r'([a-z0-9])([A-Z])', r'\1 \2', name) + + # Replace separators (/, _, -, .) with spaces for processing + # Also handle multiple consecutive separators like // + normalized = re.sub(r'[/_.-]+', ' ', pascal_split) + + # Split into words and filter out empty strings + words = [word for word in normalized.split() if word] + + if not words: + return name + + # Convert to camelCase: first word lowercase, rest capitalize + camel_case_words = [words[0].lower()] + for word in words[1:]: + camel_case_words.append(word.capitalize()) + + return ''.join(camel_case_words) + + def _is_valid_tool_name(self, name: str) -> bool: + """Check if a normalized tool name is valid. + + A valid tool name must: + - Not be empty + - Not consist only of special characters + - Start with a letter + - Contain at least one alphanumeric character + """ + if not name: + return False + + # Must start with a letter + if not name[0].isalpha(): + return False + + # Must contain at least one alphanumeric character + if not any(c.isalnum() for c in name): + return False + + return True + + def _create_tool_from_operation( + self, + path: str, + method: str, + operation: Dict[str, Any], + spec_data: Dict[str, Any], + memo_cache: Dict[str, Any] + ) -> Optional[OCPTool]: + """Create OCP tool from OpenAPI operation""" + + # Generate tool name with proper validation and fallback logic + operation_id = operation.get('operationId') + tool_name = None + + # Try operationId first + if operation_id: + normalized_name = self._normalize_tool_name(operation_id) + if self._is_valid_tool_name(normalized_name): + tool_name = normalized_name + + # If operationId failed, try fallback naming + if not tool_name: + # Generate name from path and method + clean_path = path.replace('/', '_').replace('{', '').replace('}', '') + fallback_name = f"{method.lower()}{clean_path}" + normalized_fallback = self._normalize_tool_name(fallback_name) + if self._is_valid_tool_name(normalized_fallback): + tool_name = normalized_fallback + + # If we can't generate a valid tool name, skip this operation + if not tool_name: + logger.warning(f"Skipping operation {method} {path}: unable to generate valid tool name") + return None + + # Get description + description = operation.get('summary', '') or operation.get('description', '') + if not description: + description = "No description provided" + + # Parse parameters (version-aware) + parameters = self._parse_parameters(operation.get('parameters', []), spec_data, memo_cache) + + # Add request body parameters (version-specific) + if method in ['POST', 'PUT', 'PATCH']: + if self._spec_version == "swagger_2": + # Swagger 2.0: body is in parameters array + for param in operation.get('parameters', []): + body_params = self._parse_swagger2_body_parameter(param, spec_data, memo_cache) + parameters.update(body_params) + else: + # OpenAPI 3.x: separate requestBody field + if 'requestBody' in operation: + body_params = self._parse_openapi3_request_body(operation['requestBody'], spec_data, memo_cache) + parameters.update(body_params) + + # Parse response schema + response_schema = self._parse_responses(operation.get('responses', {}), spec_data, memo_cache) + + # Get tags + tags = operation.get('tags', []) + + return OCPTool( + name=tool_name, + description=description, + method=method, + path=path, + parameters=parameters, + response_schema=response_schema, + operation_id=operation_id, + tags=tags + ) + + def _parse_parameters( + self, + parameters: List[Dict[str, Any]], + spec_data: Dict[str, Any], + memo_cache: Dict[str, Any] + ) -> Dict[str, Any]: + """Parse OpenAPI parameters into tool parameter schema""" + parsed_params = {} + + for param in parameters: + name = param.get('name') + if not name: + continue + + param_schema = { + 'description': param.get('description', ''), + 'required': param.get('required', False), + 'location': param.get('in', 'query'), # query, path, header, cookie + 'type': 'string' # Default type + } + + # Extract type from schema + schema = param.get('schema', {}) + if schema: + # Resolve any $refs in the parameter schema + schema = self._resolve_refs(schema, spec_data, [], memo_cache) + param_schema['type'] = schema.get('type', 'string') + if 'enum' in schema: + param_schema['enum'] = schema['enum'] + if 'format' in schema: + param_schema['format'] = schema['format'] + + parsed_params[name] = param_schema + + return parsed_params + + def _parse_openapi3_request_body( + self, + request_body: Dict[str, Any], + spec_data: Dict[str, Any], + memo_cache: Dict[str, Any] + ) -> Dict[str, Any]: + """Parse request body into parameters (OpenAPI 3.x only)""" + parameters = {} + + content = request_body.get('content', {}) + + # Look for JSON content first + json_content = content.get('application/json', {}) + if json_content and 'schema' in json_content: + schema = json_content['schema'] + + # Resolve any $refs in the request body schema + schema = self._resolve_refs(schema, spec_data, [], memo_cache) + + # Handle object schemas + if schema.get('type') == 'object': + properties = schema.get('properties', {}) + required_fields = schema.get('required', []) + + for prop_name, prop_schema in properties.items(): + parameters[prop_name] = { + 'description': prop_schema.get('description', ''), + 'required': prop_name in required_fields, + 'location': 'body', + 'type': prop_schema.get('type', 'string') + } + + if 'enum' in prop_schema: + parameters[prop_name]['enum'] = prop_schema['enum'] + + return parameters + + def _parse_swagger2_body_parameter( + self, + param: Dict[str, Any], + spec_data: Dict[str, Any], + memo_cache: Dict[str, Any] + ) -> Dict[str, Any]: + """Parse Swagger 2.0 body parameter into parameters.""" + parameters = {} + + if param.get('in') == 'body' and 'schema' in param: + schema = param['schema'] + + # Resolve any $refs in the body schema + schema = self._resolve_refs(schema, spec_data, [], memo_cache) + + # Handle object schemas + if schema.get('type') == 'object': + properties = schema.get('properties', {}) + required_fields = schema.get('required', []) + + for prop_name, prop_schema in properties.items(): + parameters[prop_name] = { + 'description': prop_schema.get('description', ''), + 'required': prop_name in required_fields, + 'location': 'body', + 'type': prop_schema.get('type', 'string') + } + + if 'enum' in prop_schema: + parameters[prop_name]['enum'] = prop_schema['enum'] + + return parameters + + def _parse_responses( + self, + responses: Dict[str, Any], + spec_data: Dict[str, Any], + memo_cache: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """Parse response schemas (version-aware)""" + # Look for successful response (200, 201, etc.) + for status_code, response in responses.items(): + if str(status_code).startswith('2'): # 2xx success codes + if self._spec_version == "swagger_2": + # Swagger 2.0: schema is directly in response + if 'schema' in response: + schema = response['schema'] + # Resolve any $refs in the response schema + return self._resolve_refs(schema, spec_data, [], memo_cache) + else: + # OpenAPI 3.x: schema is in content.application/json + content = response.get('content', {}) + json_content = content.get('application/json', {}) + + if json_content and 'schema' in json_content: + schema = json_content['schema'] + # Resolve any $refs in the response schema + return self._resolve_refs(schema, spec_data, [], memo_cache) + + return None + + def _filter_tools_by_resources( + self, + tools: List[OCPTool], + include_resources: List[str], + path_prefix: Optional[str] = None + ) -> List[OCPTool]: + """Filter tools to only include those whose first resource segment matches include_resources""" + if not include_resources: + return tools + + # Normalize resource names to lowercase for case-insensitive matching + normalized_resources = [resource.lower() for resource in include_resources] + + filtered_tools = [] + for tool in tools: + path = tool.path + + # Strip path prefix if provided + if path_prefix: + prefix_lower = path_prefix.lower() + path_lower = path.lower() + if path_lower.startswith(prefix_lower): + path = path[len(path_prefix):] + + # Extract path segments by splitting on both '/' and '.' + path_lower = path.lower() + # Replace dots with slashes for uniform splitting + path_normalized = path_lower.replace('.', '/') + # Split by '/' and filter out empty segments and parameter placeholders + segments = [seg for seg in path_normalized.split('/') if seg and not seg.startswith('{')] + + # Check if the first segment matches any of the include_resources + if segments and segments[0] in normalized_resources: + filtered_tools.append(tool) + + return filtered_tools + + def _resolve_refs( + self, + obj: Any, + root: Optional[Dict[str, Any]] = None, + resolution_stack: Optional[List[str]] = None, + memo: Optional[Dict[str, Any]] = None, + inside_polymorphic_keyword: bool = False + ) -> Any: + """Recursively resolve $ref references in OpenAPI spec with polymorphic keyword handling + + Args: + obj: Current object being processed (dict, list, or primitive) + root: Root spec document for looking up references + resolution_stack: Stack of refs currently being resolved (for circular detection) + memo: Memoization cache for already-resolved refs + inside_polymorphic_keyword: True if currently inside anyOf/oneOf/allOf + + Returns: + Object with all resolvable $refs replaced by their definitions + """ + # Initialize on first call + if root is None: + root = obj + if resolution_stack is None: + resolution_stack = [] + if memo is None: + memo = {} + + # Handle dict objects + if isinstance(obj, dict): + # Check for polymorphic keywords - process with flag set + if 'anyOf' in obj: + result = {'anyOf': [self._resolve_refs(item, root, resolution_stack, memo, inside_polymorphic_keyword=True) for item in obj['anyOf']]} + # Include other keys if present + for k, v in obj.items(): + if k != 'anyOf': + result[k] = self._resolve_refs(v, root, resolution_stack, memo, inside_polymorphic_keyword) + return result + + if 'oneOf' in obj: + result = {'oneOf': [self._resolve_refs(item, root, resolution_stack, memo, inside_polymorphic_keyword=True) for item in obj['oneOf']]} + for k, v in obj.items(): + if k != 'oneOf': + result[k] = self._resolve_refs(v, root, resolution_stack, memo, inside_polymorphic_keyword) + return result + + if 'allOf' in obj: + result = {'allOf': [self._resolve_refs(item, root, resolution_stack, memo, inside_polymorphic_keyword=True) for item in obj['allOf']]} + for k, v in obj.items(): + if k != 'allOf': + result[k] = self._resolve_refs(v, root, resolution_stack, memo, inside_polymorphic_keyword) + return result + + # Check if this is a $ref + if '$ref' in obj and len(obj) == 1: + ref_path = obj['$ref'] + + # Only handle internal refs (start with #/) + if not ref_path.startswith('#/'): + return obj + + # If inside polymorphic keyword, check if ref points to an object + if inside_polymorphic_keyword: + try: + resolved = self._lookup_ref(root, ref_path) + if resolved is not None: + # Check if it's an object schema + if resolved.get('type') == 'object' or 'properties' in resolved: + # Keep the $ref unresolved for object schemas + return obj + except Exception: + # If lookup fails, keep the ref + return obj + + # Check memo cache + if ref_path in memo: + return memo[ref_path] + + # Check for circular reference + if ref_path in resolution_stack: + # Return a placeholder to break the cycle + placeholder = {'type': 'object', 'description': 'Circular reference'} + memo[ref_path] = placeholder + return placeholder + + # Resolve the reference + try: + resolved = self._lookup_ref(root, ref_path) + if resolved is not None: + # Recursively resolve the resolved object with updated stack + new_stack = resolution_stack + [ref_path] + result = self._resolve_refs(resolved, root, new_stack, memo, inside_polymorphic_keyword) + memo[ref_path] = result + return result + except Exception: + # If lookup fails, return a placeholder + placeholder = {'type': 'object', 'description': 'Unresolved reference'} + memo[ref_path] = placeholder + return placeholder + + return obj + + # Not a $ref, recursively process all values + return {k: self._resolve_refs(v, root, resolution_stack, memo, inside_polymorphic_keyword) for k, v in obj.items()} + + # Handle list objects + elif isinstance(obj, list): + return [self._resolve_refs(item, root, resolution_stack, memo, inside_polymorphic_keyword) for item in obj] + + # Primitives pass through unchanged + return obj + + def _lookup_ref(self, root: Dict[str, Any], ref_path: str) -> Any: + """Look up a reference path in the spec document + + Args: + root: Root spec document + ref_path: Reference path like '#/components/schemas/User' + + Returns: + The referenced object, or None if not found + """ + # Remove the leading '#/' and split by '/' + if not ref_path.startswith('#/'): + return None + + path_parts = ref_path[2:].split('/') + + # Navigate through the spec + current = root + for part in path_parts: + if isinstance(current, dict) and part in current: + current = current[part] + else: + return None + + return current diff --git a/src/ocp_agent/parsers/registry.py b/src/ocp_agent/parsers/registry.py new file mode 100644 index 0000000..f0f392f --- /dev/null +++ b/src/ocp_agent/parsers/registry.py @@ -0,0 +1,90 @@ +""" +Parser Registry + +Manages registration and discovery of API specification parsers. +""" + +from typing import Dict, Any, Optional, List +from .base import APISpecParser + + +class ParserRegistry: + """ + Registry for API specification parsers. + + Maintains a list of registered parsers and provides methods to: + - Register new parsers + - Find the appropriate parser for a given spec + - List all supported formats + """ + + def __init__(self, auto_register_builtin: bool = True): + """ + Initialize the parser registry. + + Args: + auto_register_builtin: If True, automatically register built-in parsers + """ + self._parsers: List[APISpecParser] = [] + + if auto_register_builtin: + self._register_builtin_parsers() + + def register(self, parser: APISpecParser) -> None: + """ + Register a new parser. + + Args: + parser: An instance of APISpecParser to register + """ + self._parsers.append(parser) + + def _register_builtin_parsers(self) -> None: + """Register all built-in parsers.""" + from .openapi_parser import OpenAPIParser + + # Register built-in parsers + self.register(OpenAPIParser()) + + # Future parsers will be registered here: + # from .postman_parser import PostmanCollectionParser + # self.register(PostmanCollectionParser()) + # from .graphql_parser import GraphQLParser + # self.register(GraphQLParser()) + + def find_parser(self, spec_data: Dict[str, Any]) -> Optional[APISpecParser]: + """ + Find a parser that can handle the given specification. + + Parsers are checked in registration order. The first parser that + returns True from can_parse() is returned. + + Args: + spec_data: The raw specification data as a dictionary + + Returns: + An APISpecParser instance that can handle the spec, or None if + no suitable parser is found + """ + for parser in self._parsers: + if parser.can_parse(spec_data): + return parser + return None + + def get_supported_formats(self) -> List[str]: + """ + Get a list of all supported format names. + + Returns: + List of human-readable format names + """ + return [parser.get_format_name() for parser in self._parsers] + + def get_parser_count(self) -> int: + """ + Get the number of registered parsers. + + Returns: + Number of registered parsers + """ + return len(self._parsers) diff --git a/src/ocp_agent/schema_discovery.py b/src/ocp_agent/schema_discovery.py index 99c013d..8656eae 100644 --- a/src/ocp_agent/schema_discovery.py +++ b/src/ocp_agent/schema_discovery.py @@ -1,110 +1,125 @@ """ OCP Schema Discovery -Provides automatic API discovery and tool generation from OpenAPI specifications, -enabling context-aware API interactions with zero infrastructure requirements. +Provides automatic API discovery and tool generation from various API specification +formats (OpenAPI, Postman, GraphQL, etc.), enabling context-aware API interactions +with zero infrastructure requirements. """ import json -import re import requests import logging import yaml from typing import Dict, List, Any, Optional -from dataclasses import dataclass -from urllib.parse import urljoin from pathlib import Path from .errors import SchemaDiscoveryError +from .parsers import ParserRegistry +from .parsers.base import OCPTool, OCPAPISpec logger = logging.getLogger(__name__) # Configuration constants DEFAULT_SPEC_TIMEOUT = 30 -DEFAULT_API_TITLE = 'Unknown API' -DEFAULT_API_VERSION = '1.0.0' -SUPPORTED_HTTP_METHODS = ['get', 'post', 'put', 'patch', 'delete'] -@dataclass -class OCPTool: - """Represents a discovered API tool/endpoint""" - name: str - description: str - method: str - path: str - parameters: Dict[str, Any] - response_schema: Optional[Dict[str, Any]] - operation_id: Optional[str] = None - tags: Optional[List[str]] = None - -@dataclass -class OCPAPISpec: - """Represents a parsed OpenAPI specification""" - base_url: str - title: str - version: str - description: str - tools: List[OCPTool] - raw_spec: Dict[str, Any] - name: Optional[str] = None +# Re-export for backward compatibility +__all__ = ['OCPSchemaDiscovery', 'OCPTool', 'OCPAPISpec'] class OCPSchemaDiscovery: """ - Automatic API discovery and tool generation from OpenAPI specifications. + Automatic API discovery and tool generation from various API specification formats. This enables automatic API discovery while maintaining OCP's zero-infrastructure - approach by parsing OpenAPI specs directly. + approach by parsing API specifications directly. + + Supported formats: + - OpenAPI 2.0 (Swagger) + - OpenAPI 3.0.x, 3.1.x, 3.2.x + - Future: Postman Collections, GraphQL schemas, Google Discovery format """ - def __init__(self): + def __init__(self, parser_registry: Optional[ParserRegistry] = None): + """ + Initialize schema discovery. + + Args: + parser_registry: Optional custom parser registry. If None, uses default + registry with built-in parsers. + """ self.cached_specs: Dict[str, OCPAPISpec] = {} - self._spec_version: Optional[str] = None + self.parser_registry = parser_registry or ParserRegistry() def discover_api(self, spec_path: str, base_url: Optional[str] = None, include_resources: Optional[List[str]] = None, path_prefix: Optional[str] = None) -> OCPAPISpec: """ - Discover API capabilities from OpenAPI specification. + Discover API capabilities from various API specification formats. Args: - spec_path: URL or file path to OpenAPI specification (JSON or YAML) + spec_path: URL or file path to API specification (JSON or YAML) base_url: Optional override for API base URL include_resources: Optional list of resource names to filter tools by (case-insensitive, first resource segment matching) path_prefix: Optional path prefix to strip before filtering (e.g., '/v1', '/api/v2') Returns: OCPAPISpec with discovered tools and capabilities + + Raises: + SchemaDiscoveryError: If the specification format is unsupported or parsing fails """ # Normalize cache key (absolute path for files, URL as-is) cache_key = self._normalize_cache_key(spec_path) # Check cache first if cache_key in self.cached_specs: - return self.cached_specs[cache_key] + cached_spec = self.cached_specs[cache_key] + + # If filters are specified, apply them to cached spec + if include_resources: + # Get the appropriate parser to apply filtering + parser = self.parser_registry.find_parser(cached_spec.raw_spec) + if parser and hasattr(parser, '_filter_tools_by_resources'): + filtered_tools = parser._filter_tools_by_resources( + cached_spec.tools, include_resources, path_prefix + ) + return OCPAPISpec( + base_url=cached_spec.base_url, + title=cached_spec.title, + version=cached_spec.version, + description=cached_spec.description, + tools=filtered_tools, + raw_spec=cached_spec.raw_spec + ) + + return cached_spec try: - # Fetch, detect version, and parse OpenAPI spec + # Fetch specification data spec_data = self._fetch_spec(spec_path) - self._spec_version = self._detect_spec_version(spec_data) - parsed_spec = self._parse_openapi_spec(spec_data, base_url) - - # Cache for future use - self.cached_specs[cache_key] = parsed_spec - # Apply resource filtering if specified (only on newly parsed specs) - if include_resources: - filtered_tools = self._filter_tools_by_resources(parsed_spec.tools, include_resources, path_prefix) - return OCPAPISpec( - base_url=parsed_spec.base_url, - title=parsed_spec.title, - version=parsed_spec.version, - description=parsed_spec.description, - tools=filtered_tools, - raw_spec=parsed_spec.raw_spec + # Find appropriate parser + parser = self.parser_registry.find_parser(spec_data) + if not parser: + supported_formats = self.parser_registry.get_supported_formats() + raise SchemaDiscoveryError( + f"Unsupported API specification format. " + f"Supported formats: {', '.join(supported_formats)}" ) + # Parse using the appropriate parser + parsed_spec = parser.parse(spec_data, base_url, include_resources, path_prefix) + + # Cache for future use (cache the unfiltered version) + if not include_resources: + self.cached_specs[cache_key] = parsed_spec + else: + # Cache the original without filters + unfiltered_spec = parser.parse(spec_data, base_url, None, None) + self.cached_specs[cache_key] = unfiltered_spec + return parsed_spec + + except SchemaDiscoveryError: + raise except Exception as e: - if isinstance(e, SchemaDiscoveryError): - raise raise SchemaDiscoveryError(f"Failed to discover API: {e}") def _normalize_cache_key(self, spec_path: str) -> str: @@ -162,495 +177,6 @@ def _fetch_from_file(self, file_path: str) -> Dict[str, Any]: except Exception as e: raise SchemaDiscoveryError(f"Failed to load spec from {file_path}: {e}") - def _detect_spec_version(self, spec: Dict[str, Any]) -> str: - """Detect OpenAPI/Swagger version from spec. - - Returns: - Version string: 'swagger_2', 'openapi_3.0', 'openapi_3.1', 'openapi_3.2' - """ - if "swagger" in spec: - swagger_version = spec["swagger"] - if swagger_version.startswith("2."): - return "swagger_2" - raise SchemaDiscoveryError(f"Unsupported Swagger version: {swagger_version}") - elif "openapi" in spec: - openapi_version = spec["openapi"] - if openapi_version.startswith("3.0"): - return "openapi_3.0" - elif openapi_version.startswith("3.1"): - return "openapi_3.1" - elif openapi_version.startswith("3.2"): - return "openapi_3.2" - raise SchemaDiscoveryError(f"Unsupported OpenAPI version: {openapi_version}") - - raise SchemaDiscoveryError("Unable to detect spec version: missing 'swagger' or 'openapi' field") - - def _resolve_refs( - self, - obj: Any, - root: Optional[Dict[str, Any]] = None, - resolution_stack: Optional[List[str]] = None, - memo: Optional[Dict[str, Any]] = None, - inside_polymorphic_keyword: bool = False - ) -> Any: - """Recursively resolve $ref references in OpenAPI spec with polymorphic keyword handling - - Args: - obj: Current object being processed (dict, list, or primitive) - root: Root spec document for looking up references - resolution_stack: Stack of refs currently being resolved (for circular detection) - memo: Memoization cache for already-resolved refs - inside_polymorphic_keyword: True if currently inside anyOf/oneOf/allOf - - Returns: - Object with all resolvable $refs replaced by their definitions - """ - # Initialize on first call - if root is None: - root = obj - if resolution_stack is None: - resolution_stack = [] - if memo is None: - memo = {} - - # Handle dict objects - if isinstance(obj, dict): - # Check for polymorphic keywords - process with flag set - if 'anyOf' in obj: - result = {'anyOf': [self._resolve_refs(item, root, resolution_stack, memo, inside_polymorphic_keyword=True) for item in obj['anyOf']]} - # Include other keys if present - for k, v in obj.items(): - if k != 'anyOf': - result[k] = self._resolve_refs(v, root, resolution_stack, memo, inside_polymorphic_keyword) - return result - - if 'oneOf' in obj: - result = {'oneOf': [self._resolve_refs(item, root, resolution_stack, memo, inside_polymorphic_keyword=True) for item in obj['oneOf']]} - for k, v in obj.items(): - if k != 'oneOf': - result[k] = self._resolve_refs(v, root, resolution_stack, memo, inside_polymorphic_keyword) - return result - - if 'allOf' in obj: - result = {'allOf': [self._resolve_refs(item, root, resolution_stack, memo, inside_polymorphic_keyword=True) for item in obj['allOf']]} - for k, v in obj.items(): - if k != 'allOf': - result[k] = self._resolve_refs(v, root, resolution_stack, memo, inside_polymorphic_keyword) - return result - - # Check if this is a $ref - if '$ref' in obj and len(obj) == 1: - ref_path = obj['$ref'] - - # Only handle internal refs (start with #/) - if not ref_path.startswith('#/'): - return obj - - # If inside polymorphic keyword, check if ref points to an object - if inside_polymorphic_keyword: - try: - resolved = self._lookup_ref(root, ref_path) - if resolved is not None: - # Check if it's an object schema - if resolved.get('type') == 'object' or 'properties' in resolved: - # Keep the $ref unresolved for object schemas - return obj - except Exception: - # If lookup fails, keep the ref - return obj - - # Check memo cache - if ref_path in memo: - return memo[ref_path] - - # Check for circular reference - if ref_path in resolution_stack: - # Return a placeholder to break the cycle - placeholder = {'type': 'object', 'description': 'Circular reference'} - memo[ref_path] = placeholder - return placeholder - - # Resolve the reference - try: - resolved = self._lookup_ref(root, ref_path) - if resolved is not None: - # Recursively resolve the resolved object with updated stack - new_stack = resolution_stack + [ref_path] - result = self._resolve_refs(resolved, root, new_stack, memo, inside_polymorphic_keyword) - memo[ref_path] = result - return result - except Exception: - # If lookup fails, return a placeholder - placeholder = {'type': 'object', 'description': 'Unresolved reference'} - memo[ref_path] = placeholder - return placeholder - - return obj - - # Not a $ref, recursively process all values - return {k: self._resolve_refs(v, root, resolution_stack, memo, inside_polymorphic_keyword) for k, v in obj.items()} - - # Handle list objects - elif isinstance(obj, list): - return [self._resolve_refs(item, root, resolution_stack, memo, inside_polymorphic_keyword) for item in obj] - - # Primitives pass through unchanged - return obj - - def _lookup_ref(self, root: Dict[str, Any], ref_path: str) -> Any: - """Look up a reference path in the spec document - - Args: - root: Root spec document - ref_path: Reference path like '#/components/schemas/User' - - Returns: - The referenced object, or None if not found - """ - # Remove the leading '#/' and split by '/' - if not ref_path.startswith('#/'): - return None - - path_parts = ref_path[2:].split('/') - - # Navigate through the spec - current = root - for part in path_parts: - if isinstance(current, dict) and part in current: - current = current[part] - else: - return None - - return current - - def _parse_openapi_spec(self, spec_data: Dict[str, Any], base_url_override: Optional[str] = None) -> OCPAPISpec: - """Parse OpenAPI specification into OCP tools""" - - # Extract basic info - info = spec_data.get('info', {}) - title = info.get('title', DEFAULT_API_TITLE) - version = info.get('version', DEFAULT_API_VERSION) - description = info.get('description', '') - - # Determine base URL (version-specific) - base_url = base_url_override - if not base_url: - base_url = self._extract_base_url(spec_data) - - # Create memoization cache for $ref resolution - memo_cache = {} - - # Parse paths into tools - tools = [] - paths = spec_data.get('paths', {}) - - for path, path_item in paths.items(): - for method, operation in path_item.items(): - if method.lower() in SUPPORTED_HTTP_METHODS: - tool = self._create_tool_from_operation( - path, method.upper(), operation, spec_data, memo_cache - ) - if tool: - tools.append(tool) - - return OCPAPISpec( - base_url=base_url, - title=title, - version=version, - description=description, - tools=tools, - raw_spec=spec_data - ) - - def _extract_base_url(self, spec_data: Dict[str, Any]) -> str: - """Extract base URL from spec (version-aware).""" - if self._spec_version == "swagger_2": - # Swagger 2.0: construct from host, basePath, and schemes - schemes = spec_data.get('schemes', ['https']) - host = spec_data.get('host', '') - base_path = spec_data.get('basePath', '') - - if host: - scheme = schemes[0] if schemes else 'https' - return f"{scheme}://{host}{base_path}" - return '' - else: - # OpenAPI 3.x: use servers array - servers = spec_data.get('servers', []) - if servers: - return servers[0].get('url', '') - return '' - - def _normalize_tool_name(self, name: str) -> str: - """Normalize tool name to camelCase, removing special characters. - - Converts various naming patterns to consistent camelCase: - - 'meta/root' → 'metaRoot' - - 'repos/disable-vulnerability-alerts' → 'reposDisableVulnerabilityAlerts' - - 'admin_apps_approve' → 'adminAppsApprove' - - 'FetchAccount' → 'fetchAccount' - - 'v2010/Accounts' → 'v2010Accounts' - - 'get_users_list' → 'getUsersList' - - 'SMS/send' → 'smsSend' - """ - # Handle empty or None names - if not name: - return name - - # First, split PascalCase/camelCase words (e.g., "FetchAccount" -> "Fetch Account") - # Insert space before uppercase letters that follow lowercase letters or digits - pascal_split = re.sub(r'([a-z0-9])([A-Z])', r'\1 \2', name) - - # Replace separators (/, _, -, .) with spaces for processing - # Also handle multiple consecutive separators like // - normalized = re.sub(r'[/_.-]+', ' ', pascal_split) - - # Split into words and filter out empty strings - words = [word for word in normalized.split() if word] - - if not words: - return name - - # Convert to camelCase: first word lowercase, rest capitalize - camel_case_words = [words[0].lower()] - for word in words[1:]: - camel_case_words.append(word.capitalize()) - - return ''.join(camel_case_words) - - def _is_valid_tool_name(self, name: str) -> bool: - """Check if a normalized tool name is valid. - - A valid tool name must: - - Not be empty - - Not consist only of special characters - - Start with a letter - - Contain at least one alphanumeric character - """ - if not name: - return False - - # Must start with a letter - if not name[0].isalpha(): - return False - - # Must contain at least one alphanumeric character - if not any(c.isalnum() for c in name): - return False - - return True - - def _create_tool_from_operation(self, path: str, method: str, operation: Dict[str, Any], spec_data: Dict[str, Any], memo_cache: Dict[str, Any]) -> Optional[OCPTool]: - """Create OCP tool from OpenAPI operation""" - - # Generate tool name with proper validation and fallback logic - operation_id = operation.get('operationId') - tool_name = None - - # Try operationId first - if operation_id: - normalized_name = self._normalize_tool_name(operation_id) - if self._is_valid_tool_name(normalized_name): - tool_name = normalized_name - - # If operationId failed, try fallback naming - if not tool_name: - # Generate name from path and method - clean_path = path.replace('/', '_').replace('{', '').replace('}', '') - fallback_name = f"{method.lower()}{clean_path}" - normalized_fallback = self._normalize_tool_name(fallback_name) - if self._is_valid_tool_name(normalized_fallback): - tool_name = normalized_fallback - - # If we can't generate a valid tool name, skip this operation - if not tool_name: - logger.warning(f"Skipping operation {method} {path}: unable to generate valid tool name") - return None - - # Get description - description = operation.get('summary', '') or operation.get('description', '') - if not description: - description = "No description provided" - - # Parse parameters (version-aware) - parameters = self._parse_parameters(operation.get('parameters', []), spec_data, memo_cache) - - # Add request body parameters (version-specific) - if method in ['POST', 'PUT', 'PATCH']: - if self._spec_version == "swagger_2": - # Swagger 2.0: body is in parameters array - for param in operation.get('parameters', []): - body_params = self._parse_swagger2_body_parameter(param, spec_data, memo_cache) - parameters.update(body_params) - else: - # OpenAPI 3.x: separate requestBody field - if 'requestBody' in operation: - body_params = self._parse_openapi3_request_body(operation['requestBody'], spec_data, memo_cache) - parameters.update(body_params) - - # Parse response schema - response_schema = self._parse_responses(operation.get('responses', {}), spec_data, memo_cache) - - # Get tags - tags = operation.get('tags', []) - - return OCPTool( - name=tool_name, - description=description, - method=method, - path=path, - parameters=parameters, - response_schema=response_schema, - operation_id=operation_id, - tags=tags - ) - - def _parse_parameters(self, parameters: List[Dict[str, Any]], spec_data: Dict[str, Any], memo_cache: Dict[str, Any]) -> Dict[str, Any]: - """Parse OpenAPI parameters into tool parameter schema""" - parsed_params = {} - - for param in parameters: - name = param.get('name') - if not name: - continue - - param_schema = { - 'description': param.get('description', ''), - 'required': param.get('required', False), - 'location': param.get('in', 'query'), # query, path, header, cookie - 'type': 'string' # Default type - } - - # Extract type from schema - schema = param.get('schema', {}) - if schema: - # Resolve any $refs in the parameter schema - schema = self._resolve_refs(schema, spec_data, [], memo_cache) - param_schema['type'] = schema.get('type', 'string') - if 'enum' in schema: - param_schema['enum'] = schema['enum'] - if 'format' in schema: - param_schema['format'] = schema['format'] - - parsed_params[name] = param_schema - - return parsed_params - - def _parse_openapi3_request_body(self, request_body: Dict[str, Any], spec_data: Dict[str, Any], memo_cache: Dict[str, Any]) -> Dict[str, Any]: - """Parse request body into parameters (OpenAPI 3.x only)""" - parameters = {} - - content = request_body.get('content', {}) - - # Look for JSON content first - json_content = content.get('application/json', {}) - if json_content and 'schema' in json_content: - schema = json_content['schema'] - - # Resolve any $refs in the request body schema - schema = self._resolve_refs(schema, spec_data, [], memo_cache) - - # Handle object schemas - if schema.get('type') == 'object': - properties = schema.get('properties', {}) - required_fields = schema.get('required', []) - - for prop_name, prop_schema in properties.items(): - parameters[prop_name] = { - 'description': prop_schema.get('description', ''), - 'required': prop_name in required_fields, - 'location': 'body', - 'type': prop_schema.get('type', 'string') - } - - if 'enum' in prop_schema: - parameters[prop_name]['enum'] = prop_schema['enum'] - - return parameters - - def _parse_swagger2_body_parameter(self, param: Dict[str, Any], spec_data: Dict[str, Any], memo_cache: Dict[str, Any]) -> Dict[str, Any]: - """Parse Swagger 2.0 body parameter into parameters.""" - parameters = {} - - if param.get('in') == 'body' and 'schema' in param: - schema = param['schema'] - - # Resolve any $refs in the body schema - schema = self._resolve_refs(schema, spec_data, [], memo_cache) - - # Handle object schemas - if schema.get('type') == 'object': - properties = schema.get('properties', {}) - required_fields = schema.get('required', []) - - for prop_name, prop_schema in properties.items(): - parameters[prop_name] = { - 'description': prop_schema.get('description', ''), - 'required': prop_name in required_fields, - 'location': 'body', - 'type': prop_schema.get('type', 'string') - } - - if 'enum' in prop_schema: - parameters[prop_name]['enum'] = prop_schema['enum'] - - return parameters - - def _parse_responses(self, responses: Dict[str, Any], spec_data: Dict[str, Any], memo_cache: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """Parse response schemas (version-aware)""" - # Look for successful response (200, 201, etc.) - for status_code, response in responses.items(): - if str(status_code).startswith('2'): # 2xx success codes - if self._spec_version == "swagger_2": - # Swagger 2.0: schema is directly in response - if 'schema' in response: - schema = response['schema'] - # Resolve any $refs in the response schema - return self._resolve_refs(schema, spec_data, [], memo_cache) - else: - # OpenAPI 3.x: schema is in content.application/json - content = response.get('content', {}) - json_content = content.get('application/json', {}) - - if json_content and 'schema' in json_content: - schema = json_content['schema'] - # Resolve any $refs in the response schema - return self._resolve_refs(schema, spec_data, [], memo_cache) - - return None - - def _filter_tools_by_resources(self, tools: List[OCPTool], include_resources: List[str], path_prefix: Optional[str] = None) -> List[OCPTool]: - """Filter tools to only include those whose first resource segment matches include_resources""" - if not include_resources: - return tools - - # Normalize resource names to lowercase for case-insensitive matching - normalized_resources = [resource.lower() for resource in include_resources] - - filtered_tools = [] - for tool in tools: - path = tool.path - - # Strip path prefix if provided - if path_prefix: - prefix_lower = path_prefix.lower() - path_lower = path.lower() - if path_lower.startswith(prefix_lower): - path = path[len(path_prefix):] - - # Extract path segments by splitting on both '/' and '.' - path_lower = path.lower() - # Replace dots with slashes for uniform splitting - path_normalized = path_lower.replace('.', '/') - # Split by '/' and filter out empty segments and parameter placeholders - segments = [seg for seg in path_normalized.split('/') if seg and not seg.startswith('{')] - - # Check if the first segment matches any of the include_resources - if segments and segments[0] in normalized_resources: - filtered_tools.append(tool) - - return filtered_tools - def get_tools_by_tag(self, api_spec: OCPAPISpec, tag: str) -> List[OCPTool]: """Get tools filtered by tag""" return [tool for tool in api_spec.tools if tag in (tool.tags or [])] @@ -693,4 +219,39 @@ def generate_tool_documentation(self, tool: OCPTool) -> str: def clear_cache(self): """Clear cached API specifications""" - self.cached_specs.clear() \ No newline at end of file + self.cached_specs.clear() + + def get_supported_formats(self) -> List[str]: + """ + Get list of supported API specification formats. + + Returns: + List of format names (e.g., ['OpenAPI', 'Postman Collection']) + """ + return self.parser_registry.get_supported_formats() + + def register_parser(self, parser) -> None: + """ + Register a custom parser for additional API specification formats. + + Args: + parser: An instance of APISpecParser + + Example: + from ocp_agent.parsers.base import APISpecParser + + class MyCustomParser(APISpecParser): + def can_parse(self, spec_data): + return 'myformat' in spec_data + + def parse(self, spec_data, base_url_override=None, **kwargs): + # Parse logic here + pass + + def get_format_name(self): + return "My Custom Format" + + discovery = OCPSchemaDiscovery() + discovery.register_parser(MyCustomParser()) + """ + self.parser_registry.register(parser) \ No newline at end of file diff --git a/tests/test_openapi_parser.py b/tests/test_openapi_parser.py new file mode 100644 index 0000000..f431b33 --- /dev/null +++ b/tests/test_openapi_parser.py @@ -0,0 +1,342 @@ +""" +Tests for OpenAPI parser functionality. +""" + +import pytest +from ocp_agent.parsers import OpenAPIParser +from ocp_agent.parsers.base import OCPAPISpec, OCPTool +from ocp_agent.errors import SchemaDiscoveryError + + +class TestOpenAPIParser: + """Test OpenAPI parser functionality.""" + + @pytest.fixture + def parser(self): + """Create an OpenAPI parser instance.""" + return OpenAPIParser() + + @pytest.fixture + def sample_openapi3_spec(self): + """Sample OpenAPI 3.0 specification.""" + return { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0", + "description": "A test API" + }, + "servers": [ + {"url": "https://api.example.com"} + ], + "paths": { + "/users": { + "get": { + "operationId": "listUsers", + "summary": "List users", + "parameters": [ + { + "name": "limit", + "in": "query", + "schema": {"type": "integer"}, + "required": False + } + ], + "responses": { + "200": { + "description": "List of users", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": {"type": "object"} + } + } + } + } + } + }, + "post": { + "operationId": "createUser", + "summary": "Create user", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string"} + }, + "required": ["name", "email"] + } + } + } + }, + "responses": { + "201": { + "description": "User created" + } + } + } + } + } + } + + @pytest.fixture + def swagger2_spec(self): + """Sample Swagger 2.0 specification.""" + return { + "swagger": "2.0", + "info": { + "title": "Swagger API", + "version": "1.0.0" + }, + "host": "api.example.com", + "basePath": "/v1", + "schemes": ["https"], + "paths": { + "/users": { + "get": { + "operationId": "getUsers", + "summary": "Get users", + "parameters": [ + { + "name": "id", + "in": "path", + "required": True, + "type": "string" + } + ], + "responses": { + "200": { + "description": "User details", + "schema": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "name": {"type": "string"} + } + } + } + } + }, + "post": { + "operationId": "createUser", + "summary": "Create user", + "parameters": [ + { + "name": "body", + "in": "body", + "required": True, + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string"} + }, + "required": ["name", "email"] + } + } + ], + "responses": { + "201": { + "description": "User created" + } + } + } + } + } + } + + # Parser detection tests + def test_can_parse_openapi3(self, parser): + """Test parser can detect OpenAPI 3.x spec.""" + spec = {"openapi": "3.0.0", "info": {}, "paths": {}} + assert parser.can_parse(spec) is True + + def test_can_parse_swagger2(self, parser): + """Test parser can detect Swagger 2.0 spec.""" + spec = {"swagger": "2.0", "info": {}, "paths": {}} + assert parser.can_parse(spec) is True + + def test_cannot_parse_unknown(self, parser): + """Test parser rejects unknown formats.""" + spec = {"someformat": "1.0", "data": {}} + assert parser.can_parse(spec) is False + + def test_get_format_name(self, parser): + """Test parser returns correct format name.""" + assert parser.get_format_name() == "OpenAPI" + + # Version detection tests + def test_detect_openapi_3_0(self, parser): + """Test detection of OpenAPI 3.0.x.""" + spec = {"openapi": "3.0.0"} + version = parser._detect_spec_version(spec) + assert version == "openapi_3.0" + + def test_detect_openapi_3_1(self, parser): + """Test detection of OpenAPI 3.1.x.""" + spec = {"openapi": "3.1.0"} + version = parser._detect_spec_version(spec) + assert version == "openapi_3.1" + + def test_detect_swagger_2(self, parser): + """Test detection of Swagger 2.0.""" + spec = {"swagger": "2.0"} + version = parser._detect_spec_version(spec) + assert version == "swagger_2" + + def test_detect_unsupported_version(self, parser): + """Test error on unsupported version.""" + spec = {"openapi": "4.0.0"} + with pytest.raises(SchemaDiscoveryError, match="Unsupported OpenAPI version"): + parser._detect_spec_version(spec) + + # OpenAPI 3.x parsing tests + def test_parse_openapi3_spec(self, parser, sample_openapi3_spec): + """Test parsing complete OpenAPI 3.x spec.""" + api_spec = parser.parse(sample_openapi3_spec) + + assert isinstance(api_spec, OCPAPISpec) + assert api_spec.title == "Test API" + assert api_spec.version == "1.0.0" + assert api_spec.description == "A test API" + assert api_spec.base_url == "https://api.example.com" + assert len(api_spec.tools) == 2 # GET and POST /users + + def test_parse_openapi3_tools(self, parser, sample_openapi3_spec): + """Test tool extraction from OpenAPI 3.x spec.""" + api_spec = parser.parse(sample_openapi3_spec) + + # Find the GET tool + get_tool = next((t for t in api_spec.tools if t.method == "GET"), None) + assert get_tool is not None + assert get_tool.name == "listUsers" + assert get_tool.path == "/users" + assert "limit" in get_tool.parameters + + # Find the POST tool + post_tool = next((t for t in api_spec.tools if t.method == "POST"), None) + assert post_tool is not None + assert post_tool.name == "createUser" + assert "name" in post_tool.parameters + assert "email" in post_tool.parameters + assert post_tool.parameters["name"]["required"] is True + + # Swagger 2.0 parsing tests + def test_parse_swagger2_spec(self, parser, swagger2_spec): + """Test parsing Swagger 2.0 spec.""" + api_spec = parser.parse(swagger2_spec) + + assert isinstance(api_spec, OCPAPISpec) + assert api_spec.title == "Swagger API" + assert api_spec.version == "1.0.0" + assert api_spec.base_url == "https://api.example.com/v1" + assert len(api_spec.tools) == 2 + + def test_swagger2_base_url_extraction(self, parser, swagger2_spec): + """Test base URL extraction from Swagger 2.0.""" + api_spec = parser.parse(swagger2_spec) + # Should combine scheme + host + basePath + assert api_spec.base_url == "https://api.example.com/v1" + + def test_swagger2_body_parameters(self, parser, swagger2_spec): + """Test Swagger 2.0 body parameter parsing.""" + api_spec = parser.parse(swagger2_spec) + + post_tool = next((t for t in api_spec.tools if t.method == "POST"), None) + assert post_tool is not None + assert "name" in post_tool.parameters + assert "email" in post_tool.parameters + assert post_tool.parameters["name"]["location"] == "body" + + # Tool name normalization tests + def test_normalize_tool_name_slash_separators(self, parser): + """Test normalization of operationId with slash separators.""" + assert parser._normalize_tool_name("meta/root") == "metaRoot" + assert parser._normalize_tool_name("repos/disable-vulnerability-alerts") == "reposDisableVulnerabilityAlerts" + + def test_normalize_tool_name_underscore_separators(self, parser): + """Test normalization of operationId with underscore separators.""" + assert parser._normalize_tool_name("admin_apps_approve") == "adminAppsApprove" + assert parser._normalize_tool_name("get_users_list") == "getUsersList" + + def test_normalize_tool_name_pascal_case(self, parser): + """Test normalization of PascalCase operationIds.""" + assert parser._normalize_tool_name("FetchAccount") == "fetchAccount" + assert parser._normalize_tool_name("GetUserProfile") == "getUserProfile" + + def test_normalize_tool_name_numbers(self, parser): + """Test that numbers are preserved in normalization.""" + assert parser._normalize_tool_name("v2010/Accounts") == "v2010Accounts" + + def test_normalize_tool_name_acronyms(self, parser): + """Test that acronyms are handled.""" + assert parser._normalize_tool_name("SMS/send") == "smsSend" + + def test_valid_tool_name(self, parser): + """Test tool name validation.""" + assert parser._is_valid_tool_name("metaRoot") is True + assert parser._is_valid_tool_name("getUsersList") is True + assert parser._is_valid_tool_name("") is False + assert parser._is_valid_tool_name("123invalid") is False + assert parser._is_valid_tool_name("___") is False + + # Resource filtering tests + def test_filter_tools_by_resources(self, parser): + """Test filtering tools by resource names.""" + tools = [ + OCPTool(name="reposGet", description="Get repo", method="GET", + path="/repos/{owner}", parameters={}, response_schema=None), + OCPTool(name="issuesGet", description="Get issue", method="GET", + path="/issues/{id}", parameters={}, response_schema=None), + ] + + filtered = parser._filter_tools_by_resources(tools, ["repos"]) + assert len(filtered) == 1 + assert filtered[0].name == "reposGet" + + def test_filter_tools_case_insensitive(self, parser): + """Test resource filtering is case-insensitive.""" + tools = [ + OCPTool(name="reposGet", description="Get repo", method="GET", + path="/repos/{owner}", parameters={}, response_schema=None), + ] + + filtered = parser._filter_tools_by_resources(tools, ["REPOS"]) + assert len(filtered) == 1 + + def test_filter_tools_with_path_prefix(self, parser): + """Test filtering with path prefix stripping.""" + tools = [ + OCPTool(name="paymentsGet", description="Get payment", method="GET", + path="/v1/payments", parameters={}, response_schema=None), + ] + + filtered = parser._filter_tools_by_resources(tools, ["payments"], path_prefix="/v1") + assert len(filtered) == 1 + + # Base URL override tests + def test_base_url_override(self, parser, sample_openapi3_spec): + """Test that base_url_override is respected.""" + api_spec = parser.parse(sample_openapi3_spec, base_url_override="https://custom.api.com") + assert api_spec.base_url == "https://custom.api.com" + + # Resource filtering integration test + def test_parse_with_resource_filtering(self, parser, sample_openapi3_spec): + """Test parsing with resource filtering.""" + # Add a second path to test filtering + sample_openapi3_spec["paths"]["/repos/{id}"] = { + "get": { + "operationId": "getRepo", + "summary": "Get repo", + "responses": {"200": {"description": "OK"}} + } + } + + api_spec = parser.parse(sample_openapi3_spec, include_resources=["users"]) + # Should only include /users tools, not /repos + assert len(api_spec.tools) == 2 # GET and POST /users + assert all("/users" in tool.path for tool in api_spec.tools) diff --git a/tests/test_parser_registry.py b/tests/test_parser_registry.py new file mode 100644 index 0000000..51bcb5f --- /dev/null +++ b/tests/test_parser_registry.py @@ -0,0 +1,99 @@ +""" +Tests for parser registry functionality. +""" + +import pytest +from ocp_agent.parsers import ParserRegistry, APISpecParser, OpenAPIParser +from ocp_agent.parsers.base import OCPAPISpec + + +class TestParserRegistry: + """Test parser registry functionality.""" + + @pytest.fixture + def registry(self): + """Create a parser registry instance.""" + return ParserRegistry(auto_register_builtin=False) + + @pytest.fixture + def registry_with_builtin(self): + """Create a parser registry with built-in parsers.""" + return ParserRegistry(auto_register_builtin=True) + + def test_registry_initialization_empty(self, registry): + """Test registry initialization without built-in parsers.""" + assert registry.get_parser_count() == 0 + assert registry.get_supported_formats() == [] + + def test_registry_initialization_with_builtin(self, registry_with_builtin): + """Test registry initialization with built-in parsers.""" + assert registry_with_builtin.get_parser_count() > 0 + assert "OpenAPI" in registry_with_builtin.get_supported_formats() + + def test_register_parser(self, registry): + """Test registering a custom parser.""" + parser = OpenAPIParser() + registry.register(parser) + + assert registry.get_parser_count() == 1 + assert "OpenAPI" in registry.get_supported_formats() + + def test_find_parser_openapi(self, registry_with_builtin): + """Test finding parser for OpenAPI spec.""" + openapi_spec = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0.0"}, + "paths": {} + } + + parser = registry_with_builtin.find_parser(openapi_spec) + assert parser is not None + assert parser.get_format_name() == "OpenAPI" + + def test_find_parser_swagger(self, registry_with_builtin): + """Test finding parser for Swagger 2.0 spec.""" + swagger_spec = { + "swagger": "2.0", + "info": {"title": "Test", "version": "1.0.0"}, + "paths": {} + } + + parser = registry_with_builtin.find_parser(swagger_spec) + assert parser is not None + assert parser.get_format_name() == "OpenAPI" + + def test_find_parser_no_match(self, registry_with_builtin): + """Test that None is returned when no parser matches.""" + unknown_spec = { + "some_format": "1.0", + "data": {} + } + + parser = registry_with_builtin.find_parser(unknown_spec) + assert parser is None + + def test_multiple_parsers_registration(self, registry): + """Test registering multiple parsers.""" + parser1 = OpenAPIParser() + + registry.register(parser1) + + assert registry.get_parser_count() == 1 + formats = registry.get_supported_formats() + assert "OpenAPI" in formats + + def test_parser_order_matters(self, registry): + """Test that parsers are checked in registration order.""" + # First parser registered will be checked first + parser1 = OpenAPIParser() + registry.register(parser1) + + openapi_spec = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0.0"}, + "paths": {} + } + + found_parser = registry.find_parser(openapi_spec) + assert found_parser is not None + assert found_parser.get_format_name() == "OpenAPI" diff --git a/tests/test_schema_discovery.py b/tests/test_schema_discovery.py index 02490b9..b077fdd 100644 --- a/tests/test_schema_discovery.py +++ b/tests/test_schema_discovery.py @@ -6,6 +6,7 @@ import json from unittest.mock import Mock, patch, MagicMock from ocp_agent.schema_discovery import OCPSchemaDiscovery, OCPTool, OCPAPISpec +from ocp_agent.parsers import OpenAPIParser from ocp_agent.errors import SchemaDiscoveryError @@ -17,6 +18,11 @@ def discovery(self): """Create a schema discovery instance.""" return OCPSchemaDiscovery() + @pytest.fixture + def openapi_parser(self): + """Create an OpenAPI parser instance for testing parser-specific functionality.""" + return OpenAPIParser() + @pytest.fixture def sample_openapi_spec(self): """Basic OpenAPI specification without operationIds for testing fallback naming.""" @@ -256,217 +262,6 @@ def openapi_spec_edge_cases(self): } } - def test_parse_openapi_spec(self, discovery, sample_openapi_spec): - """Test parsing OpenAPI specification.""" - api_spec = discovery._parse_openapi_spec( - sample_openapi_spec, - "https://api.example.com" - ) - - assert isinstance(api_spec, OCPAPISpec) - assert api_spec.title == "Test API" - assert api_spec.version == "1.0.0" - assert api_spec.base_url == "https://api.example.com" - assert len(api_spec.tools) == 3 # GET /users, POST /users, GET /users/{id} - - def test_generate_tools_from_spec(self, discovery, sample_openapi_spec): - """Test tool generation from OpenAPI specification.""" - api_spec = discovery._parse_openapi_spec( - sample_openapi_spec, - "https://api.example.com" - ) - - tools = api_spec.tools - assert len(tools) == 3 # GET /users, POST /users, GET /users/{id} - - # Check that we have the expected tools with deterministic names - tool_names = [t.name for t in tools] - expected_names = ["getUsers", "postUsers", "getUsersId"] # camelCase naming - - for expected_name in expected_names: - assert expected_name in tool_names, f"Expected tool name '{expected_name}' not found in {tool_names}" - - # Check GET /users tool - get_users = next((t for t in tools if t.name == "getUsers"), None) - assert get_users is not None - assert get_users.method == "GET" - assert get_users.path == "/users" - assert get_users.description == "List users" - assert "limit" in get_users.parameters - assert get_users.parameters["limit"]["type"] == "integer" - assert get_users.parameters["limit"]["location"] == "query" - assert not get_users.parameters["limit"]["required"] - assert get_users.response_schema is not None - assert get_users.response_schema["type"] == "array" - - # Check POST /users tool - post_users = next((t for t in tools if t.name == "postUsers"), None) - assert post_users is not None - assert post_users.method == "POST" - assert post_users.path == "/users" - assert "name" in post_users.parameters - assert "email" in post_users.parameters - assert post_users.parameters["name"]["required"] - assert post_users.parameters["email"]["required"] - assert post_users.response_schema is not None - assert post_users.response_schema["type"] == "object" - - # Check GET /users/{id} tool - get_users_id = next((t for t in tools if t.name == "getUsersId"), None) - assert get_users_id is not None - assert get_users_id.method == "GET" - assert get_users_id.path == "/users/{id}" - assert "id" in get_users_id.parameters - assert get_users_id.parameters["id"]["required"] - assert get_users_id.parameters["id"]["location"] == "path" - assert get_users_id.response_schema is not None - assert get_users_id.response_schema["type"] == "object" - - def test_normalize_tool_name_slash_separators(self, discovery): - """Test normalization of operationId with slash separators.""" - assert discovery._normalize_tool_name("meta/root") == "metaRoot" - assert discovery._normalize_tool_name("repos/disable-vulnerability-alerts") == "reposDisableVulnerabilityAlerts" - assert discovery._normalize_tool_name("users/list-followers") == "usersListFollowers" - - def test_normalize_tool_name_underscore_separators(self, discovery): - """Test normalization of operationId with underscore separators.""" - assert discovery._normalize_tool_name("admin_apps_approve") == "adminAppsApprove" - assert discovery._normalize_tool_name("chat_post_message") == "chatPostMessage" - assert discovery._normalize_tool_name("users_list_all") == "usersListAll" - - def test_normalize_tool_name_pascal_case(self, discovery): - """Test normalization of PascalCase operationIds.""" - assert discovery._normalize_tool_name("FetchAccount") == "fetchAccount" - assert discovery._normalize_tool_name("CreateAccount") == "createAccount" - assert discovery._normalize_tool_name("ListAvailablePhoneNumberLocal") == "listAvailablePhoneNumberLocal" - - def test_normalize_tool_name_numbers_preserved(self, discovery): - """Test that numbers are preserved in normalization.""" - assert discovery._normalize_tool_name("v2010/Accounts") == "v2010Accounts" - assert discovery._normalize_tool_name("api_v2_users") == "apiV2Users" - assert discovery._normalize_tool_name("get-v3-repos") == "getV3Repos" - - def test_normalize_tool_name_acronyms_preserved(self, discovery): - """Test that acronyms are converted to camelCase.""" - assert discovery._normalize_tool_name("SMS/send") == "smsSend" - assert discovery._normalize_tool_name("api/HTTP_request") == "apiHttpRequest" - assert discovery._normalize_tool_name("get_API_key") == "getApiKey" - - def test_normalize_tool_name_fallback_patterns(self, discovery): - """Test normalization of fallback generated names.""" - assert discovery._normalize_tool_name("get_users") == "getUsers" - assert discovery._normalize_tool_name("post_users") == "postUsers" - assert discovery._normalize_tool_name("get_users_id") == "getUsersId" - assert discovery._normalize_tool_name("delete_repos_issues_comments_id") == "deleteReposIssuesCommentsId" - - def test_normalize_tool_name_multiple_separators(self, discovery): - """Test handling of multiple consecutive separators.""" - assert discovery._normalize_tool_name("api//users") == "apiUsers" - assert discovery._normalize_tool_name("admin___apps") == "adminApps" - assert discovery._normalize_tool_name("repos---list") == "reposList" - assert discovery._normalize_tool_name("api./..users") == "apiUsers" - - def test_normalize_tool_name_edge_cases(self, discovery): - """Test edge cases for normalization.""" - # Empty and None - assert discovery._normalize_tool_name("") == "" - assert discovery._normalize_tool_name(None) == None - - # Single character - assert discovery._normalize_tool_name("a") == "a" - assert discovery._normalize_tool_name("A") == "a" - - # Only separators should return original (but will be caught by validation) - assert discovery._normalize_tool_name("///") == "///" - assert discovery._normalize_tool_name("___") == "___" - - # Single word - assert discovery._normalize_tool_name("users") == "users" - assert discovery._normalize_tool_name("USERS") == "users" - - def test_valid_tool_name_validation(self, discovery): - """Test tool name validation logic.""" - # Valid names - assert discovery._is_valid_tool_name("metaRoot") == True - assert discovery._is_valid_tool_name("a") == True - assert discovery._is_valid_tool_name("test123") == True - assert discovery._is_valid_tool_name("getUserId") == True - - # Invalid names - assert discovery._is_valid_tool_name("") == False - assert discovery._is_valid_tool_name("///") == False - assert discovery._is_valid_tool_name("___") == False - assert discovery._is_valid_tool_name("123abc") == False # Starts with number - assert discovery._is_valid_tool_name("!@#") == False # Only special chars - - def test_operation_id_integration(self, discovery, openapi_spec_with_operation_ids): - """Test that operationId normalization works in full tool generation flow.""" - api_spec = discovery._parse_openapi_spec( - openapi_spec_with_operation_ids, - "https://api.example.com" - ) - - tools = api_spec.tools - tool_names = [t.name for t in tools] - - # Verify normalized operationId names - expected_names = [ - "metaRoot", # meta/root - "reposDisableVulnerabilityAlerts", # repos/disable-vulnerability-alerts - "adminAppsApprove", # admin_apps_approve - "fetchAccount", # FetchAccount - "createAccount", # CreateAccount - "v2010Accounts", # v2010/Accounts - "smsSend", # SMS/send - "getUsersNoOperationId", # fallback: get + /users/no-operation-id - "apiUsers" # api//users - ] - - for expected_name in expected_names: - assert expected_name in tool_names, f"Expected tool name '{expected_name}' not found in {tool_names}" - - # Verify specific tools have correct properties - meta_tool = next((t for t in tools if t.name == "metaRoot"), None) - assert meta_tool is not None - assert meta_tool.operation_id == "meta/root" # Original preserved - assert meta_tool.method == "GET" - assert meta_tool.path == "/meta" - - # Test acronym preservation - sms_tool = next((t for t in tools if t.name == "smsSend"), None) - assert sms_tool is not None - assert sms_tool.operation_id == "SMS/send" - - # Test fallback naming for missing operationId - fallback_tool = next((t for t in tools if t.name == "getUsersNoOperationId"), None) - assert fallback_tool is not None - assert fallback_tool.operation_id is None # No operationId in spec - - def test_edge_cases_integration(self, discovery, openapi_spec_edge_cases): - """Test edge cases in full tool generation flow.""" - api_spec = discovery._parse_openapi_spec( - openapi_spec_edge_cases, - "https://api.example.com" - ) - - tools = api_spec.tools - tool_names = [t.name for t in tools] - - expected_tools = [ - "getEmptyOperationId", # Empty operationId falls back to path - "a", # Single character preserved (valid) - "getNoOperationId", # Missing operationId uses path-based naming - "apiV1UsersListAll", # Mixed separators normalized - "getApiHttpUrl" # Multiple acronyms preserved - ] - - # Check expected tools are present - for expected_name in expected_tools: - assert expected_name in tool_names, f"Expected tool name '{expected_name}' not found in {tool_names}" - - # Verify total tool count - all 5 operations should create valid tools - assert len(tools) == 5, f"Expected 5 tools, got {len(tools)}: {tool_names}" - @patch('requests.get') def test_discover_api_success(self, mock_get, discovery, sample_openapi_spec): """Test successful API discovery.""" @@ -1117,67 +912,6 @@ def swagger2_spec(self): } } - def test_detect_swagger2_version(self, discovery, swagger2_spec): - """Test that Swagger 2.0 version is correctly detected.""" - version = discovery._detect_spec_version(swagger2_spec) - assert version == "swagger_2" - - def test_swagger2_base_url_extraction(self, discovery, swagger2_spec): - """Test base URL extraction from Swagger 2.0 (host + basePath + schemes).""" - discovery._spec_version = "swagger_2" - base_url = discovery._extract_base_url(swagger2_spec) - assert base_url == "https://api.example.com/v1" - - def test_swagger2_base_url_multiple_schemes(self, discovery): - """Test base URL extraction with multiple schemes (uses first one).""" - spec = { - "swagger": "2.0", - "host": "api.example.com", - "basePath": "/api", - "schemes": ["http", "https"] - } - discovery._spec_version = "swagger_2" - base_url = discovery._extract_base_url(spec) - assert base_url == "http://api.example.com/api" - - def test_swagger2_base_url_no_schemes(self, discovery): - """Test base URL extraction defaults to https when no schemes.""" - spec = { - "swagger": "2.0", - "host": "api.example.com", - "basePath": "/v2" - } - discovery._spec_version = "swagger_2" - base_url = discovery._extract_base_url(spec) - assert base_url == "https://api.example.com/v2" - - def test_swagger2_response_schema_parsing(self, discovery, swagger2_spec): - """Test that Swagger 2.0 response schemas are correctly parsed.""" - discovery._spec_version = "swagger_2" - responses = swagger2_spec["paths"]["/users"]["get"]["responses"] - - schema = discovery._parse_responses(responses, swagger2_spec, {}) - - assert schema is not None - assert schema["type"] == "array" - assert "items" in schema - assert schema["items"]["type"] == "object" - - def test_swagger2_body_parameter_parsing(self, discovery, swagger2_spec): - """Test that Swagger 2.0 body parameters are correctly parsed.""" - discovery._spec_version = "swagger_2" - post_operation = swagger2_spec["paths"]["/users"]["post"] - body_param = post_operation["parameters"][0] - - params = discovery._parse_swagger2_body_parameter(body_param, swagger2_spec, {}) - - assert "name" in params - assert "email" in params - assert params["name"]["type"] == "string" - assert params["name"]["required"] == True - assert params["name"]["location"] == "body" - assert params["email"]["required"] == True - @patch('ocp_agent.schema_discovery.requests.get') def test_discover_swagger2_api(self, mock_get, discovery, swagger2_spec): """Test full API discovery with Swagger 2.0 spec.""" @@ -1323,172 +1057,6 @@ def tools_with_resources(self): ) ] - def test_filter_tools_by_resources_single_resource(self, discovery, tools_with_resources): - """Test filtering tools by a single resource name.""" - # Filter for repos resources only (first segment matching) - filtered_tools = discovery._filter_tools_by_resources(tools_with_resources, ["repos"]) - - assert len(filtered_tools) == 2 # /repos/{owner}/{repo}, /repos/{owner}/{repo}/issues (NOT /user/repos) - path_set = {tool.path for tool in filtered_tools} - assert "/repos/{owner}/{repo}" in path_set - assert "/repos/{owner}/{repo}/issues" in path_set - - def test_filter_tools_by_resources_multiple_resources(self, discovery, tools_with_resources): - """Test filtering tools by multiple resource names.""" - # Filter for both repos and orgs resources (first segment matching) - filtered_tools = discovery._filter_tools_by_resources(tools_with_resources, ["repos", "orgs"]) - - assert len(filtered_tools) == 3 # /repos/..., /repos/.../issues, /orgs/... (NOT /user/repos) - - def test_filter_tools_by_resources_case_insensitive(self, discovery, tools_with_resources): - """Test that resource filtering is case-insensitive.""" - # Filter with different case (first segment matching) - filtered_tools = discovery._filter_tools_by_resources(tools_with_resources, ["REPOS", "Orgs"]) - - assert len(filtered_tools) == 3 - - def test_filter_tools_by_resources_no_matches(self, discovery, tools_with_resources): - """Test filtering tools with resources that don't match any paths.""" - # Filter for resources that don't exist - filtered_tools = discovery._filter_tools_by_resources(tools_with_resources, ["payments", "customers"]) - - assert len(filtered_tools) == 0 - - def test_filter_tools_by_resources_empty_list(self, discovery, tools_with_resources): - """Test filtering with empty include_resources list returns all tools.""" - # Empty list should return all tools - filtered_tools = discovery._filter_tools_by_resources(tools_with_resources, []) - - assert len(filtered_tools) == 4 - assert filtered_tools == tools_with_resources - - def test_filter_tools_by_resources_none(self, discovery, tools_with_resources): - """Test filtering with None include_resources returns all tools.""" - # None should return all tools - filtered_tools = discovery._filter_tools_by_resources(tools_with_resources, None) - - assert len(filtered_tools) == 4 - assert filtered_tools == tools_with_resources - - def test_filter_tools_by_resources_exact_match(self, discovery): - """Test that only exact segment matches are included, not substring matches.""" - tools = [ - OCPTool(name="listPaymentMethods", description="List payment methods", method="GET", - path="/payment_methods", parameters={}, response_schema=None), - OCPTool(name="createPaymentIntent", description="Create payment intent", method="POST", - path="/payment_intents", parameters={}, response_schema=None), - OCPTool(name="listPayments", description="List payments", method="GET", - path="/payments", parameters={}, response_schema=None) - ] - - # Filter for "payment" should not match any (no exact segment match) - filtered_tools = discovery._filter_tools_by_resources(tools, ["payment"]) - assert len(filtered_tools) == 0 # "payment" doesn't exactly match any first segment - - # Filter for "payments" should match the exact first segment - filtered_tools = discovery._filter_tools_by_resources(tools, ["payments"]) - assert len(filtered_tools) == 1 - assert filtered_tools[0].path == "/payments" - - # Filter for "payment_methods" should match - filtered_tools = discovery._filter_tools_by_resources(tools, ["payment_methods"]) - assert len(filtered_tools) == 1 - assert filtered_tools[0].path == "/payment_methods" - - def test_filter_tools_by_resources_with_dots(self, discovery): - """Test that dot-separated paths work correctly (e.g., Slack API).""" - tools = [ - OCPTool(name="conversationsReplies", description="Get conversation replies", method="GET", - path="/conversations.replies", parameters={}, response_schema=None), - OCPTool(name="conversationsHistory", description="Get conversation history", method="GET", - path="/conversations.history", parameters={}, response_schema=None), - OCPTool(name="chatPostMessage", description="Post a message", method="POST", - path="/chat.postMessage", parameters={}, response_schema=None) - ] - - # Filter for "conversations" should match both conversation endpoints - filtered_tools = discovery._filter_tools_by_resources(tools, ["conversations"]) - assert len(filtered_tools) == 2 - assert all("conversations" in tool.path for tool in filtered_tools) - - # Filter for "chat" should match the chat endpoint - filtered_tools = discovery._filter_tools_by_resources(tools, ["chat"]) - assert len(filtered_tools) == 1 - assert filtered_tools[0].path == "/chat.postMessage" - - def test_filter_tools_by_resources_no_substring_match(self, discovery): - """Test that substring matching doesn't work - only exact segment matches.""" - tools = [ - OCPTool(name="listRepos", description="List repos", method="GET", - path="/repos/{owner}/{repo}", parameters={}, response_schema=None), - OCPTool(name="listRepositories", description="List enterprise repositories", method="GET", - path="/enterprises/{enterprise}/code-security/configurations/{config_id}/repositories", - parameters={}, response_schema=None) - ] - - # Filter for "repos" should match "/repos/{owner}/{repo}" - # Should NOT match "/enterprises/.../repositories" (repos != repositories) - filtered_tools = discovery._filter_tools_by_resources(tools, ["repos"]) - assert len(filtered_tools) == 1 - assert filtered_tools[0].path == "/repos/{owner}/{repo}" - - # Filter for "repositories" should match the enterprise endpoint (but first segment is "enterprises") - filtered_tools = discovery._filter_tools_by_resources(tools, ["repositories"]) - assert len(filtered_tools) == 0 # "repositories" is not the first segment - - # Filter for "enterprises" should match the enterprise endpoint - filtered_tools = discovery._filter_tools_by_resources(tools, ["enterprises"]) - assert len(filtered_tools) == 1 - assert "/enterprises" in filtered_tools[0].path - - def test_filter_tools_by_resources_with_path_prefix(self, discovery): - """Test filtering with path_prefix to strip version prefixes.""" - tools = [ - OCPTool(name="listPayments", description="List payments", method="GET", - path="/v1/payments", parameters={}, response_schema=None), - OCPTool(name="createCharge", description="Create charge", method="POST", - path="/v1/charges", parameters={}, response_schema=None), - OCPTool(name="legacyPayment", description="Legacy payment", method="GET", - path="/v2/payments", parameters={}, response_schema=None) - ] - - # Filter for "payments" with /v1 prefix - filtered_tools = discovery._filter_tools_by_resources(tools, ["payments"], path_prefix="/v1") - assert len(filtered_tools) == 1 - assert filtered_tools[0].path == "/v1/payments" - - # Filter for "payments" with /v2 prefix - filtered_tools = discovery._filter_tools_by_resources(tools, ["payments"], path_prefix="/v2") - assert len(filtered_tools) == 1 - assert filtered_tools[0].path == "/v2/payments" - - # Filter without prefix - no matches (first segment is "v1" or "v2") - filtered_tools = discovery._filter_tools_by_resources(tools, ["payments"]) - assert len(filtered_tools) == 0 - - def test_filter_tools_by_resources_first_segment_only(self, discovery): - """Test that only the first resource segment is matched.""" - tools = [ - OCPTool(name="listRepoIssues", description="List repo issues", method="GET", - path="/repos/{owner}/{repo}/issues", parameters={}, response_schema=None), - OCPTool(name="listUserRepos", description="List user repos", method="GET", - path="/user/repos", parameters={}, response_schema=None) - ] - - # Filter for "repos" - should match /repos/... but NOT /user/repos (first segment is "user") - filtered_tools = discovery._filter_tools_by_resources(tools, ["repos"]) - assert len(filtered_tools) == 1 - assert filtered_tools[0].path == "/repos/{owner}/{repo}/issues" - - # Filter for "user" - should match /user/repos - filtered_tools = discovery._filter_tools_by_resources(tools, ["user"]) - assert len(filtered_tools) == 1 - assert filtered_tools[0].path == "/user/repos" - - # Filter for "issues" - should NOT match anything (issues is not first segment) - filtered_tools = discovery._filter_tools_by_resources(tools, ["issues"]) - assert len(filtered_tools) == 0 - @patch('ocp_agent.schema_discovery.requests.get') def test_discover_api_with_include_resources(self, mock_get, discovery, openapi_spec_with_resources): """Test discover_api method with include_resources parameter."""