diff --git a/examples/README.md b/examples/README.md index 39ef8f6..80eb234 100644 --- a/examples/README.md +++ b/examples/README.md @@ -83,6 +83,21 @@ python examples/run_x402_gemini_tools.py - Demonstrates tool/function calling with Gemini models - Uses x402 protocol for payment processing +#### `run_x402_structured_output.py` +Demonstrates structured outputs with schema validation. + +```bash +python examples/run_x402_structured_output.py +``` + +**What it does:** +- Shows how to use `response_format` to constrain LLM outputs to valid JSON +- Demonstrates simple JSON mode and schema-validated outputs +- Includes examples with chat, completion, and streaming +- Useful for reliable data extraction and structured AI workflows + +**Note:** Not all models support structured outputs. Works with GPT-4o, GPT-4o-mini, and other compatible models. + ## Alpha Testnet Examples Examples for features only available on the **Alpha Testnet** are located in the [`alpha/`](./alpha/) folder. These include: diff --git a/examples/run_x402_structured_output.py b/examples/run_x402_structured_output.py new file mode 100644 index 0000000..1776cca --- /dev/null +++ b/examples/run_x402_structured_output.py @@ -0,0 +1,240 @@ +""" +Example: Structured outputs with OpenGradient TEE LLM inference + +This example demonstrates how to use response_format to constrain LLM outputs +to valid JSON matching a specific schema. This is useful for reliable data extraction, +structured AI workflows, and programmatic consumption of LLM outputs. + +Requirements: +- Set OG_PRIVATE_KEY environment variable with your wallet private key +- The model must support structured outputs (e.g., GPT-4o, GPT-4o-mini) + +Usage: + export OG_PRIVATE_KEY="0x..." + python examples/run_x402_structured_output.py + +Note: + Structured output support depends on both the model and backend implementation. + If you see markdown-formatted responses instead of pure JSON, the backend may + not fully support response_format yet, or the model may need additional configuration. +""" + +import json +import os +import re + +import opengradient as og + + +def extract_json_from_response(content: str) -> str: + """ + Extract JSON from response content. + + Handles cases where the backend returns markdown-formatted JSON + instead of pure JSON. + + Args: + content: The response content (may be JSON or markdown with JSON) + + Returns: + Extracted JSON string + """ + # Try to parse as-is first + try: + json.loads(content) + return content + except json.JSONDecodeError: + pass + + # Try to extract JSON from markdown code blocks + json_match = re.search(r'```(?:json)?\s*\n(.*?)\n```', content, re.DOTALL) + if json_match: + return json_match.group(1).strip() + + # Try to find JSON object in text + json_match = re.search(r'\{.*\}', content, re.DOTALL) + if json_match: + try: + json.loads(json_match.group(0)) + return json_match.group(0) + except json.JSONDecodeError: + pass + + # If all else fails, return original content + return content + + +def example_1_simple_json_mode(): + """Example 1: Simple JSON mode with chat""" + print("\n=== Example 1: Simple JSON Mode ===") + print("Using type='json_object' to return any valid JSON structure\n") + + private_key = os.environ.get("OG_PRIVATE_KEY") + if not private_key: + raise ValueError("Please set OG_PRIVATE_KEY environment variable") + + client = og.init(private_key=private_key) + + # Simple JSON mode - returns valid JSON, any structure + result = client.llm.chat( + model=og.TEE_LLM.GPT_4O, + messages=[{"role": "user", "content": "List 3 colors as a JSON object with a 'colors' array"}], + max_tokens=200, + response_format={"type": "json_object"}, + ) + + print(f"Raw response: {result.chat_output['content']}") + data = json.loads(result.chat_output["content"]) + print(f"Parsed JSON: {data}") + print(f"Payment hash: {result.payment_hash}") + + +def example_2_schema_validated_chat(): + """Example 2: Schema-validated output with chat""" + print("\n=== Example 2: Schema-Validated Output ===") + print("Using type='json_schema' to enforce a specific structure\n") + + private_key = os.environ.get("OG_PRIVATE_KEY") + if not private_key: + raise ValueError("Please set OG_PRIVATE_KEY environment variable") + + client = og.init(private_key=private_key) + + # Define a strict schema for the response + color_schema = { + "type": "object", + "properties": { + "colors": { + "type": "array", + "items": {"type": "string"}, + "description": "List of color names", + }, + "count": {"type": "integer", "description": "Number of colors"}, + }, + "required": ["colors", "count"], + "additionalProperties": False, + } + + # Schema-validated mode - must match the schema exactly + result = client.llm.chat( + model=og.TEE_LLM.GPT_4O, + messages=[{"role": "user", "content": "Give me 5 primary and secondary colors"}], + max_tokens=300, + response_format={ + "type": "json_schema", + "json_schema": { + "name": "color_list", + "schema": color_schema, + "strict": True, + }, + }, + ) + + print(f"Raw response: {result.chat_output['content']}") + data = json.loads(result.chat_output["content"]) + print(f"\nColors returned: {data['colors']}") + print(f"Count: {data['count']}") + print(f"Payment hash: {result.payment_hash}") + + +def example_3_completion_structured(): + """Example 3: Structured output with completion""" + print("\n=== Example 3: Structured Completion ===") + print("Using response_format with the completion API\n") + + private_key = os.environ.get("OG_PRIVATE_KEY") + if not private_key: + raise ValueError("Please set OG_PRIVATE_KEY environment variable") + + client = og.init(private_key=private_key) + + # Schema for math problem response + math_schema = { + "type": "object", + "properties": { + "answer": {"type": "number", "description": "The numerical answer"}, + "explanation": {"type": "string", "description": "Step-by-step explanation"}, + }, + "required": ["answer", "explanation"], + } + + result = client.llm.completion( + model=og.TEE_LLM.GPT_4O, + prompt="What is 15% of 240? Respond with JSON containing 'answer' and 'explanation' fields.", + max_tokens=200, + response_format={ + "type": "json_schema", + "json_schema": { + "name": "math_response", + "schema": math_schema, + }, + }, + ) + + print(f"Raw response: {result.completion_output}") + data = json.loads(result.completion_output) + print(f"\nAnswer: {data['answer']}") + print(f"Explanation: {data['explanation']}") + + +def example_4_streaming_structured(): + """Example 4: Streaming with structured output""" + print("\n=== Example 4: Streaming Structured Output ===") + print("Streaming JSON chunks (note: JSON is only valid when complete)\n") + + private_key = os.environ.get("OG_PRIVATE_KEY") + if not private_key: + raise ValueError("Please set OG_PRIVATE_KEY environment variable") + + client = og.init(private_key=private_key) + + # Define schema for a story response + story_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "content": {"type": "string"}, + "word_count": {"type": "integer"}, + }, + "required": ["title", "content", "word_count"], + } + + print("Streaming chunks:") + chunks = [] + for chunk in client.llm.chat( + model=og.TEE_LLM.GPT_4O, + messages=[{"role": "user", "content": "Write a very short story (2 sentences) about a robot"}], + max_tokens=300, + stream=True, + response_format={ + "type": "json_schema", + "json_schema": { + "name": "story_response", + "schema": story_schema, + }, + }, + ): + if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + print(content, end="", flush=True) + chunks.append(content) + + print("\n") + full_response = "".join(chunks) + data = json.loads(full_response) + print(f"\nParsed structured response:") + print(f"Title: {data['title']}") + print(f"Content: {data['content']}") + print(f"Word count: {data['word_count']}") + + +if __name__ == "__main__": + try: + example_1_simple_json_mode() + example_2_schema_validated_chat() + example_3_completion_structured() + example_4_streaming_structured() + print("\n✅ All examples completed successfully!") + except Exception as e: + print(f"\n❌ Error: {e}") + raise diff --git a/pyproject.toml b/pyproject.toml index 7317ef7..4104536 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "opengradient" -version = "0.6.0" +version = "0.6.1" description = "Python SDK for OpenGradient decentralized model management & inference services" authors = [{name = "OpenGradient", email = "kyle@vannalabs.ai"}] readme = "README.md" diff --git a/src/opengradient/__init__.py b/src/opengradient/__init__.py index 3236897..b512ef2 100644 --- a/src/opengradient/__init__.py +++ b/src/opengradient/__init__.py @@ -85,8 +85,10 @@ HistoricalInputQuery, InferenceMode, InferenceResult, + JSONSchemaDefinition, ModelOutput, ModelRepository, + ResponseFormat, SchedulerParams, TextGenerationOutput, TextGenerationStream, @@ -137,6 +139,8 @@ def init( "SchedulerParams", "CandleType", "CandleOrder", + "JSONSchemaDefinition", + "ResponseFormat", "agents", "alphasense", ] diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 2d13224..20da50c 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -67,6 +67,46 @@ def _og_payment_selector(self, accepts, network_filter=DEFAULT_NETWORK_FILTER, s max_value=max_value, ) + def _validate_response_format(self, response_format: Optional[Dict]) -> None: + """ + Validate response_format structure. + + Performs lightweight structural validation only. Does NOT validate + the actual JSON Schema content - that's handled by the backend. + + Args: + response_format: Response format dict to validate. + + Raises: + OpenGradientError: If the structure is invalid. + """ + if response_format is None: + return + + if not isinstance(response_format, dict): + raise OpenGradientError("response_format must be a dict") + + if "type" not in response_format: + raise OpenGradientError("response_format must have a 'type' field") + + format_type = response_format["type"] + if format_type not in ("json_object", "json_schema"): + raise OpenGradientError(f"response_format type must be 'json_object' or 'json_schema', got: {format_type}") + + if format_type == "json_schema": + if "json_schema" not in response_format: + raise OpenGradientError("response_format with type='json_schema' must have a 'json_schema' field") + + json_schema = response_format["json_schema"] + if not isinstance(json_schema, dict): + raise OpenGradientError("json_schema must be a dict") + + if "name" not in json_schema: + raise OpenGradientError("json_schema must have a 'name' field") + + if "schema" not in json_schema: + raise OpenGradientError("json_schema must have a 'schema' field") + def completion( self, model: TEE_LLM, @@ -74,6 +114,7 @@ def completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, + response_format: Optional[Dict] = None, x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ @@ -85,6 +126,11 @@ def completion( max_tokens (int): Maximum number of tokens for LLM output. Default is 100. stop_sequence (List[str], optional): List of stop sequences for LLM. Default is None. temperature (float): Temperature for LLM inference, between 0 and 1. Default is 0.0. + response_format (Dict, optional): Format for structured outputs. Supports: + - `{"type": "json_object"}`: Returns valid JSON (any structure) + - `{"type": "json_schema", "json_schema": {...}}`: Returns JSON matching schema + Example: `{"type": "json_schema", "json_schema": {"name": "response", "schema": {...}}}` + Note: Not all models support structured outputs. Default is None. x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments. - SETTLE: Records input/output hashes only (most privacy-preserving). - SETTLE_BATCH: Aggregates multiple inferences into batch hashes (most cost-efficient). @@ -98,14 +144,16 @@ def completion( - Payment hash for x402 transactions Raises: - OpenGradientError: If the inference fails. + OpenGradientError: If the inference fails or validation fails. """ + self._validate_response_format(response_format) return self._tee_llm_completion( model=model.split("/")[1], prompt=prompt, max_tokens=max_tokens, stop_sequence=stop_sequence, temperature=temperature, + response_format=response_format, x402_settlement_mode=x402_settlement_mode, ) @@ -116,6 +164,7 @@ def _tee_llm_completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, + response_format: Optional[Dict] = None, x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ @@ -146,6 +195,9 @@ async def make_request(): if stop_sequence: payload["stop"] = stop_sequence + if response_format: + payload["response_format"] = response_format + try: response = await client.post("/v1/completions", json=payload, headers=headers, timeout=60) @@ -180,6 +232,7 @@ def chat( temperature: float = 0.0, tools: Optional[List[Dict]] = [], tool_choice: Optional[str] = None, + response_format: Optional[Dict] = None, x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, stream: bool = False, ) -> Union[TextGenerationOutput, TextGenerationStream]: @@ -194,6 +247,11 @@ def chat( temperature (float): Temperature for LLM inference, between 0 and 1. tools (List[dict], optional): Set of tools for function calling. tool_choice (str, optional): Sets a specific tool to choose. + response_format (Dict, optional): Format for structured outputs. Supports: + - `{"type": "json_object"}`: Returns valid JSON (any structure) + - `{"type": "json_schema", "json_schema": {...}}`: Returns JSON matching schema + Example: `{"type": "json_schema", "json_schema": {"name": "colors", "schema": {...}}}` + Note: Not all models support structured outputs. Default is None. x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments. - SETTLE: Records input/output hashes only (most privacy-preserving). - SETTLE_BATCH: Aggregates multiple inferences into batch hashes (most cost-efficient). @@ -207,8 +265,9 @@ def chat( - If stream=True: TextGenerationStream yielding StreamChunk objects with typed deltas (true streaming via threading) Raises: - OpenGradientError: If the inference fails. + OpenGradientError: If the inference fails or validation fails. """ + self._validate_response_format(response_format) if stream: # Use threading bridge for true sync streaming return self._tee_llm_chat_stream_sync( @@ -219,6 +278,7 @@ def chat( temperature=temperature, tools=tools, tool_choice=tool_choice, + response_format=response_format, x402_settlement_mode=x402_settlement_mode, ) else: @@ -231,6 +291,7 @@ def chat( temperature=temperature, tools=tools, tool_choice=tool_choice, + response_format=response_format, x402_settlement_mode=x402_settlement_mode, ) @@ -243,6 +304,7 @@ def _tee_llm_chat( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, + response_format: Optional[Dict] = None, x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ @@ -277,6 +339,9 @@ async def make_request(): payload["tools"] = tools payload["tool_choice"] = tool_choice or "auto" + if response_format: + payload["response_format"] = response_format + try: # Non-streaming with x402 endpoint = "/v1/chat/completions" @@ -320,6 +385,7 @@ def _tee_llm_chat_stream_sync( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, + response_format: Optional[Dict] = None, x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ): """ @@ -351,6 +417,7 @@ async def _stream(): temperature=temperature, tools=tools, tool_choice=tool_choice, + response_format=response_format, x402_settlement_mode=x402_settlement_mode, ): queue.put(chunk) # Put chunk immediately @@ -404,6 +471,7 @@ async def _tee_llm_chat_stream_async( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, + response_format: Optional[Dict] = None, x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ): """ @@ -440,6 +508,8 @@ async def _tee_llm_chat_stream_async( if tools: payload["tools"] = tools payload["tool_choice"] = tool_choice or "auto" + if response_format: + payload["response_format"] = response_format async with client.stream( "POST", diff --git a/src/opengradient/types.py b/src/opengradient/types.py index fa89c98..888f20a 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -5,10 +5,15 @@ import time from dataclasses import dataclass from enum import Enum, IntEnum -from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, TypedDict, Union import numpy as np +try: + from typing import NotRequired +except ImportError: + from typing_extensions import NotRequired + class x402SettlementMode(str, Enum): """ @@ -487,3 +492,67 @@ class ModelRepository: class FileUploadResult: modelCid: str size: int + + +class JSONSchemaDefinition(TypedDict): + """ + JSON Schema definition for structured output validation. + + This follows the OpenAI standard for schema-validated responses. + The schema must be a valid JSON Schema object that defines the + structure the model's output should conform to. + + Attributes: + name: A descriptive name for the schema (e.g., "math_response", "user_profile"). + schema: A valid JSON Schema object (dict) defining the expected output structure. + strict: Whether to enforce strict schema validation. Defaults to True. + When True, the model output must exactly match the schema. + """ + + name: str + schema: Dict[str, Any] + strict: NotRequired[bool] + + +class ResponseFormat(TypedDict): + """ + Response format configuration for structured outputs. + + Used to constrain LLM outputs to valid JSON matching a specific schema. + Follows the OpenAI standard response_format parameter structure. + + Attributes: + type: The response format type. Must be one of: + - "json_object": Model outputs valid JSON (any structure) + - "json_schema": Model outputs JSON matching the provided schema + json_schema: Required when type="json_schema". Defines the schema + the output must conform to. + + Usage: + # Simple JSON mode + response_format = {"type": "json_object"} + + # Schema-validated JSON + response_format = { + "type": "json_schema", + "json_schema": { + "name": "math_response", + "schema": { + "type": "object", + "properties": { + "result": {"type": "number"}, + "explanation": {"type": "string"} + }, + "required": ["result", "explanation"] + } + } + } + + Note: + Not all models support structured outputs. Check model capabilities + before using this feature. The backend will return an error if the + model does not support structured outputs. + """ + + type: Literal["json_object", "json_schema"] + json_schema: NotRequired[JSONSchemaDefinition] diff --git a/tests/client_test.py b/tests/client_test.py index da8e733..897d5b3 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -319,3 +319,200 @@ def test_settlement_mode_aliases(self): """Test settlement mode aliases.""" assert x402SettlementMode.SETTLE_INDIVIDUAL == x402SettlementMode.SETTLE assert x402SettlementMode.SETTLE_INDIVIDUAL_WITH_METADATA == x402SettlementMode.SETTLE_METADATA + + +class TestResponseFormatValidation: + """Test response_format validation logic.""" + + @pytest.fixture + def client(self, mock_web3, mock_abi_files): + """Create a test client instance.""" + return Client(private_key="0x" + "1" * 64) + + def test_none_response_format(self, client): + """Test that None response_format is valid.""" + # Should not raise an exception + client.llm._validate_response_format(None) + + def test_json_object_format(self, client): + """Test json_object response format validation.""" + response_format = {"type": "json_object"} + # Should not raise an exception + client.llm._validate_response_format(response_format) + + def test_json_schema_format(self, client): + """Test json_schema response format validation.""" + response_format = { + "type": "json_schema", + "json_schema": { + "name": "test_schema", + "schema": { + "type": "object", + "properties": {"field": {"type": "string"}}, + }, + }, + } + # Should not raise an exception + client.llm._validate_response_format(response_format) + + def test_invalid_type_not_dict(self, client): + """Test that non-dict response_format raises error.""" + from src.opengradient.client.exceptions import OpenGradientError + + with pytest.raises(OpenGradientError, match="response_format must be a dict"): + client.llm._validate_response_format("invalid") + + def test_missing_type_field(self, client): + """Test that missing 'type' field raises error.""" + from src.opengradient.client.exceptions import OpenGradientError + + with pytest.raises(OpenGradientError, match="response_format must have a 'type' field"): + client.llm._validate_response_format({}) + + def test_invalid_type_value(self, client): + """Test that invalid type value raises error.""" + from src.opengradient.client.exceptions import OpenGradientError + + with pytest.raises(OpenGradientError, match="response_format type must be"): + client.llm._validate_response_format({"type": "invalid_type"}) + + def test_json_schema_missing_json_schema_field(self, client): + """Test that json_schema type without json_schema field raises error.""" + from src.opengradient.client.exceptions import OpenGradientError + + with pytest.raises(OpenGradientError, match="must have a 'json_schema' field"): + client.llm._validate_response_format({"type": "json_schema"}) + + def test_json_schema_not_dict(self, client): + """Test that non-dict json_schema raises error.""" + from src.opengradient.client.exceptions import OpenGradientError + + with pytest.raises(OpenGradientError, match="json_schema must be a dict"): + client.llm._validate_response_format({"type": "json_schema", "json_schema": "invalid"}) + + def test_json_schema_missing_name(self, client): + """Test that json_schema without name raises error.""" + from src.opengradient.client.exceptions import OpenGradientError + + with pytest.raises(OpenGradientError, match="json_schema must have a 'name' field"): + client.llm._validate_response_format({"type": "json_schema", "json_schema": {"schema": {}}}) + + def test_json_schema_missing_schema(self, client): + """Test that json_schema without schema raises error.""" + from src.opengradient.client.exceptions import OpenGradientError + + with pytest.raises(OpenGradientError, match="json_schema must have a 'schema' field"): + client.llm._validate_response_format({"type": "json_schema", "json_schema": {"name": "test"}}) + + +class TestStructuredOutputs: + """Test structured output functionality with mocked backends.""" + + @pytest.fixture + def client(self, mock_web3, mock_abi_files): + """Create a test client instance.""" + return Client(private_key="0x" + "1" * 64) + + @pytest.fixture + def mock_x402_client(self): + """Mock x402HttpxClient for testing.""" + with patch("src.opengradient.client.llm.x402HttpxClient") as mock_client_class: + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.headers = {} + mock_client_class.return_value.__aenter__.return_value = mock_client + yield mock_client, mock_response + + @pytest.fixture + def mock_httpx_client(self): + """Mock httpx.AsyncClient for streaming tests.""" + with patch("src.opengradient.client.llm.httpx.AsyncClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + yield mock_client + + def test_chat_with_json_object(self, client, mock_x402_client): + """Test chat with json_object response format.""" + mock_client, mock_response = mock_x402_client + mock_response.aread = MagicMock( + return_value=json.dumps( + {"choices": [{"message": {"role": "assistant", "content": '{"colors": ["red", "blue"]}'}, "finish_reason": "stop"}]} + ).encode() + ) + mock_client.post = MagicMock(return_value=mock_response) + + result = client.llm.chat( + model=TEE_LLM.GPT_4O, + messages=[{"role": "user", "content": "List colors"}], + max_tokens=100, + response_format={"type": "json_object"}, + ) + + # Verify the call was made with response_format in payload + call_kwargs = mock_client.post.call_args[1] + assert "json" in call_kwargs + assert call_kwargs["json"]["response_format"] == {"type": "json_object"} + assert result.chat_output["content"] == '{"colors": ["red", "blue"]}' + + def test_chat_with_json_schema(self, client, mock_x402_client): + """Test chat with json_schema response format.""" + mock_client, mock_response = mock_x402_client + mock_response.aread = MagicMock( + return_value=json.dumps( + {"choices": [{"message": {"role": "assistant", "content": '{"count": 2}'}, "finish_reason": "stop"}]} + ).encode() + ) + mock_client.post = MagicMock(return_value=mock_response) + + schema = {"type": "object", "properties": {"count": {"type": "integer"}}} + result = client.llm.chat( + model=TEE_LLM.GPT_4O, + messages=[{"role": "user", "content": "Count items"}], + max_tokens=100, + response_format={"type": "json_schema", "json_schema": {"name": "count", "schema": schema}}, + ) + + # Verify the call was made with response_format in payload + call_kwargs = mock_client.post.call_args[1] + assert "response_format" in call_kwargs["json"] + assert call_kwargs["json"]["response_format"]["type"] == "json_schema" + assert result.chat_output["content"] == '{"count": 2}' + + def test_completion_with_response_format(self, client, mock_x402_client): + """Test completion with response_format.""" + mock_client, mock_response = mock_x402_client + mock_response.aread = MagicMock(return_value=json.dumps({"completion": '{"answer": 42}'}).encode()) + mock_client.post = MagicMock(return_value=mock_response) + + result = client.llm.completion( + model=TEE_LLM.GPT_4O, + prompt="Answer in JSON", + max_tokens=100, + response_format={"type": "json_object"}, + ) + + # Verify the call was made with response_format in payload + call_kwargs = mock_client.post.call_args[1] + assert "response_format" in call_kwargs["json"] + assert result.completion_output == '{"answer": 42}' + + def test_chat_without_response_format(self, client, mock_x402_client): + """Test that chat works without response_format (backward compatibility).""" + mock_client, mock_response = mock_x402_client + mock_response.aread = MagicMock( + return_value=json.dumps( + {"choices": [{"message": {"role": "assistant", "content": "Hello"}, "finish_reason": "stop"}]} + ).encode() + ) + mock_client.post = MagicMock(return_value=mock_response) + + result = client.llm.chat( + model=TEE_LLM.GPT_4O, + messages=[{"role": "user", "content": "Hi"}], + max_tokens=100, + ) + + # Verify response_format was not in the payload + call_kwargs = mock_client.post.call_args[1] + assert "response_format" not in call_kwargs["json"] + assert result.chat_output["content"] == "Hello"