From 5aedc34962f7259cfa69df9235d071edb8827b99 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Thu, 5 Feb 2026 15:09:21 +0800 Subject: [PATCH] feat(tests): add unit tests for knowledgebase module and API This commit introduces a comprehensive suite of unit tests for the knowledgebase module, including tests for the KnowledgeBaseClient, KnowledgeBase, and various provider settings. The tests cover creation, deletion, and update functionalities, ensuring robust validation of the knowledgebase operations. Additionally, new test files for the API and model components have been added to enhance test coverage and reliability. Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/knowledgebase/__init__.py | 6 + .../__knowledgebase_async_template.py | 50 + .../api/__data_async_template.py | 226 ++- agentrun/knowledgebase/api/__init__.py | 2 + agentrun/knowledgebase/api/data.py | 275 ++- agentrun/knowledgebase/knowledgebase.py | 50 + agentrun/knowledgebase/model.py | 66 +- agentrun/utils/control_api.py | 40 + examples/knowledgebase.py | 224 ++- pyproject.toml | 1 + tests/unittests/knowledgebase/__init__.py | 0 tests/unittests/knowledgebase/api/__init__.py | 1 + .../unittests/knowledgebase/api/test_data.py | 1594 +++++++++++++++++ tests/unittests/knowledgebase/test_client.py | 639 +++++++ .../knowledgebase/test_knowledgebase.py | 1306 ++++++++++++++ tests/unittests/knowledgebase/test_model.py | 650 +++++++ 16 files changed, 5117 insertions(+), 13 deletions(-) create mode 100644 tests/unittests/knowledgebase/__init__.py create mode 100644 tests/unittests/knowledgebase/api/__init__.py create mode 100644 tests/unittests/knowledgebase/api/test_data.py create mode 100644 tests/unittests/knowledgebase/test_client.py create mode 100644 tests/unittests/knowledgebase/test_knowledgebase.py create mode 100644 tests/unittests/knowledgebase/test_model.py diff --git a/agentrun/knowledgebase/__init__.py b/agentrun/knowledgebase/__init__.py index df0f1a8..6a01376 100644 --- a/agentrun/knowledgebase/__init__.py +++ b/agentrun/knowledgebase/__init__.py @@ -1,6 +1,7 @@ """KnowledgeBase 模块 / KnowledgeBase Module""" from .api import ( + ADBDataAPI, BailianDataAPI, get_data_api, KnowledgeBaseControlAPI, @@ -10,6 +11,8 @@ from .client import KnowledgeBaseClient from .knowledgebase import KnowledgeBase from .model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseCreateInput, @@ -33,6 +36,7 @@ "KnowledgeBaseDataAPI", "RagFlowDataAPI", "BailianDataAPI", + "ADBDataAPI", "get_data_api", # enums "KnowledgeBaseProvider", @@ -40,10 +44,12 @@ "ProviderSettings", "RagFlowProviderSettings", "BailianProviderSettings", + "ADBProviderSettings", # retrieve settings "RetrieveSettings", "RagFlowRetrieveSettings", "BailianRetrieveSettings", + "ADBRetrieveSettings", # api model "KnowledgeBaseCreateInput", "KnowledgeBaseUpdateInput", diff --git a/agentrun/knowledgebase/__knowledgebase_async_template.py b/agentrun/knowledgebase/__knowledgebase_async_template.py index 96ee94c..501449e 100644 --- a/agentrun/knowledgebase/__knowledgebase_async_template.py +++ b/agentrun/knowledgebase/__knowledgebase_async_template.py @@ -14,6 +14,8 @@ from .api.data import get_data_api from .model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseCreateInput, @@ -294,6 +296,54 @@ def _get_data_api(self, config: Optional[Config] = None): **self.retrieve_settings ) + elif provider == KnowledgeBaseProvider.ADB: + # ADB 设置 / ADB settings + if self.provider_settings: + if isinstance(self.provider_settings, ADBProviderSettings): + converted_provider_settings = self.provider_settings + elif isinstance(self.provider_settings, dict): + # ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case + # ADB provider_settings uses PascalCase keys, need to convert to snake_case + converted_provider_settings = ADBProviderSettings( + db_instance_id=self.provider_settings.get( + "DBInstanceId", "" + ), + namespace=self.provider_settings.get("Namespace", ""), + namespace_password=self.provider_settings.get( + "NamespacePassword", "" + ), + embedding_model=self.provider_settings.get( + "EmbeddingModel" + ), + metrics=self.provider_settings.get("Metrics"), + metadata=self.provider_settings.get("Metadata"), + ) + + if self.retrieve_settings: + if isinstance(self.retrieve_settings, ADBRetrieveSettings): + converted_retrieve_settings = self.retrieve_settings + elif isinstance(self.retrieve_settings, dict): + # ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case + # ADB retrieve_settings uses PascalCase keys, need to convert to snake_case + converted_retrieve_settings = ADBRetrieveSettings( + top_k=self.retrieve_settings.get("TopK"), + use_full_text_retrieval=self.retrieve_settings.get( + "UseFullTextRetrieval" + ), + rerank_factor=self.retrieve_settings.get( + "RerankFactor" + ), + recall_window=self.retrieve_settings.get( + "RecallWindow" + ), + hybrid_search=self.retrieve_settings.get( + "HybridSearch" + ), + hybrid_search_args=self.retrieve_settings.get( + "HybridSearchArgs" + ), + ) + return get_data_api( provider=provider, knowledge_base_name=self.knowledge_base_name or "", diff --git a/agentrun/knowledgebase/api/__data_async_template.py b/agentrun/knowledgebase/api/__data_async_template.py index 4f8a1ae..51e9c43 100644 --- a/agentrun/knowledgebase/api/__data_async_template.py +++ b/agentrun/knowledgebase/api/__data_async_template.py @@ -3,14 +3,15 @@ 提供知识库检索功能的数据链路 API。 Provides data API for knowledge base retrieval operations. -根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。 -Dispatches to different implementations based on provider type (ragflow / bailian). +根据不同的 provider 类型(ragflow / bailian / adb)分发到不同的实现。 +Dispatches to different implementations based on provider type (ragflow / bailian / adb). """ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union from alibabacloud_bailian20231229 import models as bailian_models +from alibabacloud_gpdb20160503 import models as gpdb_models import httpx from agentrun.utils.config import Config @@ -19,6 +20,8 @@ from agentrun.utils.log import logger from ..model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseProvider, @@ -347,15 +350,213 @@ async def retrieve_async( } +class ADBDataAPI(KnowledgeBaseDataAPI, ControlAPI): + """ADB (AnalyticDB for PostgreSQL) 知识库数据链路 API / ADB KnowledgeBase Data API + + 实现 ADB 知识库的检索逻辑,通过 GPDB SDK 调用 QueryContent 接口。 + Implements retrieval logic for ADB knowledge base via GPDB SDK QueryContent API. + """ + + def __init__( + self, + knowledge_base_name: str, + config: Optional[Config] = None, + provider_settings: Optional[ADBProviderSettings] = None, + retrieve_settings: Optional[ADBRetrieveSettings] = None, + ): + """初始化 ADB 知识库数据链路 API / Initialize ADB KnowledgeBase Data API + + Args: + knowledge_base_name: 知识库名称 / Knowledge base name + config: 配置 / Configuration + provider_settings: ADB 提供商设置 / ADB provider settings + retrieve_settings: ADB 检索设置 / ADB retrieve settings + """ + KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config) + ControlAPI.__init__(self, config) + self.provider_settings = provider_settings + self.retrieve_settings = retrieve_settings + + def _build_query_content_request( + self, query: str, config: Optional[Config] = None + ) -> gpdb_models.QueryContentRequest: + """构建 QueryContent 请求 / Build QueryContent request + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + QueryContentRequest: GPDB QueryContent 请求对象 + """ + if self.provider_settings is None: + raise ValueError("provider_settings is required for ADB retrieval") + + cfg = Config.with_configs(self.config, config) + + # 构建基础请求参数 / Build base request parameters + request_params: Dict[str, Any] = { + "content": query, + "dbinstance_id": self.provider_settings.db_instance_id, + "namespace": self.provider_settings.namespace, + "namespace_password": self.provider_settings.namespace_password, + "collection": self.knowledge_base_name, + "region_id": cfg.get_region_id(), + } + + # 添加可选的提供商设置 / Add optional provider settings + if self.provider_settings.metrics is not None: + request_params["metrics"] = self.provider_settings.metrics + + # 添加检索设置 / Add retrieve settings + if self.retrieve_settings: + if self.retrieve_settings.top_k is not None: + request_params["top_k"] = self.retrieve_settings.top_k + if self.retrieve_settings.use_full_text_retrieval is not None: + request_params["use_full_text_retrieval"] = ( + self.retrieve_settings.use_full_text_retrieval + ) + if self.retrieve_settings.rerank_factor is not None: + request_params["rerank_factor"] = ( + self.retrieve_settings.rerank_factor + ) + if self.retrieve_settings.recall_window is not None: + request_params["recall_window"] = ( + self.retrieve_settings.recall_window + ) + if self.retrieve_settings.hybrid_search is not None: + request_params["hybrid_search"] = ( + self.retrieve_settings.hybrid_search + ) + if self.retrieve_settings.hybrid_search_args is not None: + request_params["hybrid_search_args"] = ( + self.retrieve_settings.hybrid_search_args + ) + + return gpdb_models.QueryContentRequest(**request_params) + + def _parse_query_content_response( + self, response: gpdb_models.QueryContentResponse, query: str + ) -> Dict[str, Any]: + """解析 QueryContent 响应 / Parse QueryContent response + + Args: + response: GPDB QueryContent 响应对象 + query: 原始查询文本 / Original query text + + Returns: + Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results + """ + all_matches: List[Dict[str, Any]] = [] + + if response.body and response.body.matches: + match_list = response.body.matches.match_list or [] + for match in match_list: + all_matches.append({ + "content": ( + match.content if hasattr(match, "content") else None + ), + "score": match.score if hasattr(match, "score") else None, + "id": match.id if hasattr(match, "id") else None, + "file_name": ( + match.file_name if hasattr(match, "file_name") else None + ), + "file_url": ( + match.file_url if hasattr(match, "file_url") else None + ), + "metadata": ( + match.metadata if hasattr(match, "metadata") else None + ), + "rerank_score": ( + match.rerank_score + if hasattr(match, "rerank_score") + else None + ), + "retrieval_source": ( + match.retrieval_source + if hasattr(match, "retrieval_source") + else None + ), + }) + + return { + "data": all_matches, + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "request_id": ( + response.body.request_id + if response.body and hasattr(response.body, "request_id") + else None + ), + } + + async def retrieve_async( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """ADB 检索(异步)/ ADB retrieval asynchronously + + 通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。 + Retrieves from ADB knowledge base via GPDB SDK QueryContent API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for ADB retrieval" + ) + + # 获取 GPDB 客户端 / Get GPDB client + client = self._get_gpdb_client(config) + + # 构建请求 / Build request + request = self._build_query_content_request(query, config) + logger.debug(f"ADB QueryContent request: {request}") + + # 调用 QueryContent API / Call QueryContent API + response = await client.query_content_async(request) + logger.debug(f"ADB QueryContent response: {response}") + + # 解析并返回结果 / Parse and return results + return self._parse_query_content_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from ADB knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def get_data_api( provider: KnowledgeBaseProvider, knowledge_base_name: str, config: Optional[Config] = None, provider_settings: Optional[ - Union[RagFlowProviderSettings, BailianProviderSettings] + Union[ + RagFlowProviderSettings, + BailianProviderSettings, + ADBProviderSettings, + ] ] = None, retrieve_settings: Optional[ - Union[RagFlowRetrieveSettings, BailianRetrieveSettings] + Union[ + RagFlowRetrieveSettings, + BailianRetrieveSettings, + ADBRetrieveSettings, + ] ] = None, credential_name: Optional[str] = None, ) -> KnowledgeBaseDataAPI: @@ -410,5 +611,22 @@ def get_data_api( provider_settings=bailian_provider_settings, retrieve_settings=bailian_retrieve_settings, ) + elif provider == KnowledgeBaseProvider.ADB or provider == "adb": + adb_provider_settings = ( + provider_settings + if isinstance(provider_settings, ADBProviderSettings) + else None + ) + adb_retrieve_settings = ( + retrieve_settings + if isinstance(retrieve_settings, ADBRetrieveSettings) + else None + ) + return ADBDataAPI( + knowledge_base_name, + config, + provider_settings=adb_provider_settings, + retrieve_settings=adb_retrieve_settings, + ) else: raise ValueError(f"Unsupported provider type: {provider}") diff --git a/agentrun/knowledgebase/api/__init__.py b/agentrun/knowledgebase/api/__init__.py index 2746a9e..bcfc80c 100644 --- a/agentrun/knowledgebase/api/__init__.py +++ b/agentrun/knowledgebase/api/__init__.py @@ -2,6 +2,7 @@ from .control import KnowledgeBaseControlAPI from .data import ( + ADBDataAPI, BailianDataAPI, get_data_api, KnowledgeBaseDataAPI, @@ -15,5 +16,6 @@ "KnowledgeBaseDataAPI", "RagFlowDataAPI", "BailianDataAPI", + "ADBDataAPI", "get_data_api", ] diff --git a/agentrun/knowledgebase/api/data.py b/agentrun/knowledgebase/api/data.py index 350a302..747157c 100644 --- a/agentrun/knowledgebase/api/data.py +++ b/agentrun/knowledgebase/api/data.py @@ -13,14 +13,15 @@ 提供知识库检索功能的数据链路 API。 Provides data API for knowledge base retrieval operations. -根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。 -Dispatches to different implementations based on provider type (ragflow / bailian). +根据不同的 provider 类型(ragflow / bailian / adb)分发到不同的实现。 +Dispatches to different implementations based on provider type (ragflow / bailian / adb). """ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union from alibabacloud_bailian20231229 import models as bailian_models +from alibabacloud_gpdb20160503 import models as gpdb_models import httpx from agentrun.utils.config import Config @@ -29,6 +30,8 @@ from agentrun.utils.log import logger from ..model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseProvider, @@ -557,15 +560,262 @@ def retrieve( } +class ADBDataAPI(KnowledgeBaseDataAPI, ControlAPI): + """ADB (AnalyticDB for PostgreSQL) 知识库数据链路 API / ADB KnowledgeBase Data API + + 实现 ADB 知识库的检索逻辑,通过 GPDB SDK 调用 QueryContent 接口。 + Implements retrieval logic for ADB knowledge base via GPDB SDK QueryContent API. + """ + + def __init__( + self, + knowledge_base_name: str, + config: Optional[Config] = None, + provider_settings: Optional[ADBProviderSettings] = None, + retrieve_settings: Optional[ADBRetrieveSettings] = None, + ): + """初始化 ADB 知识库数据链路 API / Initialize ADB KnowledgeBase Data API + + Args: + knowledge_base_name: 知识库名称 / Knowledge base name + config: 配置 / Configuration + provider_settings: ADB 提供商设置 / ADB provider settings + retrieve_settings: ADB 检索设置 / ADB retrieve settings + """ + KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config) + ControlAPI.__init__(self, config) + self.provider_settings = provider_settings + self.retrieve_settings = retrieve_settings + + def _build_query_content_request( + self, query: str, config: Optional[Config] = None + ) -> gpdb_models.QueryContentRequest: + """构建 QueryContent 请求 / Build QueryContent request + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + QueryContentRequest: GPDB QueryContent 请求对象 + """ + if self.provider_settings is None: + raise ValueError("provider_settings is required for ADB retrieval") + + cfg = Config.with_configs(self.config, config) + + # 构建基础请求参数 / Build base request parameters + request_params: Dict[str, Any] = { + "content": query, + "dbinstance_id": self.provider_settings.db_instance_id, + "namespace": self.provider_settings.namespace, + "namespace_password": self.provider_settings.namespace_password, + "collection": self.knowledge_base_name, + "region_id": cfg.get_region_id(), + } + + # 添加可选的提供商设置 / Add optional provider settings + if self.provider_settings.metrics is not None: + request_params["metrics"] = self.provider_settings.metrics + + # 添加检索设置 / Add retrieve settings + if self.retrieve_settings: + if self.retrieve_settings.top_k is not None: + request_params["top_k"] = self.retrieve_settings.top_k + if self.retrieve_settings.use_full_text_retrieval is not None: + request_params["use_full_text_retrieval"] = ( + self.retrieve_settings.use_full_text_retrieval + ) + if self.retrieve_settings.rerank_factor is not None: + request_params["rerank_factor"] = ( + self.retrieve_settings.rerank_factor + ) + if self.retrieve_settings.recall_window is not None: + request_params["recall_window"] = ( + self.retrieve_settings.recall_window + ) + if self.retrieve_settings.hybrid_search is not None: + request_params["hybrid_search"] = ( + self.retrieve_settings.hybrid_search + ) + if self.retrieve_settings.hybrid_search_args is not None: + request_params["hybrid_search_args"] = ( + self.retrieve_settings.hybrid_search_args + ) + + return gpdb_models.QueryContentRequest(**request_params) + + def _parse_query_content_response( + self, response: gpdb_models.QueryContentResponse, query: str + ) -> Dict[str, Any]: + """解析 QueryContent 响应 / Parse QueryContent response + + Args: + response: GPDB QueryContent 响应对象 + query: 原始查询文本 / Original query text + + Returns: + Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results + """ + all_matches: List[Dict[str, Any]] = [] + + if response.body and response.body.matches: + match_list = response.body.matches.match_list or [] + for match in match_list: + all_matches.append({ + "content": ( + match.content if hasattr(match, "content") else None + ), + "score": match.score if hasattr(match, "score") else None, + "id": match.id if hasattr(match, "id") else None, + "file_name": ( + match.file_name if hasattr(match, "file_name") else None + ), + "file_url": ( + match.file_url if hasattr(match, "file_url") else None + ), + "metadata": ( + match.metadata if hasattr(match, "metadata") else None + ), + "rerank_score": ( + match.rerank_score + if hasattr(match, "rerank_score") + else None + ), + "retrieval_source": ( + match.retrieval_source + if hasattr(match, "retrieval_source") + else None + ), + }) + + return { + "data": all_matches, + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "request_id": ( + response.body.request_id + if response.body and hasattr(response.body, "request_id") + else None + ), + } + + async def retrieve_async( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """ADB 检索(异步)/ ADB retrieval asynchronously + + 通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。 + Retrieves from ADB knowledge base via GPDB SDK QueryContent API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for ADB retrieval" + ) + + # 获取 GPDB 客户端 / Get GPDB client + client = self._get_gpdb_client(config) + + # 构建请求 / Build request + request = self._build_query_content_request(query, config) + logger.debug(f"ADB QueryContent request: {request}") + + # 调用 QueryContent API / Call QueryContent API + response = await client.query_content_async(request) + logger.debug(f"ADB QueryContent response: {response}") + + # 解析并返回结果 / Parse and return results + return self._parse_query_content_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from ADB knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def retrieve( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """ADB 检索(同步)/ ADB retrieval synchronously + + 通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。 + Retrieves from ADB knowledge base via GPDB SDK QueryContent API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for ADB retrieval" + ) + + # 获取 GPDB 客户端 / Get GPDB client + client = self._get_gpdb_client(config) + + # 构建请求 / Build request + request = self._build_query_content_request(query, config) + logger.debug(f"ADB QueryContent request: {request}") + + # 调用 QueryContent API / Call QueryContent API + response = client.query_content(request) + logger.debug(f"ADB QueryContent response: {response}") + + # 解析并返回结果 / Parse and return results + return self._parse_query_content_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from ADB knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def get_data_api( provider: KnowledgeBaseProvider, knowledge_base_name: str, config: Optional[Config] = None, provider_settings: Optional[ - Union[RagFlowProviderSettings, BailianProviderSettings] + Union[ + RagFlowProviderSettings, + BailianProviderSettings, + ADBProviderSettings, + ] ] = None, retrieve_settings: Optional[ - Union[RagFlowRetrieveSettings, BailianRetrieveSettings] + Union[ + RagFlowRetrieveSettings, + BailianRetrieveSettings, + ADBRetrieveSettings, + ] ] = None, credential_name: Optional[str] = None, ) -> KnowledgeBaseDataAPI: @@ -620,5 +870,22 @@ def get_data_api( provider_settings=bailian_provider_settings, retrieve_settings=bailian_retrieve_settings, ) + elif provider == KnowledgeBaseProvider.ADB or provider == "adb": + adb_provider_settings = ( + provider_settings + if isinstance(provider_settings, ADBProviderSettings) + else None + ) + adb_retrieve_settings = ( + retrieve_settings + if isinstance(retrieve_settings, ADBRetrieveSettings) + else None + ) + return ADBDataAPI( + knowledge_base_name, + config, + provider_settings=adb_provider_settings, + retrieve_settings=adb_retrieve_settings, + ) else: raise ValueError(f"Unsupported provider type: {provider}") diff --git a/agentrun/knowledgebase/knowledgebase.py b/agentrun/knowledgebase/knowledgebase.py index 2e453da..74c5c50 100644 --- a/agentrun/knowledgebase/knowledgebase.py +++ b/agentrun/knowledgebase/knowledgebase.py @@ -24,6 +24,8 @@ from .api.data import get_data_api from .model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseCreateInput, @@ -472,6 +474,54 @@ def _get_data_api(self, config: Optional[Config] = None): **self.retrieve_settings ) + elif provider == KnowledgeBaseProvider.ADB: + # ADB 设置 / ADB settings + if self.provider_settings: + if isinstance(self.provider_settings, ADBProviderSettings): + converted_provider_settings = self.provider_settings + elif isinstance(self.provider_settings, dict): + # ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case + # ADB provider_settings uses PascalCase keys, need to convert to snake_case + converted_provider_settings = ADBProviderSettings( + db_instance_id=self.provider_settings.get( + "DBInstanceId", "" + ), + namespace=self.provider_settings.get("Namespace", ""), + namespace_password=self.provider_settings.get( + "NamespacePassword", "" + ), + embedding_model=self.provider_settings.get( + "EmbeddingModel" + ), + metrics=self.provider_settings.get("Metrics"), + metadata=self.provider_settings.get("Metadata"), + ) + + if self.retrieve_settings: + if isinstance(self.retrieve_settings, ADBRetrieveSettings): + converted_retrieve_settings = self.retrieve_settings + elif isinstance(self.retrieve_settings, dict): + # ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case + # ADB retrieve_settings uses PascalCase keys, need to convert to snake_case + converted_retrieve_settings = ADBRetrieveSettings( + top_k=self.retrieve_settings.get("TopK"), + use_full_text_retrieval=self.retrieve_settings.get( + "UseFullTextRetrieval" + ), + rerank_factor=self.retrieve_settings.get( + "RerankFactor" + ), + recall_window=self.retrieve_settings.get( + "RecallWindow" + ), + hybrid_search=self.retrieve_settings.get( + "HybridSearch" + ), + hybrid_search_args=self.retrieve_settings.get( + "HybridSearchArgs" + ), + ) + return get_data_api( provider=provider, knowledge_base_name=self.knowledge_base_name or "", diff --git a/agentrun/knowledgebase/model.py b/agentrun/knowledgebase/model.py index 69ce23c..1c7c227 100644 --- a/agentrun/knowledgebase/model.py +++ b/agentrun/knowledgebase/model.py @@ -18,6 +18,8 @@ class KnowledgeBaseProvider(str, Enum): """RagFlow 知识库 / RagFlow knowledge base""" BAILIAN = "bailian" """百炼知识库 / Bailian knowledge base""" + ADB = "adb" + """ADB (AnalyticDB for PostgreSQL) 知识库 / ADB knowledge base""" # ============================================================================= @@ -75,17 +77,77 @@ class BailianRetrieveSettings(BaseModel): """重排序返回的 Top N 数量 / Rerank top N""" +# ============================================================================= +# ADB 配置模型 / ADB Configuration Models +# ============================================================================= + + +class ADBProviderSettings(BaseModel): + """ADB (AnalyticDB for PostgreSQL) 提供商设置 / ADB Provider Settings + + 配置 ADB 知识库的连接和访问参数。 + Configure ADB knowledge base connection and access parameters. + """ + + db_instance_id: str + """ADB 实例 ID / ADB instance ID""" + namespace: str + """命名空间,默认为 public / Namespace, defaults to public""" + namespace_password: str + """命名空间密码 / Namespace password""" + embedding_model: Optional[str] = None + """向量化模型名称,如 text-embedding-v3 / Embedding model name""" + metrics: Optional[str] = None + """相似度算法:l2(欧氏距离)、ip(内积)、cosine(余弦相似度) + Similarity algorithm: l2 (Euclidean), ip (inner product), cosine""" + metadata: Optional[str] = None + """元数据配置,JSON 字符串格式 / Metadata configuration in JSON string format""" + + +class ADBRetrieveSettings(BaseModel): + """ADB 检索设置 / ADB Retrieve Settings + + 配置 ADB 知识库的检索参数,支持向量检索和全文检索的混合模式。 + Configure ADB knowledge base retrieval parameters, supporting hybrid + vector and full-text retrieval modes. + """ + + top_k: Optional[int] = None + """返回结果的数量 / Number of results to return""" + use_full_text_retrieval: Optional[bool] = None + """是否启用全文检索(双路召回),默认 false 仅使用向量检索 + Enable full-text retrieval (dual recall), defaults to false (vector only)""" + rerank_factor: Optional[float] = None + """重排序因子,取值范围 1 < RerankFactor <= 5 + Re-ranking factor, value range: 1 < RerankFactor <= 5""" + recall_window: Optional[List[int]] = None + """召回窗口,格式为 [A, B],其中 -10 <= A <= 0,0 <= B <= 10 + Recall window, format [A, B] where -10 <= A <= 0, 0 <= B <= 10""" + hybrid_search: Optional[str] = None + """混合检索算法:RRF(倒数排名融合)、Weight(加权排序)、Cascaded(级联检索) + Hybrid search algorithm: RRF, Weight, or Cascaded""" + hybrid_search_args: Optional[Dict[str, Any]] = None + """混合检索算法参数,如 {"RRF": {"k": 60}} 或 {"Weight": {"alpha": 0.5}} + Hybrid search algorithm parameters""" + + # ============================================================================= # 联合类型定义 / Union Type Definitions # ============================================================================= ProviderSettings = Union[ - RagFlowProviderSettings, BailianProviderSettings, Dict[str, Any] + RagFlowProviderSettings, + BailianProviderSettings, + ADBProviderSettings, + Dict[str, Any], ] """提供商设置联合类型 / Provider settings union type""" RetrieveSettings = Union[ - RagFlowRetrieveSettings, BailianRetrieveSettings, Dict[str, Any] + RagFlowRetrieveSettings, + BailianRetrieveSettings, + ADBRetrieveSettings, + Dict[str, Any], ] """检索设置联合类型 / Retrieve settings union type""" diff --git a/agentrun/utils/control_api.py b/agentrun/utils/control_api.py index d9db600..b74a822 100644 --- a/agentrun/utils/control_api.py +++ b/agentrun/utils/control_api.py @@ -9,6 +9,7 @@ from alibabacloud_agentrun20250910.client import Client as AgentRunClient from alibabacloud_bailian20231229.client import Client as BailianClient from alibabacloud_devs20230714.client import Client as DevsClient +from alibabacloud_gpdb20160503.client import Client as GPDBClient from alibabacloud_tea_openapi import utils_models as open_api_util_models from agentrun.utils.config import Config @@ -103,3 +104,42 @@ def _get_bailian_client( read_timeout=cfg.get_read_timeout(), # type: ignore ) ) + + def _get_gpdb_client(self, config: Optional[Config] = None) -> "GPDBClient": + """ + 获取 GPDB (AnalyticDB for PostgreSQL) API 客户端实例 + Get GPDB (AnalyticDB for PostgreSQL) API client instance + + Args: + config: 配置对象,可选 / Configuration object, optional + + Returns: + GPDBClient: GPDB API 客户端实例 / GPDB API client instance + """ + + cfg = Config.with_configs(self.config, config) + # GPDB 使用区域级别的 endpoint / GPDB uses region-level endpoint + region_id = cfg.get_region_id() + if region_id in ( + "cn-beijing", + "cn-hangzhou", + "cn-shanghai", + "cn-shenzhen", + "cn-hongkong", + "ap-southeast-1", + ): + endpoint = "gpdb.aliyuncs.com" + else: + endpoint = f"gpdb.{region_id}.aliyuncs.com" + + return GPDBClient( + open_api_util_models.Config( + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + security_token=cfg.get_security_token(), + region_id=cfg.get_region_id(), + endpoint=endpoint, + connect_timeout=cfg.get_timeout(), # type: ignore + read_timeout=cfg.get_read_timeout(), # type: ignore + ) + ) diff --git a/examples/knowledgebase.py b/examples/knowledgebase.py index b9ddef5..61b5722 100644 --- a/examples/knowledgebase.py +++ b/examples/knowledgebase.py @@ -1,9 +1,9 @@ """ 知识库模块示例 / KnowledgeBase Module Example -本示例演示如何使用 AgentRun SDK 管理知识库,包括百炼和 RagFlow 两种类型: +本示例演示如何使用 AgentRun SDK 管理知识库,包括百炼、RagFlow 和 ADB 三种类型: This example demonstrates how to use the AgentRun SDK to manage knowledge bases, -including both Bailian and RagFlow types: +including Bailian, RagFlow and ADB types: 1. 创建知识库 / Create knowledge base (Bailian & RagFlow) 2. 获取知识库信息 / Get knowledge base info @@ -25,6 +25,12 @@ - RAGFLOW_BASE_URL: RagFlow 服务地址 - RAGFLOW_DATASET_IDS: RagFlow 数据集 ID 列表(逗号分隔) - RAGFLOW_CREDENTIAL_NAME: RagFlow API Key 凭证名称 + +ADB 知识库额外配置 / Additional config for ADB: +- ADB_INSTANCE_ID: ADB 实例 ID +- ADB_NAMESPACE: ADB 命名空间 +- ADB_NAMESPACE_PASSWORD: ADB 命名空间密码 +- ADB_COLLECTION: ADB 文档集合名称 """ import json @@ -32,6 +38,8 @@ import time from agentrun.knowledgebase import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBase, @@ -100,6 +108,34 @@ "RAGFLOW_CREDENTIAL_NAME", "ragflow-api-key" ) +# ----------------------------------------------------------------------------- +# ADB 知识库配置 / ADB Knowledge Base Configuration +# ----------------------------------------------------------------------------- + +# ADB 知识库名称 +# ADB knowledge base name +ADB_KB_NAME = f"sdk-test-adb-kb-{TIMESTAMP}" + +# ADB 实例 ID,请替换为您的实际值 +# ADB instance ID, please replace with your actual value +ADB_INSTANCE_ID = os.getenv("ADB_INSTANCE_ID", "gp-your-instance-id") + +# ADB 命名空间,默认为 public +# ADB namespace, defaults to public +ADB_NAMESPACE = os.getenv("ADB_NAMESPACE", "public") + +# ADB 命名空间密码,请替换为您的实际值 +# ADB namespace password, please replace with your actual value +ADB_NAMESPACE_PASSWORD = os.getenv("ADB_NAMESPACE_PASSWORD", "your-password") + +# ADB 文档集合名称,请替换为您的实际值 +# ADB collection name, please replace with your actual value +ADB_COLLECTION = os.getenv("ADB_COLLECTION", "your-collection") + +# ADB 向量化模型名称(可选) +# ADB embedding model name (optional) +ADB_EMBEDDING_MODEL = os.getenv("ADB_EMBEDDING_MODEL", "text-embedding-v3") + # ============================================================================ # 客户端初始化 / Client Initialization # ============================================================================ @@ -362,6 +398,153 @@ def delete_ragflow_kb(kb: KnowledgeBase): ) +# ============================================================================ +# ADB 知识库示例函数 / ADB Knowledge Base Example Functions +# ============================================================================ + + +def create_or_get_adb_kb() -> KnowledgeBase: + """创建或获取已有的 ADB 知识库 / Create or get existing ADB knowledge base + + Returns: + KnowledgeBase: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("创建或获取 ADB 知识库") + logger.info("Create or get ADB knowledge base") + logger.info("=" * 60) + + try: + # 创建 ADB 知识库 / Create ADB knowledge base + kb = KnowledgeBase.create( + KnowledgeBaseCreateInput( + knowledge_base_name=ADB_KB_NAME, + description=( + "通过 SDK 创建的 ADB 知识库示例 / ADB KB example" + " created via SDK" + ), + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id=ADB_INSTANCE_ID, + namespace=ADB_NAMESPACE, + namespace_password=ADB_NAMESPACE_PASSWORD, + embedding_model=ADB_EMBEDDING_MODEL, + metrics="cosine", # 使用余弦相似度 / Use cosine similarity + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + use_full_text_retrieval=False, # 仅使用向量检索 / Vector only + rerank_factor=2.0, # 重排序因子 / Rerank factor + ), + ) + ) + logger.info("✅ ADB 知识库创建成功 / ADB KB created successfully") + + except ResourceAlreadyExistError: + logger.info( + "ℹ️ ADB 知识库已存在,获取已有资源 / ADB KB exists, getting" + " existing" + ) + kb = client.get(ADB_KB_NAME) + + _log_kb_info(kb) + return kb + + +def query_adb_kb(kb: KnowledgeBase): + """查询 ADB 知识库 / Query ADB knowledge base + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("查询 ADB 知识库") + logger.info("Query ADB knowledge base") + logger.info("=" * 60) + + query_text = "什么是云原生数据库" + logger.info("查询文本 / Query text: %s", query_text) + + try: + results = kb.retrieve(query=query_text) + logger.info("✅ 查询成功 / Query successful") + logger.info("检索结果 / Retrieval results: %s", results) + logger.info( + " - 结果数量 / Result count: %s", len(results.get("data", [])) + ) + except Exception as e: + logger.warning("⚠️ 查询失败(可能是配置或连接问题): %s", e) + + +def query_adb_kb_by_name(knowledgebase_name: str): + """查询 ADB 知识库 / Query ADB knowledge base + Args: + knowledgebase_name: 知识库名称 / Knowledge base name + """ + + try: + kb = KnowledgeBase.get_by_name(knowledgebase_name) + results = kb.retrieve(query="什么是云原生数据库") + logger.info("✅ 查询成功 / Query successful") + logger.info("检索结果 / Retrieval results: %s", results) + logger.info( + " - 结果数量 / Result count: %s", len(results.get("data", [])) + ) + except Exception as e: + logger.warning("⚠️ 查询失败(可能是配置或连接问题): %s", e) + + +def update_adb_kb(kb: KnowledgeBase): + """更新 ADB 知识库配置 / Update ADB knowledge base configuration + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("更新 ADB 知识库配置") + logger.info("Update ADB knowledge base configuration") + logger.info("=" * 60) + + new_description = f"[ADB] 更新于 {time.strftime('%Y-%m-%d %H:%M:%S')}" + + kb.update( + KnowledgeBaseUpdateInput( + description=new_description, + retrieve_settings=ADBRetrieveSettings( + top_k=20, # 增加返回数量 / Increase result count + use_full_text_retrieval=True, # 启用双路召回 / Enable dual recall + rerank_factor=3.0, # 调整重排序因子 / Adjust rerank factor + hybrid_search="RRF", # 使用 RRF 混合检索 / Use RRF hybrid search + hybrid_search_args={"RRF": {"k": 60}}, + ), + ) + ) + + logger.info("✅ ADB 知识库更新成功 / ADB KB updated successfully") + logger.info(" - 新描述 / New description: %s", kb.description) + + +def delete_adb_kb(kb: KnowledgeBase): + """删除 ADB 知识库 / Delete ADB knowledge base + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("删除 ADB 知识库") + logger.info("Delete ADB knowledge base") + logger.info("=" * 60) + + kb.delete() + logger.info("✅ ADB 知识库删除请求已发送 / ADB KB delete request sent") + + try: + client.get(ADB_KB_NAME) + logger.warning("⚠️ ADB 知识库仍然存在 / ADB KB still exists") + except ResourceNotExistError: + logger.info("✅ ADB 知识库已成功删除 / ADB KB deleted successfully") + + # ============================================================================ # 通用工具函数 / Common Utility Functions # ============================================================================ @@ -404,8 +587,10 @@ def list_knowledge_bases(): ragflow_list = KnowledgeBase.list_all( provider=KnowledgeBaseProvider.RAGFLOW.value ) + adb_list = KnowledgeBase.list_all(provider=KnowledgeBaseProvider.ADB.value) logger.info(" - 百炼知识库 / Bailian KBs: %d 个", len(bailian_list)) logger.info(" - RagFlow 知识库 / RagFlow KBs: %d 个", len(ragflow_list)) + logger.info(" - ADB 知识库 / ADB KBs: %d 个", len(adb_list)) # ============================================================================ @@ -457,6 +642,28 @@ def ragflow_example(): logger.info("") +def adb_example(): + """ADB 知识库完整示例 / Complete ADB knowledge base example""" + logger.info("") + logger.info("🔹 ADB 知识库示例 / ADB Knowledge Base Example") + logger.info("=" * 60) + + # 创建 ADB 知识库 / Create ADB KB + kb = create_or_get_adb_kb() + + # 查询 ADB 知识库 / Query ADB KB + query_adb_kb(kb) + + # 更新 ADB 知识库 / Update ADB KB + update_adb_kb(kb) + + # 删除 ADB 知识库 / Delete ADB KB + delete_adb_kb(kb) + + logger.info("🔹 ADB 知识库示例完成 / ADB KB Example Complete") + logger.info("") + + def knowledgebase_example(): """知识库模块完整示例 / Complete knowledge base module example @@ -501,6 +708,15 @@ def ragflow_only_example(): logger.info("🎉 完成 / Complete") +def adb_only_example(): + """仅运行 ADB 知识库示例 / Run ADB knowledge base example only""" + logger.info("🚀 ADB 知识库示例 / ADB KB Example") + list_knowledge_bases() + adb_example() + list_knowledge_bases() + logger.info("🎉 完成 / Complete") + + def multiple_knowledgebase_query(): """多知识库检索 / Multi knowledge base retrieval 根据知识库名称列表进行检索,自动获取各知识库的配置并执行检索。 @@ -509,7 +725,7 @@ def multiple_knowledgebase_query(): """ multi_query_result = KnowledgeBase.multi_retrieve( query="什么是Serverless", - knowledge_base_names=["ragflow-test", "jingsu-bailian"], + knowledge_base_names=["jingsu-bailian", "logantest"], ) logger.info( "多知识库检索结果 / Multi knowledge base retrieval result:\n%s", @@ -536,5 +752,7 @@ def update_ragflow_kb_config(): if __name__ == "__main__": # bailian_only_example() # ragflow_only_example() + # adb_only_example() multiple_knowledgebase_query() + # query_adb_kb_by_name("") # update_ragflow_kb_config() diff --git a/pyproject.toml b/pyproject.toml index effdf68..251c959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "alibabacloud_tea_openapi>=0.4.2", "alibabacloud_bailian20231229>=2.6.2", "agentrun-mem0ai>=0.0.10", + "alibabacloud_gpdb20160503>=5.0.1" ] [project.optional-dependencies] diff --git a/tests/unittests/knowledgebase/__init__.py b/tests/unittests/knowledgebase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/knowledgebase/api/__init__.py b/tests/unittests/knowledgebase/api/__init__.py new file mode 100644 index 0000000..24785e9 --- /dev/null +++ b/tests/unittests/knowledgebase/api/__init__.py @@ -0,0 +1 @@ +"""KnowledgeBase API 单元测试模块 / KnowledgeBase API Unit Test Module""" diff --git a/tests/unittests/knowledgebase/api/test_data.py b/tests/unittests/knowledgebase/api/test_data.py new file mode 100644 index 0000000..8e15902 --- /dev/null +++ b/tests/unittests/knowledgebase/api/test_data.py @@ -0,0 +1,1594 @@ +"""测试 agentrun.knowledgebase.api.data 模块 / Test agentrun.knowledgebase.api.data module""" + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.api.data import ( + ADBDataAPI, + BailianDataAPI, + get_data_api, + KnowledgeBaseDataAPI, + RagFlowDataAPI, +) +from agentrun.knowledgebase.model import ( + ADBProviderSettings, + ADBRetrieveSettings, + BailianProviderSettings, + BailianRetrieveSettings, + KnowledgeBaseProvider, + RagFlowProviderSettings, + RagFlowRetrieveSettings, +) +from agentrun.utils.config import Config + + +class TestKnowledgeBaseDataAPIBase: + """测试 KnowledgeBaseDataAPI 基类""" + + def test_abstract_methods(self): + """测试抽象方法""" + # KnowledgeBaseDataAPI 是抽象类,不能直接实例化 + with pytest.raises(TypeError): + KnowledgeBaseDataAPI("test-kb") # type: ignore + + +class TestRagFlowDataAPIInit: + """测试 RagFlowDataAPI 初始化""" + + def test_init(self): + """测试初始化""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + credential_name="test-credential", + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is not None + assert api.retrieve_settings is not None + assert api.credential_name == "test-credential" + + def test_init_minimal(self): + """测试最小化初始化""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is None + assert api.retrieve_settings is None + assert api.credential_name is None + + def test_init_with_config(self): + """测试带配置初始化""" + config = Config(access_key_id="test-ak") + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + config=config, + ) + + assert api.knowledge_base_name == "test-kb" + + +class TestRagFlowDataAPIGetApiKey: + """测试 RagFlowDataAPI._get_api_key 方法""" + + @patch("agentrun.credential.Credential.get_by_name") + def test_get_api_key_sync(self, mock_get_credential): + """测试同步获取 API Key""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + result = api._get_api_key() + assert result == "test-api-key" + + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_get_api_key_async(self, mock_get_credential): + """测试异步获取 API Key""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + result = await api._get_api_key_async() + assert result == "test-api-key" + + def test_get_api_key_without_credential_name(self): + """测试无凭证名称时获取 API Key""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="credential_name is required"): + api._get_api_key() + + @pytest.mark.asyncio + async def test_get_api_key_async_without_credential_name(self): + """测试异步无凭证名称时获取 API Key""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="credential_name is required"): + await api._get_api_key_async() + + @patch("agentrun.credential.Credential.get_by_name") + def test_get_api_key_empty_secret(self, mock_get_credential): + """测试凭证密钥为空时获取 API Key""" + mock_credential = MagicMock() + mock_credential.credential_secret = None + mock_get_credential.return_value = mock_credential + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + with pytest.raises(ValueError, match="has no secret configured"): + api._get_api_key() + + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_get_api_key_async_empty_secret(self, mock_get_credential): + """测试异步凭证密钥为空时获取 API Key""" + mock_credential = MagicMock() + mock_credential.credential_secret = None + mock_get_credential.return_value = mock_credential + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + with pytest.raises(ValueError, match="has no secret configured"): + await api._get_api_key_async() + + +class TestRagFlowDataAPIBuildRequestBody: + """测试 RagFlowDataAPI._build_request_body 方法""" + + def test_build_request_body(self): + """测试构建请求体""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1", "ds-2"], + ), + ) + + body = api._build_request_body("test query") + assert body["question"] == "test query" + assert body["dataset_ids"] == ["ds-1", "ds-2"] + assert body["page"] == 1 + assert body["page_size"] == 30 + + def test_build_request_body_with_retrieve_settings(self): + """测试带检索设置构建请求体""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + vector_similarity_weight=0.5, + cross_languages=["English", "Chinese"], + ), + ) + + body = api._build_request_body("test query") + assert body["similarity_threshold"] == 0.8 + assert body["vector_similarity_weight"] == 0.5 + assert body["cross_languages"] == ["English", "Chinese"] + + def test_build_request_body_without_provider_settings(self): + """测试无提供商设置构建请求体""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="provider_settings is required"): + api._build_request_body("test query") + + def test_build_request_body_with_partial_retrieve_settings_only_threshold( + self, + ): + """测试仅设置 similarity_threshold 的部分检索设置""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + ) + + body = api._build_request_body("test query") + assert body["similarity_threshold"] == 0.8 + assert "vector_similarity_weight" not in body + assert "cross_languages" not in body + + def test_build_request_body_with_partial_retrieve_settings_only_weight( + self, + ): + """测试仅设置 vector_similarity_weight 的部分检索设置""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + vector_similarity_weight=0.5, + ), + ) + + body = api._build_request_body("test query") + assert "similarity_threshold" not in body + assert body["vector_similarity_weight"] == 0.5 + assert "cross_languages" not in body + + def test_build_request_body_with_partial_retrieve_settings_only_languages( + self, + ): + """测试仅设置 cross_languages 的部分检索设置""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + cross_languages=["English"], + ), + ) + + body = api._build_request_body("test query") + assert "similarity_threshold" not in body + assert "vector_similarity_weight" not in body + assert body["cross_languages"] == ["English"] + + def test_build_request_body_without_retrieve_settings(self): + """测试无检索设置构建请求体""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + body = api._build_request_body("test query") + assert "similarity_threshold" not in body + assert "vector_similarity_weight" not in body + assert "cross_languages" not in body + + +class TestRagFlowDataAPIRetrieve: + """测试 RagFlowDataAPI.retrieve 方法""" + + @patch("httpx.Client") + @patch("agentrun.credential.Credential.get_by_name") + def test_retrieve_sync(self, mock_get_credential, mock_httpx_client): + """测试同步检索""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_response = MagicMock() + mock_response.json.return_value = { + "data": {"chunks": [{"content": "test content"}]} + } + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post.return_value = mock_response + mock_client_instance.__enter__ = MagicMock( + return_value=mock_client_instance + ) + mock_client_instance.__exit__ = MagicMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + assert "data" in result + + @patch("httpx.AsyncClient") + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_retrieve_async(self, mock_get_credential, mock_httpx_client): + """测试异步检索""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_response = MagicMock() + mock_response.json.return_value = { + "data": {"chunks": [{"content": "test content"}]} + } + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + def test_retrieve_without_provider_settings(self): + """测试无提供商设置检索""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @patch("httpx.Client") + @patch("agentrun.credential.Credential.get_by_name") + def test_retrieve_with_false_data( + self, mock_get_credential, mock_httpx_client + ): + """测试检索返回 False 数据""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_response = MagicMock() + mock_response.json.return_value = {"data": False} + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post.return_value = mock_response + mock_client_instance.__enter__ = MagicMock( + return_value=mock_client_instance + ) + mock_client_instance.__exit__ = MagicMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @pytest.mark.asyncio + async def test_retrieve_async_without_provider_settings(self): + """测试异步检索无提供商设置""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + + @patch("httpx.AsyncClient") + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_retrieve_async_with_false_data( + self, mock_get_credential, mock_httpx_client + ): + """测试异步检索返回 False 数据""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_response = MagicMock() + mock_response.json.return_value = {"data": False} + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + + @patch("httpx.AsyncClient") + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_retrieve_async_exception( + self, mock_get_credential, mock_httpx_client + ): + """测试异步检索异常处理""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock( + side_effect=Exception("Connection error") + ) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + assert "Connection error" in result["data"] + + +class TestBailianDataAPIInit: + """测试 BailianDataAPI 初始化""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_init(self): + """测试初始化""" + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is not None + assert api.retrieve_settings is not None + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_init_minimal(self): + """测试最小化初始化""" + api = BailianDataAPI( + knowledge_base_name="test-kb", + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is None + + +class TestBailianDataAPIRetrieve: + """测试 BailianDataAPI.retrieve 方法""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_sync(self, mock_get_bailian_client): + """测试同步检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [ + MagicMock(text="test content", score=0.9, metadata={}), + ] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async(self, mock_get_bailian_client): + """测试异步检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [ + MagicMock(text="test content", score=0.9, metadata={}), + ] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_retrieve_without_provider_settings(self): + """测试无提供商设置检索""" + api = BailianDataAPI( + knowledge_base_name="test-kb", + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @pytest.mark.asyncio + async def test_retrieve_async_without_provider_settings(self): + """测试异步检索无提供商设置""" + api = BailianDataAPI( + knowledge_base_name="test-kb", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_with_retrieve_settings(self, mock_get_bailian_client): + """测试带检索设置检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + sparse_similarity_top_k=5, + rerank_min_score=0.5, + rerank_top_n=3, + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_with_partial_retrieve_settings_dense_only( + self, mock_get_bailian_client + ): + """测试仅设置 dense_similarity_top_k 的部分检索设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_with_partial_retrieve_settings_sparse_only( + self, mock_get_bailian_client + ): + """测试仅设置 sparse_similarity_top_k 的部分检索设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + sparse_similarity_top_k=5, + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_with_partial_retrieve_settings_rerank_only( + self, mock_get_bailian_client + ): + """测试仅设置 rerank 相关的部分检索设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + rerank_min_score=0.5, + rerank_top_n=3, + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_with_partial_settings( + self, mock_get_bailian_client + ): + """测试异步检索带部分设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_with_sparse_only_settings( + self, mock_get_bailian_client + ): + """测试异步检索仅设置 sparse_similarity_top_k""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + sparse_similarity_top_k=5, + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_with_rerank_settings( + self, mock_get_bailian_client + ): + """测试异步检索仅设置 rerank 相关参数""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + rerank_min_score=0.5, + rerank_top_n=3, + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_exception(self, mock_get_bailian_client): + """测试异步检索异常处理""" + mock_client = MagicMock() + mock_client.retrieve_async = AsyncMock( + side_effect=Exception("API error") + ) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + assert "API error" in result["data"] + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_with_all_settings( + self, mock_get_bailian_client + ): + """测试异步检索带完整设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + sparse_similarity_top_k=5, + rerank_min_score=0.5, + rerank_top_n=3, + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + + +class TestADBDataAPIInit: + """测试 ADBDataAPI 初始化""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_init(self): + """测试初始化""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + ), + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is not None + assert api.retrieve_settings is not None + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_init_minimal(self): + """测试最小化初始化""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is None + + +class TestADBDataAPIBuildQueryContentRequest: + """测试 ADBDataAPI._build_query_content_request 方法""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request(self): + """测试构建 QueryContent 请求""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + request = api._build_query_content_request("test query") + assert request.content == "test query" + assert request.dbinstance_id == "gp-123456" + assert request.namespace == "public" + assert request.collection == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_settings(self): + """测试带设置构建 QueryContent 请求""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + metrics="cosine", + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + use_full_text_retrieval=True, + rerank_factor=1.5, + recall_window=[-5, 5], + hybrid_search="RRF", + hybrid_search_args={"RRF": {"k": 60}}, + ), + ) + + request = api._build_query_content_request("test query") + assert request.metrics == "cosine" + assert request.top_k == 10 + assert request.use_full_text_retrieval is True + assert request.rerank_factor == 1.5 + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_without_provider_settings(self): + """测试无提供商设置构建 QueryContent 请求""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="provider_settings is required"): + api._build_query_content_request("test query") + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_top_k_only(self): + """测试仅设置 top_k 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + ), + ) + + request = api._build_query_content_request("test query") + assert request.top_k == 10 + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_full_text_only( + self, + ): + """测试仅设置 use_full_text_retrieval 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + use_full_text_retrieval=True, + ), + ) + + request = api._build_query_content_request("test query") + assert request.use_full_text_retrieval is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_rerank_only( + self, + ): + """测试仅设置 rerank_factor 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + rerank_factor=1.5, + ), + ) + + request = api._build_query_content_request("test query") + assert request.rerank_factor == 1.5 + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_recall_window_only( + self, + ): + """测试仅设置 recall_window 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + recall_window=[-5, 5], + ), + ) + + request = api._build_query_content_request("test query") + assert request.recall_window == [-5, 5] + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_hybrid_only( + self, + ): + """测试仅设置 hybrid_search 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + hybrid_search="RRF", + ), + ) + + request = api._build_query_content_request("test query") + assert request.hybrid_search == "RRF" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_hybrid_args_only( + self, + ): + """测试仅设置 hybrid_search_args 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + hybrid_search_args={"RRF": {"k": 60}}, + ), + ) + + request = api._build_query_content_request("test query") + assert request.hybrid_search_args == {"RRF": {"k": 60}} + + +class TestADBDataAPIParseQueryContentResponse: + """测试 ADBDataAPI._parse_query_content_response 方法""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_parse_query_content_response(self): + """测试解析 QueryContent 响应""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + mock_response = MagicMock() + mock_response.body.matches.match_list = [ + MagicMock( + content="test content", + score=0.9, + id="1", + file_name="test.txt", + metadata={}, + rerank_score=0.95, + retrieval_source="vector", + ), + ] + mock_response.body.request_id = "req-123" + + result = api._parse_query_content_response(mock_response, "test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + assert result["request_id"] == "req-123" + assert len(result["data"]) == 1 + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_parse_query_content_response_empty(self): + """测试解析空 QueryContent 响应""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + mock_response = MagicMock() + mock_response.body.matches = None + mock_response.body.request_id = "req-123" + + result = api._parse_query_content_response(mock_response, "test query") + assert result["data"] == [] + + +class TestADBDataAPIRetrieve: + """测试 ADBDataAPI.retrieve 方法""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_gpdb_client") + def test_retrieve_sync(self, mock_get_gpdb_client): + """测试同步检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.matches.match_list = [ + MagicMock( + content="test content", + score=0.9, + id="1", + file_name="test.txt", + metadata={}, + rerank_score=0.95, + retrieval_source="vector", + ), + ] + mock_response.body.request_id = "req-123" + mock_client.query_content.return_value = mock_response + mock_get_gpdb_client.return_value = mock_client + + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_gpdb_client") + @pytest.mark.asyncio + async def test_retrieve_async(self, mock_get_gpdb_client): + """测试异步检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.matches.match_list = [ + MagicMock( + content="test content", + score=0.9, + id="1", + file_name="test.txt", + metadata={}, + rerank_score=0.95, + retrieval_source="vector", + ), + ] + mock_response.body.request_id = "req-123" + mock_client.query_content_async = AsyncMock(return_value=mock_response) + mock_get_gpdb_client.return_value = mock_client + + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_retrieve_without_provider_settings(self): + """测试无提供商设置检索""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_gpdb_client") + def test_retrieve_exception(self, mock_get_gpdb_client): + """测试检索异常处理""" + mock_client = MagicMock() + mock_client.query_content.side_effect = Exception("Query failed") + mock_get_gpdb_client.return_value = mock_client + + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_gpdb_client") + @pytest.mark.asyncio + async def test_retrieve_async_exception(self, mock_get_gpdb_client): + """测试异步检索异常处理""" + mock_client = MagicMock() + mock_client.query_content_async = AsyncMock( + side_effect=Exception("Query failed") + ) + mock_get_gpdb_client.return_value = mock_client + + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + assert "Query failed" in result["data"] + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @pytest.mark.asyncio + async def test_retrieve_async_without_provider_settings(self): + """测试异步检索无提供商设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + + +class TestGetDataAPI: + """测试 get_data_api 工厂函数""" + + def test_get_data_api_ragflow(self): + """测试获取 RagFlow 数据链路 API""" + api = get_data_api( + provider=KnowledgeBaseProvider.RAGFLOW, + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + assert isinstance(api, RagFlowDataAPI) + + def test_get_data_api_ragflow_string(self): + """测试使用字符串获取 RagFlow 数据链路 API""" + api = get_data_api( + provider="ragflow", # type: ignore + knowledge_base_name="test-kb", + ) + + assert isinstance(api, RagFlowDataAPI) + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_get_data_api_bailian(self): + """测试获取百炼数据链路 API""" + api = get_data_api( + provider=KnowledgeBaseProvider.BAILIAN, + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + assert isinstance(api, BailianDataAPI) + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_get_data_api_bailian_string(self): + """测试使用字符串获取百炼数据链路 API""" + api = get_data_api( + provider="bailian", # type: ignore + knowledge_base_name="test-kb", + ) + + assert isinstance(api, BailianDataAPI) + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_get_data_api_adb(self): + """测试获取 ADB 数据链路 API""" + api = get_data_api( + provider=KnowledgeBaseProvider.ADB, + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + assert isinstance(api, ADBDataAPI) + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_get_data_api_adb_string(self): + """测试使用字符串获取 ADB 数据链路 API""" + api = get_data_api( + provider="adb", # type: ignore + knowledge_base_name="test-kb", + ) + + assert isinstance(api, ADBDataAPI) + + def test_get_data_api_unsupported_provider(self): + """测试不支持的提供商""" + with pytest.raises(ValueError, match="Unsupported provider type"): + get_data_api( + provider="unsupported", # type: ignore + knowledge_base_name="test-kb", + ) + + def test_get_data_api_with_wrong_settings_type(self): + """测试使用错误类型的设置""" + # 传入 Bailian 设置给 RagFlow + api = get_data_api( + provider=KnowledgeBaseProvider.RAGFLOW, + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + # 应该返回 RagFlowDataAPI,但 provider_settings 会是 None + assert isinstance(api, RagFlowDataAPI) + assert api.provider_settings is None + + def test_get_data_api_with_config(self): + """测试带配置获取数据链路 API""" + config = Config(access_key_id="test-ak") + api = get_data_api( + provider=KnowledgeBaseProvider.RAGFLOW, + knowledge_base_name="test-kb", + config=config, + ) + + assert isinstance(api, RagFlowDataAPI) + + def test_get_data_api_with_retrieve_settings(self): + """测试带检索设置获取数据链路 API""" + api = get_data_api( + provider=KnowledgeBaseProvider.RAGFLOW, + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + credential_name="test-credential", + ) + + assert isinstance(api, RagFlowDataAPI) + assert api.retrieve_settings is not None diff --git a/tests/unittests/knowledgebase/test_client.py b/tests/unittests/knowledgebase/test_client.py new file mode 100644 index 0000000..9601c6c --- /dev/null +++ b/tests/unittests/knowledgebase/test_client.py @@ -0,0 +1,639 @@ +"""测试 agentrun.knowledgebase.client 模块 / Test agentrun.knowledgebase.client module""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.client import KnowledgeBaseClient +from agentrun.knowledgebase.model import ( + ADBProviderSettings, + BailianProviderSettings, + KnowledgeBaseCreateInput, + KnowledgeBaseListInput, + KnowledgeBaseProvider, + KnowledgeBaseUpdateInput, + RagFlowProviderSettings, +) +from agentrun.utils.config import Config +from agentrun.utils.exception import ( + HTTPError, + ResourceAlreadyExistError, + ResourceNotExistError, +) + + +class MockKnowledgeBaseData: + """模拟知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-123", + "knowledgeBaseName": "test-kb", + "provider": "ragflow", + "description": "Test knowledge base", + "credentialName": "test-credential", + "providerSettings": { + "baseUrl": "https://ragflow.example.com", + "datasetIds": ["ds-1"], + }, + "retrieveSettings": { + "similarityThreshold": 0.8, + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockBailianKnowledgeBaseData: + """模拟百炼知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-456", + "knowledgeBaseName": "test-bailian-kb", + "provider": "bailian", + "description": "Test Bailian knowledge base", + "providerSettings": { + "workspaceId": "ws-123", + "indexIds": ["idx-1"], + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockADBKnowledgeBaseData: + """模拟 ADB 知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-789", + "knowledgeBaseName": "test-adb-kb", + "provider": "adb", + "description": "Test ADB knowledge base", + "providerSettings": { + "DBInstanceId": "gp-123456", + "Namespace": "public", + "NamespacePassword": "password123", + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockListResult: + """模拟列表结果""" + + def __init__(self, items): + self.items = items + + +class TestKnowledgeBaseClientInit: + """测试 KnowledgeBaseClient 初始化""" + + def test_init_without_config(self): + """测试不带配置的初始化""" + client = KnowledgeBaseClient() + assert client is not None + + def test_init_with_config(self): + """测试带配置的初始化""" + config = Config(access_key_id="test-ak") + client = KnowledgeBaseClient(config=config) + assert client is not None + + +class TestKnowledgeBaseClientCreate: + """测试 KnowledgeBaseClient.create 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_sync(self, mock_control_api_class): + """测试同步创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + description="Test knowledge base", + ) + + result = client.create(input_obj) + assert result.knowledge_base_name == "test-kb" + assert mock_control_api.create_knowledge_base.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_create_async(self, mock_control_api_class): + """测试异步创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + result = await client.create_async(input_obj) + assert result.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_bailian_kb(self, mock_control_api_class): + """测试创建百炼知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockBailianKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-bailian-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + result = client.create(input_obj) + assert result.knowledge_base_name == "test-bailian-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_adb_kb(self, mock_control_api_class): + """测试创建 ADB 知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockADBKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-adb-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = client.create(input_obj) + assert result.knowledge_base_name == "test-adb-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_with_config(self, mock_control_api_class): + """测试带配置创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + config = Config(access_key_id="custom-ak") + + result = client.create(input_obj, config=config) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_already_exists(self, mock_control_api_class): + """测试创建已存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.side_effect = HTTPError( + status_code=409, + message="Resource already exists", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="existing-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + with pytest.raises(ResourceAlreadyExistError): + client.create(input_obj) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_create_async_already_exists(self, mock_control_api_class): + """测试异步创建已存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base_async = AsyncMock( + side_effect=HTTPError( + status_code=409, + message="Resource already exists", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="existing-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + with pytest.raises(ResourceAlreadyExistError): + await client.create_async(input_obj) + + +class TestKnowledgeBaseClientDelete: + """测试 KnowledgeBaseClient.delete 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_sync(self, mock_control_api_class): + """测试同步删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.delete("test-kb") + assert result is not None + assert mock_control_api.delete_knowledge_base.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_delete_async(self, mock_control_api_class): + """测试异步删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = await client.delete_async("test-kb") + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_with_config(self, mock_control_api_class): + """测试带配置删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + config = Config(access_key_id="custom-ak") + result = client.delete("test-kb", config=config) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_not_exist(self, mock_control_api_class): + """测试删除不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + with pytest.raises(ResourceNotExistError): + client.delete("nonexistent-kb") + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_delete_async_not_exist(self, mock_control_api_class): + """测试异步删除不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base_async = AsyncMock( + side_effect=HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + with pytest.raises(ResourceNotExistError): + await client.delete_async("nonexistent-kb") + + +class TestKnowledgeBaseClientUpdate: + """测试 KnowledgeBaseClient.update 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_sync(self, mock_control_api_class): + """测试同步更新知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput(description="Updated description") + result = client.update("test-kb", input_obj) + assert result is not None + assert mock_control_api.update_knowledge_base.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_update_async(self, mock_control_api_class): + """测试异步更新知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + result = await client.update_async("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_with_provider_settings(self, mock_control_api_class): + """测试更新知识库(带提供商设置)""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput( + provider_settings=RagFlowProviderSettings( + base_url="https://new-ragflow.example.com", + dataset_ids=["ds-new"], + ), + ) + result = client.update("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_with_credential(self, mock_control_api_class): + """测试更新知识库(带凭证)""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput( + credential_name="new-credential", + ) + result = client.update("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_not_exist(self, mock_control_api_class): + """测试更新不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + with pytest.raises(ResourceNotExistError): + client.update("nonexistent-kb", input_obj) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_update_async_not_exist(self, mock_control_api_class): + """测试异步更新不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base_async = AsyncMock( + side_effect=HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + with pytest.raises(ResourceNotExistError): + await client.update_async("nonexistent-kb", input_obj) + + +class TestKnowledgeBaseClientGet: + """测试 KnowledgeBaseClient.get 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_sync(self, mock_control_api_class): + """测试同步获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.get("test-kb") + assert result.knowledge_base_name == "test-kb" + assert mock_control_api.get_knowledge_base.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_get_async(self, mock_control_api_class): + """测试异步获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = await client.get_async("test-kb") + assert result.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_with_config(self, mock_control_api_class): + """测试带配置获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + config = Config(access_key_id="custom-ak") + result = client.get("test-kb", config=config) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_not_exist(self, mock_control_api_class): + """测试获取不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + with pytest.raises(ResourceNotExistError): + client.get("nonexistent-kb") + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_get_async_not_exist(self, mock_control_api_class): + """测试异步获取不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + side_effect=HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + with pytest.raises(ResourceNotExistError): + await client.get_async("nonexistent-kb") + + +class TestKnowledgeBaseClientList: + """测试 KnowledgeBaseClient.list 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_sync(self, mock_control_api_class): + """测试同步列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult([ + MockKnowledgeBaseData(), + MockBailianKnowledgeBaseData(), + ]) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.list() + assert len(result) == 2 + assert mock_control_api.list_knowledge_bases.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_list_async(self, mock_control_api_class): + """测试异步列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases_async = AsyncMock( + return_value=MockListResult([MockKnowledgeBaseData()]) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = await client.list_async() + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_with_input(self, mock_control_api_class): + """测试同步列出知识库(带输入参数)""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult( + [MockKnowledgeBaseData()] + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseListInput( + page_number=1, + page_size=10, + provider=KnowledgeBaseProvider.RAGFLOW, + ) + result = client.list(input=input_obj) + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_list_async_with_input(self, mock_control_api_class): + """测试异步列出知识库(带输入参数)""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases_async = AsyncMock( + return_value=MockListResult([MockKnowledgeBaseData()]) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseListInput(page_number=1, page_size=10) + result = await client.list_async(input=input_obj) + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_empty(self, mock_control_api_class): + """测试列出空知识库列表""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult([]) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.list() + assert len(result) == 0 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_with_none_input(self, mock_control_api_class): + """测试列出知识库(输入为 None)""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult( + [MockKnowledgeBaseData()] + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.list(input=None) + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_with_config(self, mock_control_api_class): + """测试带配置列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult( + [MockKnowledgeBaseData()] + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + config = Config(access_key_id="custom-ak") + result = client.list(config=config) + assert len(result) == 1 diff --git a/tests/unittests/knowledgebase/test_knowledgebase.py b/tests/unittests/knowledgebase/test_knowledgebase.py new file mode 100644 index 0000000..8936d9d --- /dev/null +++ b/tests/unittests/knowledgebase/test_knowledgebase.py @@ -0,0 +1,1306 @@ +"""测试 agentrun.knowledgebase.knowledgebase 模块 / Test agentrun.knowledgebase.knowledgebase module""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.knowledgebase import KnowledgeBase +from agentrun.knowledgebase.model import ( + ADBProviderSettings, + ADBRetrieveSettings, + BailianProviderSettings, + BailianRetrieveSettings, + KnowledgeBaseCreateInput, + KnowledgeBaseProvider, + KnowledgeBaseUpdateInput, + RagFlowProviderSettings, + RagFlowRetrieveSettings, +) +from agentrun.utils.config import Config + + +class MockKnowledgeBaseData: + """模拟知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-123", + "knowledgeBaseName": "test-kb", + "provider": "ragflow", + "description": "Test knowledge base", + "credentialName": "test-credential", + "providerSettings": { + "baseUrl": "https://ragflow.example.com", + "datasetIds": ["ds-1"], + }, + "retrieveSettings": { + "similarityThreshold": 0.8, + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockBailianKnowledgeBaseData: + """模拟百炼知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-456", + "knowledgeBaseName": "test-bailian-kb", + "provider": "bailian", + "description": "Test Bailian knowledge base", + "providerSettings": { + "workspaceId": "ws-123", + "indexIds": ["idx-1"], + }, + "retrieveSettings": { + "denseSimilarityTopK": 10, + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockADBKnowledgeBaseData: + """模拟 ADB 知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-789", + "knowledgeBaseName": "test-adb-kb", + "provider": "adb", + "description": "Test ADB knowledge base", + "providerSettings": { + "DBInstanceId": "gp-123456", + "Namespace": "public", + "NamespacePassword": "password123", + }, + "retrieveSettings": { + "TopK": 10, + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockListResult: + """模拟列表结果""" + + def __init__(self, items): + self.items = items + + +class TestKnowledgeBaseCreate: + """测试 KnowledgeBase.create 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_sync(self, mock_control_api_class): + """测试同步创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + description="Test knowledge base", + ) + + result = KnowledgeBase.create(input_obj) + assert result.knowledge_base_name == "test-kb" + assert result.knowledge_base_id == "kb-123" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_create_async(self, mock_control_api_class): + """测试异步创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + result = await KnowledgeBase.create_async(input_obj) + assert result.knowledge_base_name == "test-kb" + + +class TestKnowledgeBaseDelete: + """测试 KnowledgeBase.delete 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_by_name_sync(self, mock_control_api_class): + """测试根据名称同步删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase.delete_by_name("test-kb") + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_delete_by_name_async(self, mock_control_api_class): + """测试根据名称异步删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase.delete_by_name_async("test-kb") + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_instance_sync(self, mock_control_api_class): + """测试实例同步删除""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = KnowledgeBase.get_by_name("test-kb") + + # 删除实例 + result = kb.delete() + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_delete_instance_async(self, mock_control_api_class): + """测试实例异步删除""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = await KnowledgeBase.get_by_name_async("test-kb") + + # 删除实例 + result = await kb.delete_async() + assert result is not None + + def test_delete_instance_without_name(self): + """测试实例删除(无名称)""" + kb = KnowledgeBase() + with pytest.raises(ValueError, match="knowledge_base_name is required"): + kb.delete() + + @pytest.mark.asyncio + async def test_delete_instance_async_without_name(self): + """测试异步实例删除(无名称)""" + kb = KnowledgeBase() + with pytest.raises(ValueError, match="knowledge_base_name is required"): + await kb.delete_async() + + +class TestKnowledgeBaseUpdate: + """测试 KnowledgeBase.update 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_by_name_sync(self, mock_control_api_class): + """测试根据名称同步更新知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = KnowledgeBaseUpdateInput(description="Updated description") + result = KnowledgeBase.update_by_name("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_update_by_name_async(self, mock_control_api_class): + """测试根据名称异步更新知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = KnowledgeBaseUpdateInput(description="Updated") + result = await KnowledgeBase.update_by_name_async("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_instance_sync(self, mock_control_api_class): + """测试实例同步更新""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = KnowledgeBase.get_by_name("test-kb") + + # 更新实例 + input_obj = KnowledgeBaseUpdateInput(description="Updated") + result = kb.update(input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_update_instance_async(self, mock_control_api_class): + """测试实例异步更新""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = await KnowledgeBase.get_by_name_async("test-kb") + + # 更新实例 + input_obj = KnowledgeBaseUpdateInput(description="Updated") + result = await kb.update_async(input_obj) + assert result is not None + + def test_update_instance_without_name(self): + """测试实例更新(无名称)""" + kb = KnowledgeBase() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + with pytest.raises(ValueError, match="knowledge_base_name is required"): + kb.update(input_obj) + + @pytest.mark.asyncio + async def test_update_instance_async_without_name(self): + """测试异步实例更新(无名称)""" + kb = KnowledgeBase() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + with pytest.raises(ValueError, match="knowledge_base_name is required"): + await kb.update_async(input_obj) + + +class TestKnowledgeBaseGet: + """测试 KnowledgeBase.get_by_name 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_by_name_sync(self, mock_control_api_class): + """测试同步获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase.get_by_name("test-kb") + assert result.knowledge_base_name == "test-kb" + assert result.knowledge_base_id == "kb-123" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_get_by_name_async(self, mock_control_api_class): + """测试异步获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase.get_by_name_async("test-kb") + assert result.knowledge_base_name == "test-kb" + + +class TestKnowledgeBaseRefresh: + """测试 KnowledgeBase.refresh 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_refresh_sync(self, mock_control_api_class): + """测试同步刷新知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = KnowledgeBase.get_by_name("test-kb") + + # 刷新实例 + kb.refresh() + assert kb.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_refresh_async(self, mock_control_api_class): + """测试异步刷新知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = await KnowledgeBase.get_by_name_async("test-kb") + + # 刷新实例 + await kb.refresh_async() + assert kb.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_instance_sync(self, mock_control_api_class): + """测试实例 get 方法同步""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + kb = KnowledgeBase.get_by_name("test-kb") + result = kb.get() + assert result.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_get_instance_async(self, mock_control_api_class): + """测试实例 get 方法异步""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + kb = await KnowledgeBase.get_by_name_async("test-kb") + result = await kb.get_async() + assert result.knowledge_base_name == "test-kb" + + def test_get_instance_without_name(self): + """测试实例获取(无名称)""" + kb = KnowledgeBase() + with pytest.raises(ValueError, match="knowledge_base_name is required"): + kb.get() + + @pytest.mark.asyncio + async def test_get_instance_async_without_name(self): + """测试异步实例获取(无名称)""" + kb = KnowledgeBase() + with pytest.raises(ValueError, match="knowledge_base_name is required"): + await kb.get_async() + + +class TestKnowledgeBaseList: + """测试 KnowledgeBase.list_all 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_all_sync(self, mock_control_api_class): + """测试同步列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult([ + MockKnowledgeBaseData(), + MockBailianKnowledgeBaseData(), + ]) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase.list_all() + # list_all 会对结果去重,所以相同 ID 的记录只会返回一个 + assert len(result) >= 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_list_all_async(self, mock_control_api_class): + """测试异步列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases_async = AsyncMock( + return_value=MockListResult([MockKnowledgeBaseData()]) + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase.list_all_async() + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_all_with_provider(self, mock_control_api_class): + """测试按提供商列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult( + [MockKnowledgeBaseData()] + ) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase.list_all(provider="ragflow") + assert len(result) >= 1 + + +class TestKnowledgeBaseFromInnerObject: + """测试 KnowledgeBase.from_inner_object 方法""" + + def test_from_inner_object(self): + """测试从内部对象创建知识库""" + mock_data = MockKnowledgeBaseData() + kb = KnowledgeBase.from_inner_object(mock_data) + + assert kb.knowledge_base_id == "kb-123" + assert kb.knowledge_base_name == "test-kb" + assert kb.provider == "ragflow" + assert kb.description == "Test knowledge base" + + def test_from_inner_object_with_extra(self): + """测试从内部对象创建知识库(带额外字段)""" + mock_data = MockKnowledgeBaseData() + extra = {"custom_field": "custom_value"} + kb = KnowledgeBase.from_inner_object(mock_data, extra) + + assert kb.knowledge_base_name == "test-kb" + + def test_from_inner_object_bailian(self): + """测试从内部对象创建百炼知识库""" + mock_data = MockBailianKnowledgeBaseData() + kb = KnowledgeBase.from_inner_object(mock_data) + + assert kb.knowledge_base_id == "kb-456" + assert kb.knowledge_base_name == "test-bailian-kb" + assert kb.provider == "bailian" + + def test_from_inner_object_adb(self): + """测试从内部对象创建 ADB 知识库""" + mock_data = MockADBKnowledgeBaseData() + kb = KnowledgeBase.from_inner_object(mock_data) + + assert kb.knowledge_base_id == "kb-789" + assert kb.knowledge_base_name == "test-adb-kb" + assert kb.provider == "adb" + + +class TestKnowledgeBaseGetDataAPI: + """测试 KnowledgeBase._get_data_api 方法""" + + def test_get_data_api_ragflow(self): + """测试获取 RagFlow 数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + + def test_get_data_api_bailian(self): + """测试获取百炼数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + + def test_get_data_api_adb(self): + """测试获取 ADB 数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + ), + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + + def test_get_data_api_with_dict_settings(self): + """测试使用字典设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings={ + "base_url": "https://ragflow.example.com", + "dataset_ids": ["ds-1"], + }, + retrieve_settings={ + "similarity_threshold": 0.8, + }, + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + + def test_get_data_api_bailian_with_dict_settings(self): + """测试百炼使用字典设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings={ + "workspace_id": "ws-123", + "index_ids": ["idx-1"], + }, + retrieve_settings={ + "dense_similarity_top_k": 10, + }, + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + + def test_get_data_api_adb_with_dict_settings(self): + """测试 ADB 使用字典设置获取数据链路 API(PascalCase 键名)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings={ + "DBInstanceId": "gp-123456", + "Namespace": "public", + "NamespacePassword": "password123", + }, + retrieve_settings={ + "TopK": 10, + }, + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + + def test_get_data_api_bailian_with_raw_dict_settings(self): + """测试百炼使用原始字典设置获取数据链路 API(绕过 Pydantic 转换)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + ) + # Bypass Pydantic validation to set raw dict + object.__setattr__( + kb, + "provider_settings", + { + "workspace_id": "ws-123", + "index_ids": ["idx-1"], + }, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "dense_similarity_top_k": 10, + }, + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + + def test_get_data_api_ragflow_with_raw_dict_settings(self): + """测试 RagFlow 使用原始字典设置获取数据链路 API(绕过 Pydantic 转换)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + credential_name="test-credential", + ) + # Bypass Pydantic validation to set raw dict + object.__setattr__( + kb, + "provider_settings", + { + "base_url": "https://ragflow.example.com", + "dataset_ids": ["ds-1"], + }, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "similarity_threshold": 0.8, + }, + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + + def test_get_data_api_adb_with_raw_dict_settings(self): + """测试 ADB 使用原始字典设置获取数据链路 API(绕过 Pydantic 转换,PascalCase 键名)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + ) + # Bypass Pydantic validation to set raw dict with PascalCase keys + object.__setattr__( + kb, + "provider_settings", + { + "DBInstanceId": "gp-123456", + "Namespace": "public", + "NamespacePassword": "password123", + "EmbeddingModel": "text-embedding-v1", + "Metrics": "cosine", + "Metadata": '{"key": "value"}', + }, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "TopK": 10, + "UseFullTextRetrieval": True, + "RerankFactor": 1.5, + "RecallWindow": [-5, 5], + "HybridSearch": "RRF", + "HybridSearchArgs": {"RRF": {"k": 60}}, + }, + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + + def test_get_data_api_without_provider(self): + """测试获取数据链路 API(无提供商)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="provider is required"): + kb._get_data_api() + + def test_get_data_api_with_string_provider(self): + """测试使用字符串提供商获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider="ragflow", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + + def test_get_data_api_bailian_without_settings(self): + """测试百炼无设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + assert data_api.provider_settings is None + assert data_api.retrieve_settings is None + + def test_get_data_api_bailian_without_retrieve_settings(self): + """测试百炼无检索设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + assert data_api.provider_settings is not None + assert data_api.retrieve_settings is None + + def test_get_data_api_ragflow_without_settings(self): + """测试 RagFlow 无设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + assert data_api.provider_settings is None + assert data_api.retrieve_settings is None + + def test_get_data_api_ragflow_without_retrieve_settings(self): + """测试 RagFlow 无检索设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + assert data_api.provider_settings is not None + assert data_api.retrieve_settings is None + + def test_get_data_api_adb_without_settings(self): + """测试 ADB 无设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + assert data_api.provider_settings is None + assert data_api.retrieve_settings is None + + def test_get_data_api_adb_without_retrieve_settings(self): + """测试 ADB 无检索设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + assert data_api.provider_settings is not None + assert data_api.retrieve_settings is None + + def test_get_data_api_bailian_with_invalid_provider_settings_type(self): + """测试百炼使用无效类型的 provider_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + ) + # Set an invalid type (not BailianProviderSettings or dict) + object.__setattr__(kb, "provider_settings", "invalid") + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + # converted_provider_settings should be None as the type is invalid + assert data_api.provider_settings is None + + def test_get_data_api_bailian_with_invalid_retrieve_settings_type(self): + """测试百炼使用无效类型的 retrieve_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + # Set an invalid type (not BailianRetrieveSettings or dict) + object.__setattr__(kb, "retrieve_settings", "invalid") + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + assert data_api.provider_settings is not None + # converted_retrieve_settings should be None as the type is invalid + assert data_api.retrieve_settings is None + + def test_get_data_api_ragflow_with_invalid_provider_settings_type(self): + """测试 RagFlow 使用无效类型的 provider_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + credential_name="test-credential", + ) + # Set an invalid type (not RagFlowProviderSettings or dict) + object.__setattr__(kb, "provider_settings", "invalid") + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + # converted_provider_settings should be None as the type is invalid + assert data_api.provider_settings is None + + def test_get_data_api_ragflow_with_invalid_retrieve_settings_type(self): + """测试 RagFlow 使用无效类型的 retrieve_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + # Set an invalid type (not RagFlowRetrieveSettings or dict) + object.__setattr__(kb, "retrieve_settings", "invalid") + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + assert data_api.provider_settings is not None + # converted_retrieve_settings should be None as the type is invalid + assert data_api.retrieve_settings is None + + def test_get_data_api_adb_with_invalid_provider_settings_type(self): + """测试 ADB 使用无效类型的 provider_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + ) + # Set an invalid type (not ADBProviderSettings or dict) + object.__setattr__(kb, "provider_settings", "invalid") + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + # converted_provider_settings should be None as the type is invalid + assert data_api.provider_settings is None + + def test_get_data_api_adb_with_invalid_retrieve_settings_type(self): + """测试 ADB 使用无效类型的 retrieve_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + # Set an invalid type (not ADBRetrieveSettings or dict) + object.__setattr__(kb, "retrieve_settings", "invalid") + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + assert data_api.provider_settings is not None + # converted_retrieve_settings should be None as the type is invalid + assert data_api.retrieve_settings is None + + def test_get_data_api_with_unknown_provider(self): + """测试使用未知提供商获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + ) + # Set an unknown provider that's not in KnowledgeBaseProvider enum + object.__setattr__(kb, "provider", "unknown_provider") + + # get_data_api should raise an error for unsupported provider + # The error comes from trying to convert to KnowledgeBaseProvider enum + with pytest.raises( + ValueError, match="is not a valid KnowledgeBaseProvider" + ): + kb._get_data_api() + + +class TestKnowledgeBaseRetrieve: + """测试 KnowledgeBase.retrieve 方法""" + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_retrieve_sync(self, mock_retrieve): + """测试同步检索""" + mock_retrieve.return_value = { + "data": [{"content": "test content", "score": 0.9}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = kb.retrieve("test query") + assert result["query"] == "test query" + assert "data" in result + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_retrieve_async(self, mock_retrieve_async): + """测试异步检索""" + mock_retrieve_async.return_value = { + "data": [{"content": "test content", "score": 0.9}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await kb.retrieve_async("test query") + assert result["query"] == "test query" + assert "data" in result + + +class TestKnowledgeBaseSafeGetKB: + """测试 KnowledgeBase._safe_get_kb 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_safe_get_kb_success(self, mock_control_api_class): + """测试安全获取知识库成功""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase._safe_get_kb("test-kb") + assert isinstance(result, KnowledgeBase) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_safe_get_kb_failure(self, mock_control_api_class): + """测试安全获取知识库失败""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.side_effect = Exception("Not found") + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase._safe_get_kb("test-kb") + assert isinstance(result, Exception) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_safe_get_kb_async_success(self, mock_control_api_class): + """测试异步安全获取知识库成功""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase._safe_get_kb_async("test-kb") + assert isinstance(result, KnowledgeBase) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_safe_get_kb_async_failure(self, mock_control_api_class): + """测试异步安全获取知识库失败""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + side_effect=Exception("Not found") + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase._safe_get_kb_async("test-kb") + assert isinstance(result, Exception) + + +class TestKnowledgeBaseSafeRetrieveKB: + """测试 KnowledgeBase._safe_retrieve_kb 方法""" + + def test_safe_retrieve_kb_with_exception(self): + """测试安全检索知识库(传入异常)""" + error = Exception("Not found") + result = KnowledgeBase._safe_retrieve_kb("test-kb", error, "test query") + + assert result["error"] is True + assert "Failed to retrieve" in result["data"] + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_safe_retrieve_kb_success(self, mock_retrieve): + """测试安全检索知识库成功""" + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = KnowledgeBase._safe_retrieve_kb("test-kb", kb, "test query") + assert "data" in result + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_safe_retrieve_kb_failure(self, mock_retrieve): + """测试安全检索知识库失败""" + mock_retrieve.side_effect = Exception("Retrieve failed") + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = KnowledgeBase._safe_retrieve_kb("test-kb", kb, "test query") + assert result["error"] is True + + @pytest.mark.asyncio + async def test_safe_retrieve_kb_async_with_exception(self): + """测试异步安全检索知识库(传入异常)""" + error = Exception("Not found") + result = await KnowledgeBase._safe_retrieve_kb_async( + "test-kb", error, "test query" + ) + + assert result["error"] is True + assert "Failed to retrieve" in result["data"] + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_safe_retrieve_kb_async_success(self, mock_retrieve_async): + """测试异步安全检索知识库成功""" + mock_retrieve_async.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await KnowledgeBase._safe_retrieve_kb_async( + "test-kb", kb, "test query" + ) + assert "data" in result + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_safe_retrieve_kb_async_failure(self, mock_retrieve_async): + """测试异步安全检索知识库失败""" + mock_retrieve_async.side_effect = Exception("Retrieve failed") + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await KnowledgeBase._safe_retrieve_kb_async( + "test-kb", kb, "test query" + ) + assert result["error"] is True + + +class TestKnowledgeBaseMultiRetrieve: + """测试 KnowledgeBase.multi_retrieve 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_multi_retrieve_sync(self, mock_retrieve, mock_control_api_class): + """测试同步多知识库检索""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + result = KnowledgeBase.multi_retrieve( + query="test query", + knowledge_base_names=["kb-1", "kb-2"], + ) + + assert "results" in result + assert "query" in result + assert result["query"] == "test query" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_multi_retrieve_async( + self, mock_retrieve_async, mock_control_api_class + ): + """测试异步多知识库检索""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + mock_retrieve_async.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + result = await KnowledgeBase.multi_retrieve_async( + query="test query", + knowledge_base_names=["kb-1", "kb-2"], + ) + + assert "results" in result + assert "query" in result + assert result["query"] == "test query" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_multi_retrieve_with_partial_failure(self, mock_control_api_class): + """测试同步多知识库检索(部分失败)""" + mock_control_api = MagicMock() + # 第一个成功,第二个失败 + mock_control_api.get_knowledge_base.side_effect = [ + MockKnowledgeBaseData(), + Exception("Not found"), + ] + mock_control_api_class.return_value = mock_control_api + + with patch( + "agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "kb-1", + } + + result = KnowledgeBase.multi_retrieve( + query="test query", + knowledge_base_names=["kb-1", "kb-2"], + ) + + assert "results" in result + # kb-2 应该有错误 + assert "kb-2" in result["results"] + assert result["results"]["kb-2"]["error"] is True + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_multi_retrieve_async_with_partial_failure( + self, mock_control_api_class + ): + """测试异步多知识库检索(部分失败)""" + mock_control_api = MagicMock() + + # 创建一个返回不同结果的 side_effect 函数 + call_count = [0] + + async def mock_get_async(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return MockKnowledgeBaseData() + else: + raise Exception("Not found") + + mock_control_api.get_knowledge_base_async = mock_get_async + mock_control_api_class.return_value = mock_control_api + + with patch( + "agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async" + ) as mock_retrieve_async: + mock_retrieve_async.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "kb-1", + } + + result = await KnowledgeBase.multi_retrieve_async( + query="test query", + knowledge_base_names=["kb-1", "kb-2"], + ) + + assert "results" in result diff --git a/tests/unittests/knowledgebase/test_model.py b/tests/unittests/knowledgebase/test_model.py new file mode 100644 index 0000000..5f5e335 --- /dev/null +++ b/tests/unittests/knowledgebase/test_model.py @@ -0,0 +1,650 @@ +"""测试 agentrun.knowledgebase.model 模块 / Test agentrun.knowledgebase.model module""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.model import ( + ADBProviderSettings, + ADBRetrieveSettings, + BailianProviderSettings, + BailianRetrieveSettings, + KnowledgeBaseCreateInput, + KnowledgeBaseImmutableProps, + KnowledgeBaseListInput, + KnowledgeBaseListOutput, + KnowledgeBaseMutableProps, + KnowledgeBaseProvider, + KnowledgeBaseSystemProps, + KnowledgeBaseUpdateInput, + RagFlowProviderSettings, + RagFlowRetrieveSettings, + RetrieveInput, +) + + +class TestKnowledgeBaseProvider: + """测试 KnowledgeBaseProvider 枚举""" + + def test_ragflow_value(self): + """测试 RAGFLOW 枚举值""" + assert KnowledgeBaseProvider.RAGFLOW.value == "ragflow" + + def test_bailian_value(self): + """测试 BAILIAN 枚举值""" + assert KnowledgeBaseProvider.BAILIAN.value == "bailian" + + def test_adb_value(self): + """测试 ADB 枚举值""" + assert KnowledgeBaseProvider.ADB.value == "adb" + + def test_provider_is_string_enum(self): + """测试 Provider 是字符串枚举""" + assert isinstance(KnowledgeBaseProvider.RAGFLOW, str) + assert KnowledgeBaseProvider.RAGFLOW == "ragflow" + + +class TestRagFlowProviderSettings: + """测试 RagFlowProviderSettings 模型""" + + def test_create_ragflow_provider_settings(self): + """测试创建 RagFlow 提供商设置""" + settings = RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1", "ds-2"], + ) + assert settings.base_url == "https://ragflow.example.com" + assert settings.dataset_ids == ["ds-1", "ds-2"] + + def test_ragflow_provider_settings_required_fields(self): + """测试 RagFlow 提供商设置必填字段""" + with pytest.raises(Exception): # Pydantic ValidationError + RagFlowProviderSettings() # type: ignore + + def test_ragflow_provider_settings_model_dump(self): + """测试 RagFlow 提供商设置序列化""" + settings = RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ) + data = settings.model_dump() + assert "base_url" in data or "baseUrl" in data + + +class TestRagFlowRetrieveSettings: + """测试 RagFlowRetrieveSettings 模型""" + + def test_create_ragflow_retrieve_settings(self): + """测试创建 RagFlow 检索设置""" + settings = RagFlowRetrieveSettings( + similarity_threshold=0.8, + vector_similarity_weight=0.5, + cross_languages=["English", "Chinese"], + ) + assert settings.similarity_threshold == 0.8 + assert settings.vector_similarity_weight == 0.5 + assert settings.cross_languages == ["English", "Chinese"] + + def test_ragflow_retrieve_settings_optional(self): + """测试 RagFlow 检索设置可选字段""" + settings = RagFlowRetrieveSettings() + assert settings.similarity_threshold is None + assert settings.vector_similarity_weight is None + assert settings.cross_languages is None + + def test_ragflow_retrieve_settings_partial(self): + """测试 RagFlow 检索设置部分字段""" + settings = RagFlowRetrieveSettings(similarity_threshold=0.7) + assert settings.similarity_threshold == 0.7 + assert settings.vector_similarity_weight is None + + +class TestBailianProviderSettings: + """测试 BailianProviderSettings 模型""" + + def test_create_bailian_provider_settings(self): + """测试创建百炼提供商设置""" + settings = BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1", "idx-2"], + ) + assert settings.workspace_id == "ws-123" + assert settings.index_ids == ["idx-1", "idx-2"] + + def test_bailian_provider_settings_required_fields(self): + """测试百炼提供商设置必填字段""" + with pytest.raises(Exception): + BailianProviderSettings() # type: ignore + + def test_bailian_provider_settings_single_index(self): + """测试百炼提供商设置单个索引""" + settings = BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ) + assert len(settings.index_ids) == 1 + + +class TestBailianRetrieveSettings: + """测试 BailianRetrieveSettings 模型""" + + def test_create_bailian_retrieve_settings(self): + """测试创建百炼检索设置""" + settings = BailianRetrieveSettings( + dense_similarity_top_k=10, + sparse_similarity_top_k=5, + rerank_min_score=0.5, + rerank_top_n=3, + ) + assert settings.dense_similarity_top_k == 10 + assert settings.sparse_similarity_top_k == 5 + assert settings.rerank_min_score == 0.5 + assert settings.rerank_top_n == 3 + + def test_bailian_retrieve_settings_optional(self): + """测试百炼检索设置可选字段""" + settings = BailianRetrieveSettings() + assert settings.dense_similarity_top_k is None + assert settings.sparse_similarity_top_k is None + assert settings.rerank_min_score is None + assert settings.rerank_top_n is None + + def test_bailian_retrieve_settings_partial(self): + """测试百炼检索设置部分字段""" + settings = BailianRetrieveSettings( + dense_similarity_top_k=20, + rerank_top_n=5, + ) + assert settings.dense_similarity_top_k == 20 + assert settings.rerank_top_n == 5 + assert settings.sparse_similarity_top_k is None + + +class TestADBProviderSettings: + """测试 ADBProviderSettings 模型""" + + def test_create_adb_provider_settings(self): + """测试创建 ADB 提供商设置""" + settings = ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + embedding_model="text-embedding-v3", + metrics="cosine", + metadata='{"key": "value"}', + ) + assert settings.db_instance_id == "gp-123456" + assert settings.namespace == "public" + assert settings.namespace_password == "password123" + assert settings.embedding_model == "text-embedding-v3" + assert settings.metrics == "cosine" + assert settings.metadata == '{"key": "value"}' + + def test_adb_provider_settings_required_fields(self): + """测试 ADB 提供商设置必填字段""" + with pytest.raises(Exception): + ADBProviderSettings() # type: ignore + + def test_adb_provider_settings_minimal(self): + """测试 ADB 提供商设置最小配置""" + settings = ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ) + assert settings.db_instance_id == "gp-123456" + assert settings.embedding_model is None + assert settings.metrics is None + assert settings.metadata is None + + +class TestADBRetrieveSettings: + """测试 ADBRetrieveSettings 模型""" + + def test_create_adb_retrieve_settings(self): + """测试创建 ADB 检索设置""" + settings = ADBRetrieveSettings( + top_k=10, + use_full_text_retrieval=True, + rerank_factor=1.5, + recall_window=[-5, 5], + hybrid_search="RRF", + hybrid_search_args={"RRF": {"k": 60}}, + ) + assert settings.top_k == 10 + assert settings.use_full_text_retrieval is True + assert settings.rerank_factor == 1.5 + assert settings.recall_window == [-5, 5] + assert settings.hybrid_search == "RRF" + assert settings.hybrid_search_args == {"RRF": {"k": 60}} + + def test_adb_retrieve_settings_optional(self): + """测试 ADB 检索设置可选字段""" + settings = ADBRetrieveSettings() + assert settings.top_k is None + assert settings.use_full_text_retrieval is None + assert settings.rerank_factor is None + assert settings.recall_window is None + assert settings.hybrid_search is None + assert settings.hybrid_search_args is None + + def test_adb_retrieve_settings_weight_hybrid(self): + """测试 ADB 检索设置加权混合检索""" + settings = ADBRetrieveSettings( + hybrid_search="Weight", + hybrid_search_args={"Weight": {"alpha": 0.5}}, + ) + assert settings.hybrid_search == "Weight" + assert settings.hybrid_search_args["Weight"]["alpha"] == 0.5 + + +class TestKnowledgeBaseMutableProps: + """测试 KnowledgeBaseMutableProps 模型""" + + def test_create_mutable_props(self): + """测试创建可变属性""" + props = KnowledgeBaseMutableProps( + description="Test description", + credential_name="test-credential", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + ) + assert props.description == "Test description" + assert props.credential_name == "test-credential" + assert isinstance(props.provider_settings, RagFlowProviderSettings) + assert isinstance(props.retrieve_settings, RagFlowRetrieveSettings) + + def test_mutable_props_optional(self): + """测试可变属性可选字段""" + props = KnowledgeBaseMutableProps() + assert props.description is None + assert props.credential_name is None + assert props.provider_settings is None + assert props.retrieve_settings is None + + def test_mutable_props_with_typed_settings(self): + """测试可变属性使用类型化设置""" + provider_settings = RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ) + props = KnowledgeBaseMutableProps( + provider_settings=provider_settings, + ) + assert isinstance(props.provider_settings, RagFlowProviderSettings) + + +class TestKnowledgeBaseImmutableProps: + """测试 KnowledgeBaseImmutableProps 模型""" + + def test_create_immutable_props(self): + """测试创建不可变属性""" + props = KnowledgeBaseImmutableProps( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + ) + assert props.knowledge_base_name == "test-kb" + assert props.provider == KnowledgeBaseProvider.RAGFLOW + + def test_immutable_props_optional(self): + """测试不可变属性可选字段""" + props = KnowledgeBaseImmutableProps() + assert props.knowledge_base_name is None + assert props.provider is None + + def test_immutable_props_with_string_provider(self): + """测试不可变属性使用字符串提供商""" + props = KnowledgeBaseImmutableProps( + knowledge_base_name="test-kb", + provider="bailian", + ) + assert props.provider == "bailian" + + +class TestKnowledgeBaseSystemProps: + """测试 KnowledgeBaseSystemProps 模型""" + + def test_create_system_props(self): + """测试创建系统属性""" + props = KnowledgeBaseSystemProps( + knowledge_base_id="kb-123", + created_at="2024-01-01T00:00:00Z", + last_updated_at="2024-01-02T00:00:00Z", + ) + assert props.knowledge_base_id == "kb-123" + assert props.created_at == "2024-01-01T00:00:00Z" + assert props.last_updated_at == "2024-01-02T00:00:00Z" + + def test_system_props_optional(self): + """测试系统属性可选字段""" + props = KnowledgeBaseSystemProps() + assert props.knowledge_base_id is None + assert props.created_at is None + assert props.last_updated_at is None + + +class TestKnowledgeBaseCreateInput: + """测试 KnowledgeBaseCreateInput 模型""" + + def test_create_minimal_input(self): + """测试创建最小输入参数""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + assert input_obj.knowledge_base_name == "test-kb" + assert input_obj.provider == KnowledgeBaseProvider.RAGFLOW + assert isinstance(input_obj.provider_settings, RagFlowProviderSettings) + + def test_create_full_input(self): + """测试创建完整输入参数""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + description="Test knowledge base", + credential_name="test-credential", + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + assert input_obj.knowledge_base_name == "test-kb" + assert input_obj.description == "Test knowledge base" + assert input_obj.credential_name == "test-credential" + + def test_create_input_with_adb(self): + """测试创建 ADB 输入参数""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-adb-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + assert input_obj.provider == KnowledgeBaseProvider.ADB + + def test_create_input_with_dict_settings(self): + """测试创建输入参数(字典设置会自动转换为类型化对象)""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider="ragflow", + provider_settings={ + "base_url": "https://ragflow.example.com", + "dataset_ids": ["ds-1"], + }, + ) + assert input_obj.provider == "ragflow" + # Pydantic 会自动将字典转换为类型化对象 + assert isinstance(input_obj.provider_settings, RagFlowProviderSettings) + assert ( + input_obj.provider_settings.base_url + == "https://ragflow.example.com" + ) + + def test_model_dump(self): + """测试模型序列化""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings={ + "base_url": "https://example.com", + "dataset_ids": [], + }, + ) + data = input_obj.model_dump() + assert input_obj.knowledge_base_name == "test-kb" + + +class TestKnowledgeBaseUpdateInput: + """测试 KnowledgeBaseUpdateInput 模型""" + + def test_create_update_input(self): + """测试创建更新输入参数""" + input_obj = KnowledgeBaseUpdateInput( + description="Updated description", + ) + assert input_obj.description == "Updated description" + + def test_update_input_with_credential(self): + """测试更新输入参数(带凭证)""" + input_obj = KnowledgeBaseUpdateInput( + credential_name="new-credential", + ) + assert input_obj.credential_name == "new-credential" + + def test_update_input_with_provider_settings(self): + """测试更新输入参数(带提供商设置)""" + input_obj = KnowledgeBaseUpdateInput( + provider_settings=RagFlowProviderSettings( + base_url="https://new-ragflow.example.com", + dataset_ids=["ds-new"], + ), + ) + assert input_obj.provider_settings is not None + + def test_update_input_with_retrieve_settings(self): + """测试更新输入参数(带检索设置)""" + input_obj = KnowledgeBaseUpdateInput( + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=20, + ), + ) + assert input_obj.retrieve_settings is not None + + def test_update_input_optional(self): + """测试更新输入参数可选字段""" + input_obj = KnowledgeBaseUpdateInput() + assert input_obj.description is None + assert input_obj.credential_name is None + assert input_obj.provider_settings is None + assert input_obj.retrieve_settings is None + + +class TestKnowledgeBaseListInput: + """测试 KnowledgeBaseListInput 模型""" + + def test_create_list_input(self): + """测试创建列表输入参数""" + input_obj = KnowledgeBaseListInput( + page_number=1, + page_size=10, + provider=KnowledgeBaseProvider.RAGFLOW, + ) + assert input_obj.page_number == 1 + assert input_obj.page_size == 10 + assert input_obj.provider == KnowledgeBaseProvider.RAGFLOW + + def test_list_input_default(self): + """测试列表输入参数默认值""" + input_obj = KnowledgeBaseListInput() + assert input_obj.provider is None + + def test_list_input_with_pagination(self): + """测试列表输入参数(带分页)""" + input_obj = KnowledgeBaseListInput( + page_number=2, + page_size=20, + ) + assert input_obj.page_number == 2 + assert input_obj.page_size == 20 + + def test_list_input_with_string_provider(self): + """测试列表输入参数(字符串提供商)""" + input_obj = KnowledgeBaseListInput( + provider="bailian", + ) + assert input_obj.provider == "bailian" + + +class TestKnowledgeBaseListOutput: + """测试 KnowledgeBaseListOutput 模型""" + + def test_create_list_output(self): + """测试创建列表输出""" + output = KnowledgeBaseListOutput( + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + description="Test knowledge base", + credential_name="test-credential", + created_at="2024-01-01T00:00:00Z", + last_updated_at="2024-01-01T00:00:00Z", + ) + assert output.knowledge_base_id == "kb-123" + assert output.knowledge_base_name == "test-kb" + assert output.provider == KnowledgeBaseProvider.RAGFLOW + assert output.description == "Test knowledge base" + + def test_list_output_optional(self): + """测试列表输出可选字段""" + output = KnowledgeBaseListOutput() + assert output.knowledge_base_id is None + assert output.knowledge_base_name is None + assert output.provider is None + assert output.description is None + assert output.credential_name is None + assert output.created_at is None + assert output.last_updated_at is None + + def test_list_output_with_settings(self): + """测试列表输出带设置(字典会自动转换为类型化对象)""" + output = KnowledgeBaseListOutput( + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + provider="bailian", + provider_settings={ + "workspace_id": "ws-123", + "index_ids": ["idx-1"], + }, + retrieve_settings={"dense_similarity_top_k": 10}, + ) + # Pydantic 会自动将字典转换为类型化对象 + assert isinstance(output.provider_settings, BailianProviderSettings) + assert output.provider_settings.workspace_id == "ws-123" + assert isinstance(output.retrieve_settings, BailianRetrieveSettings) + assert output.retrieve_settings.dense_similarity_top_k == 10 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseClient") + def test_to_knowledge_base_sync(self, mock_client_class): + """测试同步转换为 KnowledgeBase 对象""" + mock_client = MagicMock() + mock_kb = MagicMock() + mock_client.get.return_value = mock_kb + mock_client_class.return_value = mock_client + + output = KnowledgeBaseListOutput( + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + ) + + result = output.to_knowledge_base() + assert result == mock_kb + mock_client.get.assert_called_once() + + @patch("agentrun.knowledgebase.client.KnowledgeBaseClient") + @pytest.mark.asyncio + async def test_to_knowledge_base_async(self, mock_client_class): + """测试异步转换为 KnowledgeBase 对象""" + mock_client = MagicMock() + mock_kb = MagicMock() + mock_client.get_async = AsyncMock(return_value=mock_kb) + mock_client_class.return_value = mock_client + + output = KnowledgeBaseListOutput( + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + ) + + result = await output.to_knowledge_base_async() + assert result == mock_kb + + +class TestRetrieveInput: + """测试 RetrieveInput 模型""" + + def test_create_retrieve_input(self): + """测试创建检索输入参数""" + input_obj = RetrieveInput( + knowledge_base_names=["kb-1", "kb-2"], + query="What is AI?", + ) + assert input_obj.knowledge_base_names == ["kb-1", "kb-2"] + assert input_obj.query == "What is AI?" + + def test_retrieve_input_with_optional_fields(self): + """测试检索输入参数可选字段""" + input_obj = RetrieveInput( + knowledge_base_names=["kb-1"], + query="Test query", + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + provider="ragflow", + description="Test description", + credential_name="test-credential", + created_at="2024-01-01T00:00:00Z", + last_updated_at="2024-01-01T00:00:00Z", + ) + assert input_obj.knowledge_base_id == "kb-123" + assert input_obj.knowledge_base_name == "test-kb" + assert input_obj.provider == "ragflow" + + def test_retrieve_input_default_optional(self): + """测试检索输入参数默认可选值""" + input_obj = RetrieveInput( + knowledge_base_names=["kb-1"], + query="Test query", + ) + assert input_obj.knowledge_base_id is None + assert input_obj.knowledge_base_name is None + assert input_obj.provider is None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseClient") + def test_retrieve_input_to_knowledge_base_sync(self, mock_client_class): + """测试检索输入同步转换为 KnowledgeBase""" + mock_client = MagicMock() + mock_kb = MagicMock() + mock_client.get.return_value = mock_kb + mock_client_class.return_value = mock_client + + input_obj = RetrieveInput( + knowledge_base_names=["kb-1"], + query="Test query", + knowledge_base_name="test-kb", + ) + + result = input_obj.to_knowledge_base() + assert result == mock_kb + + @patch("agentrun.knowledgebase.client.KnowledgeBaseClient") + @pytest.mark.asyncio + async def test_retrieve_input_to_knowledge_base_async( + self, mock_client_class + ): + """测试检索输入异步转换为 KnowledgeBase""" + mock_client = MagicMock() + mock_kb = MagicMock() + mock_client.get_async = AsyncMock(return_value=mock_kb) + mock_client_class.return_value = mock_client + + input_obj = RetrieveInput( + knowledge_base_names=["kb-1"], + query="Test query", + knowledge_base_name="test-kb", + ) + + result = await input_obj.to_knowledge_base_async() + assert result == mock_kb