diff --git a/.github/workflows/_test-code-samples.yml b/.github/workflows/_test-code-samples.yml index d7b46874..171e0fed 100644 --- a/.github/workflows/_test-code-samples.yml +++ b/.github/workflows/_test-code-samples.yml @@ -40,7 +40,7 @@ jobs: - name: Tests code samples run: | - ./tests/test_code_samples.sh ${{ secrets.MINDEE_ACCOUNT_SE_TESTS }} ${{ secrets.MINDEE_ENDPOINT_SE_TESTS }} ${{ secrets.MINDEE_API_KEY_SE_TESTS }} ${{ secrets.MINDEE_V2_SE_TESTS_API_KEY }} ${{ secrets.MINDEE_V2_SE_TESTS_FINDOC_MODEL_ID }} + ./tests/test_code_samples.sh ${{ secrets.MINDEE_ACCOUNT_SE_TESTS }} ${{ secrets.MINDEE_ENDPOINT_SE_TESTS }} ${{ secrets.MINDEE_API_KEY_SE_TESTS }} ${{ secrets.MINDEE_V2_SE_TESTS_API_KEY }} ${{ secrets.MINDEE_V2_SE_TESTS_FINDOC_MODEL_ID }} ${{ secrets.MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID }} - name: Notify Slack Action on Failure uses: ravsamhq/notify-slack-action@2.3.0 diff --git a/.github/workflows/_test-integrations.yml b/.github/workflows/_test-integrations.yml index 8b8bfa3c..70b7b0e9 100644 --- a/.github/workflows/_test-integrations.yml +++ b/.github/workflows/_test-integrations.yml @@ -49,6 +49,7 @@ jobs: MINDEE_V2_API_KEY: ${{ secrets.MINDEE_V2_SE_TESTS_API_KEY }} MINDEE_V2_FINDOC_MODEL_ID: ${{ secrets.MINDEE_V2_SE_TESTS_FINDOC_MODEL_ID }} MINDEE_V2_SE_TESTS_BLANK_PDF_URL: ${{ secrets.MINDEE_V2_SE_TESTS_BLANK_PDF_URL }} + MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID: ${{ secrets.MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID }} run: | pytest --cov mindee -m integration diff --git a/docs/extras/code_samples/default_v2.txt b/docs/extras/code_samples/v2_default.txt similarity index 82% rename from docs/extras/code_samples/default_v2.txt rename to docs/extras/code_samples/v2_default.txt index 1e3abd18..7f4f0be2 100644 --- a/docs/extras/code_samples/default_v2.txt +++ b/docs/extras/code_samples/v2_default.txt @@ -1,4 +1,4 @@ -from mindee import ClientV2, InferenceParameters, PathInput +from mindee import ClientV2, InferenceParameters, InferenceResponse, PathInput input_path = "/path/to/the/file.ext" api_key = "MY_API_KEY" @@ -29,8 +29,10 @@ params = InferenceParameters( input_source = PathInput(input_path) # Send for processing -response = mindee_client.enqueue_and_get_inference( - input_source, params +response = mindee_client.enqueue_and_get_result( + InferenceResponse, + input_source, + params, ) # Print a brief summary of the parsed data diff --git a/docs/extras/code_samples/v2_split.txt b/docs/extras/code_samples/v2_split.txt new file mode 100644 index 00000000..1442f028 --- /dev/null +++ b/docs/extras/code_samples/v2_split.txt @@ -0,0 +1,27 @@ +from mindee import ClientV2, SplitParameters, SplitResponse, PathInput + +input_path = "/path/to/the/file.ext" +api_key = "MY_API_KEY" +model_id = "MY_SPLIT_MODEL_ID" + +# Init a new client +mindee_client = ClientV2(api_key) + +# Set inference parameters +params = SplitParameters( + # ID of the model, required. + model_id=model_id, +) + +# Load a file from disk +input_source = PathInput(input_path) + +# Send for processing +response = mindee_client.enqueue_and_get_result( + SplitResponse, + input_source, + params, +) + +# Print a brief summary of the parsed data +print(response.inference) diff --git a/mindee/__init__.py b/mindee/__init__.py index ff5ca240..33a2087c 100644 --- a/mindee/__init__.py +++ b/mindee/__init__.py @@ -1,13 +1,13 @@ from mindee import product from mindee.client import Client from mindee.client_v2 import ClientV2 +from mindee.input import LocalResponse, PageOptions, PollingOptions from mindee.input.inference_parameters import ( - InferenceParameters, - DataSchemaField, DataSchema, + DataSchemaField, DataSchemaReplace, + InferenceParameters, ) -from mindee.input import LocalResponse, PageOptions, PollingOptions from mindee.input.sources import ( Base64Input, BytesInput, @@ -22,29 +22,33 @@ from mindee.parsing.common.predict_response import PredictResponse from mindee.parsing.common.workflow_response import WorkflowResponse from mindee.parsing.v2 import InferenceResponse, JobResponse +from mindee.v2.product.split.split_parameters import SplitParameters +from mindee.v2.product.split.split_response import SplitResponse __all__ = [ + "ApiResponse", + "AsyncPredictResponse", + "Base64Input", + "BytesInput", "Client", "ClientV2", "DataSchema", "DataSchemaField", "DataSchemaReplace", - "InferenceParameters", + "FeedbackResponse", "FileInput", - "PathInput", - "BytesInput", - "Base64Input", - "UrlInputSource", + "InferenceParameters", + "InferenceResponse", + "Job", + "JobResponse", "LocalResponse", "PageOptions", + "PathInput", "PollingOptions", - "ApiResponse", - "AsyncPredictResponse", - "FeedbackResponse", "PredictResponse", + "SplitParameters", + "SplitResponse", + "UrlInputSource", "WorkflowResponse", - "JobResponse", - "Job", - "InferenceResponse", "product", ] diff --git a/mindee/client.py b/mindee/client.py index 6b8d3ba1..5377314c 100644 --- a/mindee/client.py +++ b/mindee/client.py @@ -353,7 +353,7 @@ def enqueue_and_parse( # pylint: disable=too-many-locals if poll_results.job.status == "failed": raise MindeeError("Parsing failed for job {poll_results.job.id}") logger.debug( - "Polling server for parsing result with job id: %s", queue_result.job.id + "Polling server for product result with job id: %s", queue_result.job.id ) retry_counter += 1 sleep(delay_sec) diff --git a/mindee/client_v2.py b/mindee/client_v2.py index 6819b2cf..07c19418 100644 --- a/mindee/client_v2.py +++ b/mindee/client_v2.py @@ -1,10 +1,11 @@ +import warnings from time import sleep -from typing import Optional, Union +from typing import Optional, Union, Type, TypeVar from mindee.client_mixin import ClientMixin from mindee.error.mindee_error import MindeeError from mindee.error.mindee_http_error_v2 import handle_error_v2 -from mindee.input import UrlInputSource +from mindee.input import UrlInputSource, BaseParameters from mindee.input.inference_parameters import InferenceParameters from mindee.input.polling_options import PollingOptions from mindee.input.sources.local_input_source import LocalInputSource @@ -15,9 +16,12 @@ is_valid_post_response, ) from mindee.parsing.v2.common_response import CommonStatus +from mindee.v2.parsing.inference.base_response import BaseResponse from mindee.parsing.v2.inference_response import InferenceResponse from mindee.parsing.v2.job_response import JobResponse +TypeBaseResponse = TypeVar("TypeBaseResponse", bound=BaseResponse) + class ClientV2(ClientMixin): """ @@ -41,20 +45,34 @@ def __init__(self, api_key: Optional[str] = None) -> None: def enqueue_inference( self, input_source: Union[LocalInputSource, UrlInputSource], - params: InferenceParameters, + params: BaseParameters, + disable_redundant_warnings: bool = False, + ) -> JobResponse: + """[Deprecated] Use `enqueue` instead.""" + if not disable_redundant_warnings: + warnings.warn( + "enqueue_inference is deprecated; use enqueue instead", + DeprecationWarning, + stacklevel=2, + ) + return self.enqueue(input_source, params) + + def enqueue( + self, + input_source: Union[LocalInputSource, UrlInputSource], + params: BaseParameters, ) -> JobResponse: """ Enqueues a document to a given model. :param input_source: The document/source file to use. Can be local or remote. - :param params: Parameters to set when sending a file. + :return: A valid inference response. """ logger.debug("Enqueuing inference using model: %s", params.model_id) - response = self.mindee_api.req_post_inference_enqueue( - input_source=input_source, params=params + input_source=input_source, params=params, slug=params.get_enqueue_slug() ) dict_response = response.json() @@ -79,34 +97,49 @@ def get_job(self, job_id: str) -> JobResponse: dict_response = response.json() return JobResponse(dict_response) - def get_inference(self, inference_id: str) -> InferenceResponse: + def get_inference( + self, + inference_id: str, + ) -> BaseResponse: + """[Deprecated] Use `get_result` instead.""" + return self.get_result(InferenceResponse, inference_id) + + def get_result( + self, + response_type: Type[TypeBaseResponse], + inference_id: str, + ) -> TypeBaseResponse: """ Get the result of an inference that was previously enqueued. The inference will only be available after it has finished processing. :param inference_id: UUID of the inference to retrieve. + :param response_type: Class of the product to instantiate. :return: An inference response. """ logger.debug("Fetching inference: %s", inference_id) - response = self.mindee_api.req_get_inference(inference_id) + response = self.mindee_api.req_get_inference( + inference_id, response_type.get_result_slug() + ) if not is_valid_get_response(response): handle_error_v2(response.json()) dict_response = response.json() - return InferenceResponse(dict_response) + return response_type(dict_response) - def enqueue_and_get_inference( + def enqueue_and_get_result( self, + response_type: Type[TypeBaseResponse], input_source: Union[LocalInputSource, UrlInputSource], - params: InferenceParameters, - ) -> InferenceResponse: + params: BaseParameters, + ) -> TypeBaseResponse: """ Enqueues to an asynchronous endpoint and automatically polls for a response. :param input_source: The document/source file to use. Can be local or remote. - :param params: Parameters to set when sending a file. + :param response_type: The product class to use for the response object. :return: A valid inference response. """ @@ -117,14 +150,15 @@ def enqueue_and_get_inference( params.polling_options.delay_sec, params.polling_options.max_retries, ) - enqueue_response = self.enqueue_inference(input_source, params) + enqueue_response = self.enqueue_inference(input_source, params, True) logger.debug( - "Successfully enqueued inference with job id: %s", enqueue_response.job.id + "Successfully enqueued document with job id: %s", enqueue_response.job.id ) sleep(params.polling_options.initial_delay_sec) try_counter = 0 while try_counter < params.polling_options.max_retries: job_response = self.get_job(enqueue_response.job.id) + assert isinstance(job_response, JobResponse) if job_response.job.status == CommonStatus.FAILED.value: if job_response.job.error: detail = job_response.job.error.detail @@ -134,8 +168,31 @@ def enqueue_and_get_inference( f"Parsing failed for job {job_response.job.id}: {detail}" ) if job_response.job.status == CommonStatus.PROCESSED.value: - return self.get_inference(job_response.job.id) + result = self.get_result( + response_type or InferenceResponse, job_response.job.id + ) + assert isinstance(result, response_type), ( + f'Invalid response type "{type(result)}"' + ) + return result try_counter += 1 sleep(params.polling_options.delay_sec) raise MindeeError(f"Couldn't retrieve document after {try_counter + 1} tries.") + + def enqueue_and_get_inference( + self, + input_source: Union[LocalInputSource, UrlInputSource], + params: InferenceParameters, + ) -> InferenceResponse: + """[Deprecated] Use `enqueue_and_get_result` instead.""" + warnings.warn( + "enqueue_and_get_inference is deprecated; use enqueue_and_get_result instead", + DeprecationWarning, + stacklevel=2, + ) + response = self.enqueue_and_get_result(InferenceResponse, input_source, params) + assert isinstance(response, InferenceResponse), ( + f'Invalid response type "{type(response)}"' + ) + return response diff --git a/mindee/error/mindee_http_error_v2.py b/mindee/error/mindee_http_error_v2.py index 99ba40da..a6be90f3 100644 --- a/mindee/error/mindee_http_error_v2.py +++ b/mindee/error/mindee_http_error_v2.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import List, Optional from mindee.parsing.common.string_dict import StringDict from mindee.parsing.v2 import ErrorItem, ErrorResponse @@ -18,7 +18,7 @@ def __init__(self, response: ErrorResponse) -> None: self.title = response.title self.code = response.code self.detail = response.detail - self.errors: list[ErrorItem] = response.errors + self.errors: List[ErrorItem] = response.errors super().__init__( f"HTTP {self.status} - {self.title} :: {self.code} - {self.detail}" ) diff --git a/mindee/input/__init__.py b/mindee/input/__init__.py index 9ed79985..31973802 100644 --- a/mindee/input/__init__.py +++ b/mindee/input/__init__.py @@ -1,4 +1,7 @@ from mindee.input.local_response import LocalResponse +from mindee.input.base_parameters import BaseParameters +from mindee.input.inference_parameters import InferenceParameters +from mindee.v2.product.split.split_parameters import SplitParameters from mindee.input.page_options import PageOptions from mindee.input.polling_options import PollingOptions from mindee.input.sources.base_64_input import Base64Input @@ -11,15 +14,18 @@ from mindee.input.workflow_options import WorkflowOptions __all__ = [ + "Base64Input", + "BaseParameters", + "BytesInput", + "FileInput", "InputType", + "InferenceParameters", "LocalInputSource", - "UrlInputSource", + "LocalResponse", + "PageOptions", "PathInput", - "FileInput", - "Base64Input", - "BytesInput", - "WorkflowOptions", "PollingOptions", - "PageOptions", - "LocalResponse", + "UrlInputSource", + "SplitParameters", + "WorkflowOptions", ] diff --git a/mindee/input/base_parameters.py b/mindee/input/base_parameters.py new file mode 100644 index 00000000..d1159ad2 --- /dev/null +++ b/mindee/input/base_parameters.py @@ -0,0 +1,44 @@ +from abc import ABC +from dataclasses import dataclass, field +from typing import Dict, Optional, List, Union + +from mindee.input.polling_options import PollingOptions + + +@dataclass +class BaseParameters(ABC): + """Base class for parameters accepted by all V2 endpoints.""" + + _slug: str = field(init=False) + """Slug of the endpoint.""" + + model_id: str + """ID of the model, required.""" + alias: Optional[str] = None + """Use an alias to link the file to your own DB. If empty, no alias will be used.""" + webhook_ids: Optional[List[str]] = None + """IDs of webhooks to propagate the API response to.""" + polling_options: Optional[PollingOptions] = None + """Options for polling. Set only if having timeout issues.""" + close_file: bool = True + """Whether to close the file after product.""" + + def get_form_data(self) -> Dict[str, Union[str, List[str]]]: + """ + Return the parameters as a config dictionary. + + :return: A dict of parameters. + """ + data: Dict[str, Union[str, List[str]]] = { + "model_id": self.model_id, + } + if self.alias is not None: + data["alias"] = self.alias + if self.webhook_ids and len(self.webhook_ids) > 0: + data["webhook_ids"] = self.webhook_ids + return data + + @classmethod + def get_enqueue_slug(cls) -> str: + """Getter for the enqueue slug.""" + return cls._slug diff --git a/mindee/input/inference_parameters.py b/mindee/input/inference_parameters.py index 6d4e01fa..807b2595 100644 --- a/mindee/input/inference_parameters.py +++ b/mindee/input/inference_parameters.py @@ -1,8 +1,8 @@ import json -from dataclasses import dataclass, asdict -from typing import List, Optional, Union +from dataclasses import dataclass, asdict, field +from typing import Dict, List, Optional, Union -from mindee.input.polling_options import PollingOptions +from mindee.input.base_parameters import BaseParameters @dataclass @@ -44,7 +44,7 @@ class DataSchemaField(StringDataClass): guidelines: Optional[str] = None """Optional extraction guidelines.""" nested_fields: Optional[dict] = None - """Subfields when type is `nested_object`. Leave empty for other types""" + """Subfields when type is `nested_object`. Leave empty for other types.""" @dataclass @@ -78,11 +78,12 @@ def __post_init__(self) -> None: @dataclass -class InferenceParameters: +class InferenceParameters(BaseParameters): """Inference parameters to set when sending a file.""" - model_id: str - """ID of the model, required.""" + _slug: str = field(init=False, default="inferences") + """Slug of the endpoint.""" + rag: Optional[bool] = None """Enhance extraction accuracy with Retrieval-Augmented Generation.""" raw_text: Optional[bool] = None @@ -94,14 +95,6 @@ class InferenceParameters: Boost the precision and accuracy of all extractions. Calculate confidence scores for all fields, and fill their ``confidence`` attribute. """ - alias: Optional[str] = None - """Use an alias to link the file to your own DB. If empty, no alias will be used.""" - webhook_ids: Optional[List[str]] = None - """IDs of webhooks to propagate the API response to.""" - polling_options: Optional[PollingOptions] = None - """Options for polling. Set only if having timeout issues.""" - close_file: bool = True - """Whether to close the file after parsing.""" text_context: Optional[str] = None """ Additional text context used by the model during inference. @@ -118,3 +111,24 @@ def __post_init__(self): self.data_schema = DataSchema(**json.loads(self.data_schema)) elif isinstance(self.data_schema, dict): self.data_schema = DataSchema(**self.data_schema) + + def get_form_data(self) -> Dict[str, Union[str, List[str]]]: + """ + Return the parameters as a config dictionary. + + :return: A dict of parameters. + """ + data = super().get_form_data() + if self.data_schema is not None: + data["data_schema"] = str(self.data_schema) + if self.rag is not None: + data["rag"] = data["rag"] = str(self.rag).lower() + if self.raw_text is not None: + data["raw_text"] = data["raw_text"] = str(self.raw_text).lower() + if self.polygon is not None: + data["polygon"] = data["polygon"] = str(self.polygon).lower() + if self.confidence is not None: + data["confidence"] = data["confidence"] = str(self.confidence).lower() + if self.text_context is not None: + data["text_context"] = self.text_context + return data diff --git a/mindee/mindee_http/mindee_api_v2.py b/mindee/mindee_http/mindee_api_v2.py index 9990330c..446582ad 100644 --- a/mindee/mindee_http/mindee_api_v2.py +++ b/mindee/mindee_http/mindee_api_v2.py @@ -4,8 +4,7 @@ import requests from mindee.error.mindee_error import MindeeApiV2Error -from mindee.input import LocalInputSource, UrlInputSource -from mindee.input.inference_parameters import InferenceParameters +from mindee.input import LocalInputSource, UrlInputSource, BaseParameters from mindee.logger import logger from mindee.mindee_http.base_settings import USER_AGENT from mindee.mindee_http.settings_mixin import SettingsMixin @@ -74,34 +73,19 @@ def set_from_env(self) -> None: def req_post_inference_enqueue( self, input_source: Union[LocalInputSource, UrlInputSource], - params: InferenceParameters, + params: BaseParameters, + slug: str, ) -> requests.Response: """ Make an asynchronous request to POST a document for prediction on the V2 API. :param input_source: Input object. :param params: Options for the enqueueing of the document. + :param slug: Slug to use for the enqueueing, defaults to 'inferences'. :return: requests response. """ - data: Dict[str, Union[str, list]] = {"model_id": params.model_id} - url = f"{self.url_root}/inferences/enqueue" - - if params.rag is not None: - data["rag"] = str(params.rag).lower() - if params.raw_text is not None: - data["raw_text"] = str(params.raw_text).lower() - if params.confidence is not None: - data["confidence"] = str(params.confidence).lower() - if params.polygon is not None: - data["polygon"] = str(params.polygon).lower() - if params.webhook_ids and len(params.webhook_ids) > 0: - data["webhook_ids"] = params.webhook_ids - if params.alias and len(params.alias): - data["alias"] = params.alias - if params.text_context and len(params.text_context): - data["text_context"] = params.text_context - if params.data_schema is not None: - data["data_schema"] = str(params.data_schema) + data = params.get_form_data() + url = f"{self.url_root}/{slug}/enqueue" if isinstance(input_source, LocalInputSource): files = {"file": input_source.read_contents(params.close_file)} @@ -137,14 +121,17 @@ def req_get_job(self, job_id: str) -> requests.Response: allow_redirects=False, ) - def req_get_inference(self, inference_id: str) -> requests.Response: + def req_get_inference(self, inference_id: str, slug: str) -> requests.Response: """ Sends a request matching a given queue_id. Returns either a Job or a Document. :param inference_id: Inference ID, returned by the job request. + :param slug: Slug of the inference, defaults to nothing. """ + + url = f"{self.url_root}/{slug}/{inference_id}" return requests.get( - f"{self.url_root}/inferences/{inference_id}", + url, headers=self.base_headers, timeout=self.request_timeout, allow_redirects=False, diff --git a/mindee/parsing/common/async_predict_response.py b/mindee/parsing/common/async_predict_response.py index e3101633..5d657532 100644 --- a/mindee/parsing/common/async_predict_response.py +++ b/mindee/parsing/common/async_predict_response.py @@ -23,7 +23,7 @@ def __init__( """ Container wrapper for a raw API response. - Inherits and instantiates a normal PredictResponse if the parsing of + Inherits and instantiates a normal PredictResponse if the product of the current queue is both requested and done. :param inference_type: Type of the inference. diff --git a/mindee/parsing/v2/inference.py b/mindee/parsing/v2/inference.py index 86c076c9..477cc41c 100644 --- a/mindee/parsing/v2/inference.py +++ b/mindee/parsing/v2/inference.py @@ -1,28 +1,19 @@ from mindee.parsing.common.string_dict import StringDict +from mindee.v2.parsing.inference import BaseInference from mindee.parsing.v2.inference_active_options import InferenceActiveOptions -from mindee.parsing.v2.inference_file import InferenceFile -from mindee.parsing.v2.inference_model import InferenceModel from mindee.parsing.v2.inference_result import InferenceResult -class Inference: +class Inference(BaseInference): """Inference object for a V2 API return.""" - id: str - """ID of the inference.""" - model: InferenceModel - """Model info for the inference.""" - file: InferenceFile - """File info for the inference.""" result: InferenceResult """Result of the inference.""" active_options: InferenceActiveOptions """Active options for the inference.""" def __init__(self, raw_response: StringDict): - self.id = raw_response["id"] - self.model = InferenceModel(raw_response["model"]) - self.file = InferenceFile(raw_response["file"]) + super().__init__(raw_response) self.result = InferenceResult(raw_response["result"]) self.active_options = InferenceActiveOptions(raw_response["active_options"]) diff --git a/mindee/parsing/v2/inference_response.py b/mindee/parsing/v2/inference_response.py index f1bb71c2..ff056d36 100644 --- a/mindee/parsing/v2/inference_response.py +++ b/mindee/parsing/v2/inference_response.py @@ -1,13 +1,17 @@ from mindee.parsing.common.string_dict import StringDict -from mindee.parsing.v2.common_response import CommonResponse from mindee.parsing.v2.inference import Inference +from mindee.v2.parsing.inference.base_response import ( + BaseResponse, +) -class InferenceResponse(CommonResponse): +class InferenceResponse(BaseResponse): """Represent an inference response from Mindee V2 API.""" inference: Inference """Inference result.""" + _slug: str = "inferences" + """Slug of the inference.""" def __init__(self, raw_response: StringDict) -> None: super().__init__(raw_response) @@ -15,3 +19,8 @@ def __init__(self, raw_response: StringDict) -> None: def __str__(self) -> str: return str(self.inference) + + @classmethod + def get_result_slug(cls) -> str: + """Getter for the inference slug.""" + return cls._slug diff --git a/mindee/v2/__init__.py b/mindee/v2/__init__.py new file mode 100644 index 00000000..136bbc42 --- /dev/null +++ b/mindee/v2/__init__.py @@ -0,0 +1,7 @@ +from mindee.v2.product.split.split_parameters import SplitParameters +from mindee.v2.product.split.split_response import SplitResponse + +__all__ = [ + "SplitResponse", + "SplitParameters", +] diff --git a/mindee/v2/parsing/__init__.py b/mindee/v2/parsing/__init__.py new file mode 100644 index 00000000..3ab40372 --- /dev/null +++ b/mindee/v2/parsing/__init__.py @@ -0,0 +1,7 @@ +from mindee.v2.parsing.inference.base_inference import BaseInference +from mindee.v2.parsing.inference.base_response import BaseResponse + +__all__ = [ + "BaseInference", + "BaseResponse", +] diff --git a/mindee/v2/parsing/inference/__init__.py b/mindee/v2/parsing/inference/__init__.py new file mode 100644 index 00000000..e59b67ae --- /dev/null +++ b/mindee/v2/parsing/inference/__init__.py @@ -0,0 +1,9 @@ +from mindee.v2.parsing.inference.base_inference import BaseInference +from mindee.v2.parsing.inference.base_response import ( + BaseResponse, +) + +__all__ = [ + "BaseInference", + "BaseResponse", +] diff --git a/mindee/v2/parsing/inference/base_inference.py b/mindee/v2/parsing/inference/base_inference.py new file mode 100644 index 00000000..78462f0f --- /dev/null +++ b/mindee/v2/parsing/inference/base_inference.py @@ -0,0 +1,25 @@ +from abc import ABC +from typing import TypeVar + +from mindee.parsing.common.string_dict import StringDict +from mindee.parsing.v2.inference_file import InferenceFile +from mindee.parsing.v2.inference_model import InferenceModel + + +class BaseInference(ABC): + """Base class for V2 inference objects.""" + + model: InferenceModel + """Model info for the inference.""" + file: InferenceFile + """File info for the inference.""" + id: str + """ID of the inference.""" + + def __init__(self, raw_response: StringDict): + self.id = raw_response["id"] + self.model = InferenceModel(raw_response["model"]) + self.file = InferenceFile(raw_response["file"]) + + +TypeBaseInference = TypeVar("TypeBaseInference", bound=BaseInference) diff --git a/mindee/v2/parsing/inference/base_response.py b/mindee/v2/parsing/inference/base_response.py new file mode 100644 index 00000000..55b6deb6 --- /dev/null +++ b/mindee/v2/parsing/inference/base_response.py @@ -0,0 +1,22 @@ +from abc import ABC + +from mindee.v2.parsing.inference.base_inference import BaseInference + +from mindee.parsing.v2.common_response import CommonResponse + + +class BaseResponse(ABC, CommonResponse): + """Base class for V2 inference responses.""" + + inference: BaseInference + """The inference result for a split utility request""" + _slug: str + """Slug of the inference.""" + + def __str__(self) -> str: + return str(self.inference) + + @classmethod + def get_result_slug(cls) -> str: + """Getter for the inference slug.""" + return cls._slug diff --git a/mindee/v2/product/__init__.py b/mindee/v2/product/__init__.py new file mode 100644 index 00000000..136bbc42 --- /dev/null +++ b/mindee/v2/product/__init__.py @@ -0,0 +1,7 @@ +from mindee.v2.product.split.split_parameters import SplitParameters +from mindee.v2.product.split.split_response import SplitResponse + +__all__ = [ + "SplitResponse", + "SplitParameters", +] diff --git a/mindee/v2/product/split/__init__.py b/mindee/v2/product/split/__init__.py new file mode 100644 index 00000000..9284c63e --- /dev/null +++ b/mindee/v2/product/split/__init__.py @@ -0,0 +1,13 @@ +from mindee.v2.product.split.split_inference import SplitInference +from mindee.v2.product.split.split_parameters import SplitParameters +from mindee.v2.product.split.split_response import SplitResponse +from mindee.v2.product.split.split_result import SplitResult +from mindee.v2.product.split.split_range import SplitRange + +__all__ = [ + "SplitInference", + "SplitParameters", + "SplitResponse", + "SplitResult", + "SplitRange", +] diff --git a/mindee/v2/product/split/split_inference.py b/mindee/v2/product/split/split_inference.py new file mode 100644 index 00000000..37aa6edb --- /dev/null +++ b/mindee/v2/product/split/split_inference.py @@ -0,0 +1,19 @@ +from mindee.parsing.common.string_dict import StringDict +from mindee.v2.parsing.inference.base_inference import BaseInference +from mindee.v2.product.split.split_result import SplitResult + + +class SplitInference(BaseInference): + """Split inference result.""" + + result: SplitResult + """Result of a split inference.""" + _slug: str = "split" + """Slug of the endpoint.""" + + def __init__(self, raw_response: StringDict) -> None: + super().__init__(raw_response) + self.result = SplitResult(raw_response["result"]) + + def __str__(self) -> str: + return f"Inference\n#########\n{self.model}\n{self.file}\n{self.result}\n" diff --git a/mindee/v2/product/split/split_parameters.py b/mindee/v2/product/split/split_parameters.py new file mode 100644 index 00000000..191070f6 --- /dev/null +++ b/mindee/v2/product/split/split_parameters.py @@ -0,0 +1,9 @@ +from mindee.input.base_parameters import BaseParameters + + +class SplitParameters(BaseParameters): + """ + Parameters accepted by the split utility v2 endpoint. + """ + + _slug: str = "utilities/split" diff --git a/mindee/v2/product/split/split_range.py b/mindee/v2/product/split/split_range.py new file mode 100644 index 00000000..21a85405 --- /dev/null +++ b/mindee/v2/product/split/split_range.py @@ -0,0 +1,20 @@ +from typing import List + +from mindee.parsing.common.string_dict import StringDict + + +class SplitRange: + """Split inference result.""" + + page_range: List[int] + """Page range of the split inference.""" + document_type: str + """Document type of the split inference.""" + + def __init__(self, server_response: StringDict): + self.page_range = server_response["page_range"] + self.document_type = server_response["document_type"] + + def __str__(self) -> str: + page_range = ",".join([str(page_index) for page_index in self.page_range]) + return f"* :Page Range: {page_range}\n :Document Type: {self.document_type}" diff --git a/mindee/v2/product/split/split_response.py b/mindee/v2/product/split/split_response.py new file mode 100644 index 00000000..dfb3c6d5 --- /dev/null +++ b/mindee/v2/product/split/split_response.py @@ -0,0 +1,17 @@ +from mindee.parsing.common.string_dict import StringDict +from mindee.v2.parsing.inference import BaseResponse +from mindee.v2.product.split.split_inference import SplitInference + + +class SplitResponse(BaseResponse): + """Represent a split inference response from Mindee V2 API.""" + + inference: SplitInference + """Inference object for split inference.""" + + _slug: str = "utilities/split" + """Slug of the inference.""" + + def __init__(self, raw_response: StringDict) -> None: + super().__init__(raw_response) + self.inference = SplitInference(raw_response["inference"]) diff --git a/mindee/v2/product/split/split_result.py b/mindee/v2/product/split/split_result.py new file mode 100644 index 00000000..efb63060 --- /dev/null +++ b/mindee/v2/product/split/split_result.py @@ -0,0 +1,20 @@ +from typing import List + +from mindee.parsing.common.string_dict import StringDict +from mindee.v2.product.split.split_range import SplitRange + + +class SplitResult: + """Split result info.""" + + splits: List[SplitRange] + + def __init__(self, raw_response: StringDict) -> None: + self.splits = [SplitRange(split) for split in raw_response["splits"]] + + def __str__(self) -> str: + splits = "\n" + if len(self.splits) > 0: + splits += "\n\n".join([str(split) for split in self.splits]) + out_str = f"Splits\n======{splits}" + return out_str diff --git a/tests/data b/tests/data index 0c51e1d3..e6495fb5 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit 0c51e1d3e2258404c44280f25f4951ba6fe27324 +Subproject commit e6495fb50c992f9c4624ae14a0404c4c194e4519 diff --git a/tests/test_code_samples.sh b/tests/test_code_samples.sh index 81f7bf83..8fd1ee57 100755 --- a/tests/test_code_samples.sh +++ b/tests/test_code_samples.sh @@ -7,6 +7,7 @@ ENDPOINT=$2 API_KEY=$3 API_KEY_V2=$4 MODEL_ID=$5 +SPLIT_MODEL_ID=$6 for f in $(find ./docs/extras/code_samples -maxdepth 1 -name "*.txt" -not -name "workflow_*.txt" | sort -h) do @@ -28,7 +29,7 @@ do sed -i 's/\/path\/to\/the\/file.ext/.\/tests\/data\/file_types\/pdf\/blank_1.pdf/' $OUTPUT_FILE - if echo "${f}" | grep -q "default_v2.txt" + if echo "${f}" | grep -q "v2_default.txt" then sed -i "s/MY_API_KEY/$API_KEY_V2/" $OUTPUT_FILE sed -i "s/MY_MODEL_ID/$MODEL_ID/" $OUTPUT_FILE @@ -36,6 +37,14 @@ do sed -i "s/my-api-key/$API_KEY/" $OUTPUT_FILE fi + if echo "${f}" | grep -q "v2_split.txt" + then + sed -i "s/MY_API_KEY/$API_KEY_V2/" $OUTPUT_FILE + sed -i "s/MY_SPLIT_MODEL_ID/$SPLIT_MODEL_ID/" $OUTPUT_FILE + else + sed -i "s/my-api-key/$API_KEY/" $OUTPUT_FILE + fi + if echo "$f" | grep -q "custom_v1.txt" then sed -i "s/my-account/$ACCOUNT/g" $OUTPUT_FILE diff --git a/tests/utils.py b/tests/utils.py index 252a699c..058e3595 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,7 @@ V2_DATA_DIR = ROOT_DATA_DIR / "v2" V2_PRODUCT_DATA_DIR = V2_DATA_DIR / "products" +V2_UTILITIES_DATA_DIR = V2_DATA_DIR / "utilities" def clear_envvars(monkeypatch) -> None: diff --git a/tests/v2/input/test_local_response.py b/tests/v2/input/test_local_response.py index 5ce07fe1..5db8be78 100644 --- a/tests/v2/input/test_local_response.py +++ b/tests/v2/input/test_local_response.py @@ -14,7 +14,7 @@ def file_path() -> Path: def _assert_local_response(local_response): fake_hmac_signing = "ogNjY44MhvKPGTtVsI8zG82JqWQa68woYQH" - signature = "1df388c992d87897fe61dfc56c444c58fc3c7369c31e2b5fd20d867695e93e85" + signature = "f390d9f7f57ac04f47b6309d8a40236b0182610804fc20e91b1f6028aaca07a7" assert local_response._file is not None assert not local_response.is_valid_hmac_signature( diff --git a/tests/v2/parsing/test_split_integration.py b/tests/v2/parsing/test_split_integration.py new file mode 100644 index 00000000..bd577505 --- /dev/null +++ b/tests/v2/parsing/test_split_integration.py @@ -0,0 +1,32 @@ +import os + +import pytest + +from mindee import ClientV2, PathInput +from mindee.input import SplitParameters +from mindee.v2 import SplitResponse +from tests.utils import V2_UTILITIES_DATA_DIR + + +@pytest.fixture(scope="session") +def split_model_id() -> str: + """Identifier of the Financial Document model, supplied through an env var.""" + return os.getenv("MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID") + + +@pytest.fixture(scope="session") +def v2_client() -> ClientV2: + return ClientV2() + + +@pytest.mark.integration +@pytest.mark.v2 +def test_split_blank(v2_client: ClientV2, split_model_id: str): + input_source = PathInput(V2_UTILITIES_DATA_DIR / "split" / "default_sample.pdf") + response = v2_client.enqueue_and_get_result( + SplitResponse, input_source, SplitParameters(split_model_id) + ) # Note: do not use blank_1.pdf for this. + assert response.inference is not None + assert response.inference.file.name == "default_sample.pdf" + assert response.inference.result.splits + assert len(response.inference.result.splits) == 2 diff --git a/tests/v2/parsing/test_split_response.py b/tests/v2/parsing/test_split_response.py new file mode 100644 index 00000000..b04428db --- /dev/null +++ b/tests/v2/parsing/test_split_response.py @@ -0,0 +1,49 @@ +import pytest + +from mindee import LocalResponse +from mindee.v2.product.split.split_range import SplitRange +from mindee.v2.product.split import SplitInference +from mindee.v2.product.split.split_response import SplitResponse +from mindee.v2.product.split.split_result import SplitResult +from tests.utils import V2_UTILITIES_DATA_DIR + + +@pytest.mark.v2 +def test_split_single(): + input_inference = LocalResponse( + V2_UTILITIES_DATA_DIR / "split" / "split_single.json" + ) + split_response = input_inference.deserialize_response(SplitResponse) + assert isinstance(split_response.inference, SplitInference) + assert split_response.inference.result.splits + assert len(split_response.inference.result.splits[0].page_range) == 2 + assert split_response.inference.result.splits[0].page_range[0] == 0 + assert split_response.inference.result.splits[0].page_range[1] == 0 + assert split_response.inference.result.splits[0].document_type == "receipt" + + +@pytest.mark.v2 +def test_split_multiple(): + input_inference = LocalResponse( + V2_UTILITIES_DATA_DIR / "split" / "split_multiple.json" + ) + split_response = input_inference.deserialize_response(SplitResponse) + assert isinstance(split_response.inference, SplitInference) + assert isinstance(split_response.inference.result, SplitResult) + assert isinstance(split_response.inference.result.splits[0], SplitRange) + assert len(split_response.inference.result.splits) == 3 + + assert len(split_response.inference.result.splits[0].page_range) == 2 + assert split_response.inference.result.splits[0].page_range[0] == 0 + assert split_response.inference.result.splits[0].page_range[1] == 0 + assert split_response.inference.result.splits[0].document_type == "invoice" + + assert len(split_response.inference.result.splits[1].page_range) == 2 + assert split_response.inference.result.splits[1].page_range[0] == 1 + assert split_response.inference.result.splits[1].page_range[1] == 3 + assert split_response.inference.result.splits[1].document_type == "invoice" + + assert len(split_response.inference.result.splits[2].page_range) == 2 + assert split_response.inference.result.splits[2].page_range[0] == 4 + assert split_response.inference.result.splits[2].page_range[1] == 4 + assert split_response.inference.result.splits[2].document_type == "invoice" diff --git a/tests/v2/test_client.py b/tests/v2/test_client.py index 866242a4..f51f84cc 100644 --- a/tests/v2/test_client.py +++ b/tests/v2/test_client.py @@ -124,9 +124,7 @@ def test_enqueue_path_with_env_token(custom_base_url_client): f"{FILE_TYPES_DIR}/receipt.jpg" ) with pytest.raises(MindeeHTTPErrorV2): - custom_base_url_client.enqueue_inference( - input_doc, InferenceParameters("dummy-model") - ) + custom_base_url_client.enqueue(input_doc, InferenceParameters("dummy-model")) @pytest.mark.v2 @@ -135,7 +133,8 @@ def test_enqueue_and_parse_path_with_env_token(custom_base_url_client): f"{FILE_TYPES_DIR}/receipt.jpg" ) with pytest.raises(MindeeHTTPErrorV2): - custom_base_url_client.enqueue_and_get_inference( + custom_base_url_client.enqueue_and_get_result( + InferenceResponse, input_doc, InferenceParameters( "dummy-model", @@ -172,8 +171,8 @@ def test_loads_from_prediction(): @pytest.mark.v2 def test_get_inference(custom_base_url_client): - response = custom_base_url_client.get_inference( - "12345678-1234-1234-1234-123456789ABC" + response = custom_base_url_client.get_result( + InferenceResponse, "12345678-1234-1234-1234-123456789ABC" ) _assert_findoc_inference(response) @@ -181,7 +180,7 @@ def test_get_inference(custom_base_url_client): @pytest.mark.v2 def test_error_handling(custom_base_url_client): with pytest.raises(MindeeHTTPErrorV2) as e: - custom_base_url_client.enqueue_inference( + custom_base_url_client.enqueue( PathInput( V2_DATA_DIR / "products" / "financial_document" / "default_sample.jpg" ), diff --git a/tests/v2/test_client_integration.py b/tests/v2/test_client_integration.py index c7d7a5f6..79ad65c4 100644 --- a/tests/v2/test_client_integration.py +++ b/tests/v2/test_client_integration.py @@ -18,12 +18,7 @@ def findoc_model_id() -> str: @pytest.fixture(scope="session") def v2_client() -> ClientV2: - """ - Real V2 client configured with the user-supplied API key - (or skipped when the key is absent). - """ - api_key = os.getenv("MINDEE_V2_API_KEY") - return ClientV2(api_key) + return ClientV2() def _basic_assert_success( @@ -63,8 +58,8 @@ def test_parse_file_empty_multiple_pages_must_succeed( alias="py_integration_empty_multiple", ) - response: InferenceResponse = v2_client.enqueue_and_get_inference( - input_source, params + response: InferenceResponse = v2_client.enqueue_and_get_result( + InferenceResponse, input_source, params ) _basic_assert_success(response=response, page_count=2, model_id=findoc_model_id) @@ -104,8 +99,8 @@ def test_parse_file_empty_single_page_options_must_succeed( confidence=True, alias="py_integration_empty_page_options", ) - response: InferenceResponse = v2_client.enqueue_and_get_inference( - input_source, params + response: InferenceResponse = v2_client.enqueue_and_get_result( + InferenceResponse, input_source, params ) _basic_assert_success(response=response, page_count=1, model_id=findoc_model_id) @@ -142,8 +137,8 @@ def test_parse_file_filled_single_page_must_succeed( text_context="this is an invoice.", ) - response: InferenceResponse = v2_client.enqueue_and_get_inference( - input_source, params + response: InferenceResponse = v2_client.enqueue_and_get_result( + InferenceResponse, input_source, params ) _basic_assert_success(response=response, page_count=1, model_id=findoc_model_id) @@ -183,7 +178,7 @@ def test_invalid_uuid_must_throw_error(v2_client: ClientV2) -> None: ) with pytest.raises(MindeeHTTPErrorV2) as exc_info: - v2_client.enqueue_inference(input_source, params) + v2_client.enqueue(input_source, params) exc: MindeeHTTPErrorV2 = exc_info.value assert exc.status == 422 @@ -204,7 +199,7 @@ def test_unknown_model_must_throw_error(v2_client: ClientV2) -> None: params = InferenceParameters(model_id="fc405e37-4ba4-4d03-aeba-533a8d1f0f21") with pytest.raises(MindeeHTTPErrorV2) as exc_info: - v2_client.enqueue_inference(input_source, params) + v2_client.enqueue(input_source, params) exc: MindeeHTTPErrorV2 = exc_info.value assert exc.status == 404 @@ -237,7 +232,7 @@ def test_unknown_webhook_ids_must_throw_error( ) with pytest.raises(MindeeHTTPErrorV2) as exc_info: - v2_client.enqueue_inference(input_source, params) + v2_client.enqueue(input_source, params) exc: MindeeHTTPErrorV2 = exc_info.value assert exc.status == 422 @@ -268,8 +263,8 @@ def test_blank_url_input_source_must_succeed( webhook_ids=[], alias="py_integration_url_source", ) - response: InferenceResponse = v2_client.enqueue_and_get_inference( - input_source, params + response: InferenceResponse = v2_client.enqueue_and_get_result( + InferenceResponse, input_source, params ) _basic_assert_success(response=response, page_count=1, model_id=findoc_model_id) @@ -299,8 +294,8 @@ def test_data_schema_must_succeed( data_schema=data_schema_replace_path.read_text(), alias="py_integration_data_schema_replace", ) - response: InferenceResponse = v2_client.enqueue_and_get_inference( - input_source, params + response: InferenceResponse = v2_client.enqueue_and_get_result( + InferenceResponse, input_source, params ) _basic_assert_success(response=response, page_count=1, model_id=findoc_model_id) assert response.inference.active_options.data_schema.replace is True