diff --git a/agentrun/knowledgebase/__init__.py b/agentrun/knowledgebase/__init__.py index df0f1a8..6a01376 100644 --- a/agentrun/knowledgebase/__init__.py +++ b/agentrun/knowledgebase/__init__.py @@ -1,6 +1,7 @@ """KnowledgeBase 模块 / KnowledgeBase Module""" from .api import ( + ADBDataAPI, BailianDataAPI, get_data_api, KnowledgeBaseControlAPI, @@ -10,6 +11,8 @@ from .client import KnowledgeBaseClient from .knowledgebase import KnowledgeBase from .model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseCreateInput, @@ -33,6 +36,7 @@ "KnowledgeBaseDataAPI", "RagFlowDataAPI", "BailianDataAPI", + "ADBDataAPI", "get_data_api", # enums "KnowledgeBaseProvider", @@ -40,10 +44,12 @@ "ProviderSettings", "RagFlowProviderSettings", "BailianProviderSettings", + "ADBProviderSettings", # retrieve settings "RetrieveSettings", "RagFlowRetrieveSettings", "BailianRetrieveSettings", + "ADBRetrieveSettings", # api model "KnowledgeBaseCreateInput", "KnowledgeBaseUpdateInput", diff --git a/agentrun/knowledgebase/__knowledgebase_async_template.py b/agentrun/knowledgebase/__knowledgebase_async_template.py index 96ee94c..501449e 100644 --- a/agentrun/knowledgebase/__knowledgebase_async_template.py +++ b/agentrun/knowledgebase/__knowledgebase_async_template.py @@ -14,6 +14,8 @@ from .api.data import get_data_api from .model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseCreateInput, @@ -294,6 +296,54 @@ def _get_data_api(self, config: Optional[Config] = None): **self.retrieve_settings ) + elif provider == KnowledgeBaseProvider.ADB: + # ADB 设置 / ADB settings + if self.provider_settings: + if isinstance(self.provider_settings, ADBProviderSettings): + converted_provider_settings = self.provider_settings + elif isinstance(self.provider_settings, dict): + # ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case + # ADB provider_settings uses PascalCase keys, need to convert to snake_case + converted_provider_settings = ADBProviderSettings( + db_instance_id=self.provider_settings.get( + "DBInstanceId", "" + ), + namespace=self.provider_settings.get("Namespace", ""), + namespace_password=self.provider_settings.get( + "NamespacePassword", "" + ), + embedding_model=self.provider_settings.get( + "EmbeddingModel" + ), + metrics=self.provider_settings.get("Metrics"), + metadata=self.provider_settings.get("Metadata"), + ) + + if self.retrieve_settings: + if isinstance(self.retrieve_settings, ADBRetrieveSettings): + converted_retrieve_settings = self.retrieve_settings + elif isinstance(self.retrieve_settings, dict): + # ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case + # ADB retrieve_settings uses PascalCase keys, need to convert to snake_case + converted_retrieve_settings = ADBRetrieveSettings( + top_k=self.retrieve_settings.get("TopK"), + use_full_text_retrieval=self.retrieve_settings.get( + "UseFullTextRetrieval" + ), + rerank_factor=self.retrieve_settings.get( + "RerankFactor" + ), + recall_window=self.retrieve_settings.get( + "RecallWindow" + ), + hybrid_search=self.retrieve_settings.get( + "HybridSearch" + ), + hybrid_search_args=self.retrieve_settings.get( + "HybridSearchArgs" + ), + ) + return get_data_api( provider=provider, knowledge_base_name=self.knowledge_base_name or "", diff --git a/agentrun/knowledgebase/api/__data_async_template.py b/agentrun/knowledgebase/api/__data_async_template.py index 4f8a1ae..51e9c43 100644 --- a/agentrun/knowledgebase/api/__data_async_template.py +++ b/agentrun/knowledgebase/api/__data_async_template.py @@ -3,14 +3,15 @@ 提供知识库检索功能的数据链路 API。 Provides data API for knowledge base retrieval operations. -根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。 -Dispatches to different implementations based on provider type (ragflow / bailian). +根据不同的 provider 类型(ragflow / bailian / adb)分发到不同的实现。 +Dispatches to different implementations based on provider type (ragflow / bailian / adb). """ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union from alibabacloud_bailian20231229 import models as bailian_models +from alibabacloud_gpdb20160503 import models as gpdb_models import httpx from agentrun.utils.config import Config @@ -19,6 +20,8 @@ from agentrun.utils.log import logger from ..model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseProvider, @@ -347,15 +350,213 @@ async def retrieve_async( } +class ADBDataAPI(KnowledgeBaseDataAPI, ControlAPI): + """ADB (AnalyticDB for PostgreSQL) 知识库数据链路 API / ADB KnowledgeBase Data API + + 实现 ADB 知识库的检索逻辑,通过 GPDB SDK 调用 QueryContent 接口。 + Implements retrieval logic for ADB knowledge base via GPDB SDK QueryContent API. + """ + + def __init__( + self, + knowledge_base_name: str, + config: Optional[Config] = None, + provider_settings: Optional[ADBProviderSettings] = None, + retrieve_settings: Optional[ADBRetrieveSettings] = None, + ): + """初始化 ADB 知识库数据链路 API / Initialize ADB KnowledgeBase Data API + + Args: + knowledge_base_name: 知识库名称 / Knowledge base name + config: 配置 / Configuration + provider_settings: ADB 提供商设置 / ADB provider settings + retrieve_settings: ADB 检索设置 / ADB retrieve settings + """ + KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config) + ControlAPI.__init__(self, config) + self.provider_settings = provider_settings + self.retrieve_settings = retrieve_settings + + def _build_query_content_request( + self, query: str, config: Optional[Config] = None + ) -> gpdb_models.QueryContentRequest: + """构建 QueryContent 请求 / Build QueryContent request + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + QueryContentRequest: GPDB QueryContent 请求对象 + """ + if self.provider_settings is None: + raise ValueError("provider_settings is required for ADB retrieval") + + cfg = Config.with_configs(self.config, config) + + # 构建基础请求参数 / Build base request parameters + request_params: Dict[str, Any] = { + "content": query, + "dbinstance_id": self.provider_settings.db_instance_id, + "namespace": self.provider_settings.namespace, + "namespace_password": self.provider_settings.namespace_password, + "collection": self.knowledge_base_name, + "region_id": cfg.get_region_id(), + } + + # 添加可选的提供商设置 / Add optional provider settings + if self.provider_settings.metrics is not None: + request_params["metrics"] = self.provider_settings.metrics + + # 添加检索设置 / Add retrieve settings + if self.retrieve_settings: + if self.retrieve_settings.top_k is not None: + request_params["top_k"] = self.retrieve_settings.top_k + if self.retrieve_settings.use_full_text_retrieval is not None: + request_params["use_full_text_retrieval"] = ( + self.retrieve_settings.use_full_text_retrieval + ) + if self.retrieve_settings.rerank_factor is not None: + request_params["rerank_factor"] = ( + self.retrieve_settings.rerank_factor + ) + if self.retrieve_settings.recall_window is not None: + request_params["recall_window"] = ( + self.retrieve_settings.recall_window + ) + if self.retrieve_settings.hybrid_search is not None: + request_params["hybrid_search"] = ( + self.retrieve_settings.hybrid_search + ) + if self.retrieve_settings.hybrid_search_args is not None: + request_params["hybrid_search_args"] = ( + self.retrieve_settings.hybrid_search_args + ) + + return gpdb_models.QueryContentRequest(**request_params) + + def _parse_query_content_response( + self, response: gpdb_models.QueryContentResponse, query: str + ) -> Dict[str, Any]: + """解析 QueryContent 响应 / Parse QueryContent response + + Args: + response: GPDB QueryContent 响应对象 + query: 原始查询文本 / Original query text + + Returns: + Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results + """ + all_matches: List[Dict[str, Any]] = [] + + if response.body and response.body.matches: + match_list = response.body.matches.match_list or [] + for match in match_list: + all_matches.append({ + "content": ( + match.content if hasattr(match, "content") else None + ), + "score": match.score if hasattr(match, "score") else None, + "id": match.id if hasattr(match, "id") else None, + "file_name": ( + match.file_name if hasattr(match, "file_name") else None + ), + "file_url": ( + match.file_url if hasattr(match, "file_url") else None + ), + "metadata": ( + match.metadata if hasattr(match, "metadata") else None + ), + "rerank_score": ( + match.rerank_score + if hasattr(match, "rerank_score") + else None + ), + "retrieval_source": ( + match.retrieval_source + if hasattr(match, "retrieval_source") + else None + ), + }) + + return { + "data": all_matches, + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "request_id": ( + response.body.request_id + if response.body and hasattr(response.body, "request_id") + else None + ), + } + + async def retrieve_async( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """ADB 检索(异步)/ ADB retrieval asynchronously + + 通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。 + Retrieves from ADB knowledge base via GPDB SDK QueryContent API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for ADB retrieval" + ) + + # 获取 GPDB 客户端 / Get GPDB client + client = self._get_gpdb_client(config) + + # 构建请求 / Build request + request = self._build_query_content_request(query, config) + logger.debug(f"ADB QueryContent request: {request}") + + # 调用 QueryContent API / Call QueryContent API + response = await client.query_content_async(request) + logger.debug(f"ADB QueryContent response: {response}") + + # 解析并返回结果 / Parse and return results + return self._parse_query_content_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from ADB knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def get_data_api( provider: KnowledgeBaseProvider, knowledge_base_name: str, config: Optional[Config] = None, provider_settings: Optional[ - Union[RagFlowProviderSettings, BailianProviderSettings] + Union[ + RagFlowProviderSettings, + BailianProviderSettings, + ADBProviderSettings, + ] ] = None, retrieve_settings: Optional[ - Union[RagFlowRetrieveSettings, BailianRetrieveSettings] + Union[ + RagFlowRetrieveSettings, + BailianRetrieveSettings, + ADBRetrieveSettings, + ] ] = None, credential_name: Optional[str] = None, ) -> KnowledgeBaseDataAPI: @@ -410,5 +611,22 @@ def get_data_api( provider_settings=bailian_provider_settings, retrieve_settings=bailian_retrieve_settings, ) + elif provider == KnowledgeBaseProvider.ADB or provider == "adb": + adb_provider_settings = ( + provider_settings + if isinstance(provider_settings, ADBProviderSettings) + else None + ) + adb_retrieve_settings = ( + retrieve_settings + if isinstance(retrieve_settings, ADBRetrieveSettings) + else None + ) + return ADBDataAPI( + knowledge_base_name, + config, + provider_settings=adb_provider_settings, + retrieve_settings=adb_retrieve_settings, + ) else: raise ValueError(f"Unsupported provider type: {provider}") diff --git a/agentrun/knowledgebase/api/__init__.py b/agentrun/knowledgebase/api/__init__.py index 2746a9e..bcfc80c 100644 --- a/agentrun/knowledgebase/api/__init__.py +++ b/agentrun/knowledgebase/api/__init__.py @@ -2,6 +2,7 @@ from .control import KnowledgeBaseControlAPI from .data import ( + ADBDataAPI, BailianDataAPI, get_data_api, KnowledgeBaseDataAPI, @@ -15,5 +16,6 @@ "KnowledgeBaseDataAPI", "RagFlowDataAPI", "BailianDataAPI", + "ADBDataAPI", "get_data_api", ] diff --git a/agentrun/knowledgebase/api/data.py b/agentrun/knowledgebase/api/data.py index 350a302..747157c 100644 --- a/agentrun/knowledgebase/api/data.py +++ b/agentrun/knowledgebase/api/data.py @@ -13,14 +13,15 @@ 提供知识库检索功能的数据链路 API。 Provides data API for knowledge base retrieval operations. -根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。 -Dispatches to different implementations based on provider type (ragflow / bailian). +根据不同的 provider 类型(ragflow / bailian / adb)分发到不同的实现。 +Dispatches to different implementations based on provider type (ragflow / bailian / adb). """ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union from alibabacloud_bailian20231229 import models as bailian_models +from alibabacloud_gpdb20160503 import models as gpdb_models import httpx from agentrun.utils.config import Config @@ -29,6 +30,8 @@ from agentrun.utils.log import logger from ..model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseProvider, @@ -557,15 +560,262 @@ def retrieve( } +class ADBDataAPI(KnowledgeBaseDataAPI, ControlAPI): + """ADB (AnalyticDB for PostgreSQL) 知识库数据链路 API / ADB KnowledgeBase Data API + + 实现 ADB 知识库的检索逻辑,通过 GPDB SDK 调用 QueryContent 接口。 + Implements retrieval logic for ADB knowledge base via GPDB SDK QueryContent API. + """ + + def __init__( + self, + knowledge_base_name: str, + config: Optional[Config] = None, + provider_settings: Optional[ADBProviderSettings] = None, + retrieve_settings: Optional[ADBRetrieveSettings] = None, + ): + """初始化 ADB 知识库数据链路 API / Initialize ADB KnowledgeBase Data API + + Args: + knowledge_base_name: 知识库名称 / Knowledge base name + config: 配置 / Configuration + provider_settings: ADB 提供商设置 / ADB provider settings + retrieve_settings: ADB 检索设置 / ADB retrieve settings + """ + KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config) + ControlAPI.__init__(self, config) + self.provider_settings = provider_settings + self.retrieve_settings = retrieve_settings + + def _build_query_content_request( + self, query: str, config: Optional[Config] = None + ) -> gpdb_models.QueryContentRequest: + """构建 QueryContent 请求 / Build QueryContent request + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + QueryContentRequest: GPDB QueryContent 请求对象 + """ + if self.provider_settings is None: + raise ValueError("provider_settings is required for ADB retrieval") + + cfg = Config.with_configs(self.config, config) + + # 构建基础请求参数 / Build base request parameters + request_params: Dict[str, Any] = { + "content": query, + "dbinstance_id": self.provider_settings.db_instance_id, + "namespace": self.provider_settings.namespace, + "namespace_password": self.provider_settings.namespace_password, + "collection": self.knowledge_base_name, + "region_id": cfg.get_region_id(), + } + + # 添加可选的提供商设置 / Add optional provider settings + if self.provider_settings.metrics is not None: + request_params["metrics"] = self.provider_settings.metrics + + # 添加检索设置 / Add retrieve settings + if self.retrieve_settings: + if self.retrieve_settings.top_k is not None: + request_params["top_k"] = self.retrieve_settings.top_k + if self.retrieve_settings.use_full_text_retrieval is not None: + request_params["use_full_text_retrieval"] = ( + self.retrieve_settings.use_full_text_retrieval + ) + if self.retrieve_settings.rerank_factor is not None: + request_params["rerank_factor"] = ( + self.retrieve_settings.rerank_factor + ) + if self.retrieve_settings.recall_window is not None: + request_params["recall_window"] = ( + self.retrieve_settings.recall_window + ) + if self.retrieve_settings.hybrid_search is not None: + request_params["hybrid_search"] = ( + self.retrieve_settings.hybrid_search + ) + if self.retrieve_settings.hybrid_search_args is not None: + request_params["hybrid_search_args"] = ( + self.retrieve_settings.hybrid_search_args + ) + + return gpdb_models.QueryContentRequest(**request_params) + + def _parse_query_content_response( + self, response: gpdb_models.QueryContentResponse, query: str + ) -> Dict[str, Any]: + """解析 QueryContent 响应 / Parse QueryContent response + + Args: + response: GPDB QueryContent 响应对象 + query: 原始查询文本 / Original query text + + Returns: + Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results + """ + all_matches: List[Dict[str, Any]] = [] + + if response.body and response.body.matches: + match_list = response.body.matches.match_list or [] + for match in match_list: + all_matches.append({ + "content": ( + match.content if hasattr(match, "content") else None + ), + "score": match.score if hasattr(match, "score") else None, + "id": match.id if hasattr(match, "id") else None, + "file_name": ( + match.file_name if hasattr(match, "file_name") else None + ), + "file_url": ( + match.file_url if hasattr(match, "file_url") else None + ), + "metadata": ( + match.metadata if hasattr(match, "metadata") else None + ), + "rerank_score": ( + match.rerank_score + if hasattr(match, "rerank_score") + else None + ), + "retrieval_source": ( + match.retrieval_source + if hasattr(match, "retrieval_source") + else None + ), + }) + + return { + "data": all_matches, + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "request_id": ( + response.body.request_id + if response.body and hasattr(response.body, "request_id") + else None + ), + } + + async def retrieve_async( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """ADB 检索(异步)/ ADB retrieval asynchronously + + 通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。 + Retrieves from ADB knowledge base via GPDB SDK QueryContent API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for ADB retrieval" + ) + + # 获取 GPDB 客户端 / Get GPDB client + client = self._get_gpdb_client(config) + + # 构建请求 / Build request + request = self._build_query_content_request(query, config) + logger.debug(f"ADB QueryContent request: {request}") + + # 调用 QueryContent API / Call QueryContent API + response = await client.query_content_async(request) + logger.debug(f"ADB QueryContent response: {response}") + + # 解析并返回结果 / Parse and return results + return self._parse_query_content_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from ADB knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def retrieve( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """ADB 检索(同步)/ ADB retrieval synchronously + + 通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。 + Retrieves from ADB knowledge base via GPDB SDK QueryContent API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for ADB retrieval" + ) + + # 获取 GPDB 客户端 / Get GPDB client + client = self._get_gpdb_client(config) + + # 构建请求 / Build request + request = self._build_query_content_request(query, config) + logger.debug(f"ADB QueryContent request: {request}") + + # 调用 QueryContent API / Call QueryContent API + response = client.query_content(request) + logger.debug(f"ADB QueryContent response: {response}") + + # 解析并返回结果 / Parse and return results + return self._parse_query_content_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from ADB knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def get_data_api( provider: KnowledgeBaseProvider, knowledge_base_name: str, config: Optional[Config] = None, provider_settings: Optional[ - Union[RagFlowProviderSettings, BailianProviderSettings] + Union[ + RagFlowProviderSettings, + BailianProviderSettings, + ADBProviderSettings, + ] ] = None, retrieve_settings: Optional[ - Union[RagFlowRetrieveSettings, BailianRetrieveSettings] + Union[ + RagFlowRetrieveSettings, + BailianRetrieveSettings, + ADBRetrieveSettings, + ] ] = None, credential_name: Optional[str] = None, ) -> KnowledgeBaseDataAPI: @@ -620,5 +870,22 @@ def get_data_api( provider_settings=bailian_provider_settings, retrieve_settings=bailian_retrieve_settings, ) + elif provider == KnowledgeBaseProvider.ADB or provider == "adb": + adb_provider_settings = ( + provider_settings + if isinstance(provider_settings, ADBProviderSettings) + else None + ) + adb_retrieve_settings = ( + retrieve_settings + if isinstance(retrieve_settings, ADBRetrieveSettings) + else None + ) + return ADBDataAPI( + knowledge_base_name, + config, + provider_settings=adb_provider_settings, + retrieve_settings=adb_retrieve_settings, + ) else: raise ValueError(f"Unsupported provider type: {provider}") diff --git a/agentrun/knowledgebase/knowledgebase.py b/agentrun/knowledgebase/knowledgebase.py index 2e453da..74c5c50 100644 --- a/agentrun/knowledgebase/knowledgebase.py +++ b/agentrun/knowledgebase/knowledgebase.py @@ -24,6 +24,8 @@ from .api.data import get_data_api from .model import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseCreateInput, @@ -472,6 +474,54 @@ def _get_data_api(self, config: Optional[Config] = None): **self.retrieve_settings ) + elif provider == KnowledgeBaseProvider.ADB: + # ADB 设置 / ADB settings + if self.provider_settings: + if isinstance(self.provider_settings, ADBProviderSettings): + converted_provider_settings = self.provider_settings + elif isinstance(self.provider_settings, dict): + # ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case + # ADB provider_settings uses PascalCase keys, need to convert to snake_case + converted_provider_settings = ADBProviderSettings( + db_instance_id=self.provider_settings.get( + "DBInstanceId", "" + ), + namespace=self.provider_settings.get("Namespace", ""), + namespace_password=self.provider_settings.get( + "NamespacePassword", "" + ), + embedding_model=self.provider_settings.get( + "EmbeddingModel" + ), + metrics=self.provider_settings.get("Metrics"), + metadata=self.provider_settings.get("Metadata"), + ) + + if self.retrieve_settings: + if isinstance(self.retrieve_settings, ADBRetrieveSettings): + converted_retrieve_settings = self.retrieve_settings + elif isinstance(self.retrieve_settings, dict): + # ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case + # ADB retrieve_settings uses PascalCase keys, need to convert to snake_case + converted_retrieve_settings = ADBRetrieveSettings( + top_k=self.retrieve_settings.get("TopK"), + use_full_text_retrieval=self.retrieve_settings.get( + "UseFullTextRetrieval" + ), + rerank_factor=self.retrieve_settings.get( + "RerankFactor" + ), + recall_window=self.retrieve_settings.get( + "RecallWindow" + ), + hybrid_search=self.retrieve_settings.get( + "HybridSearch" + ), + hybrid_search_args=self.retrieve_settings.get( + "HybridSearchArgs" + ), + ) + return get_data_api( provider=provider, knowledge_base_name=self.knowledge_base_name or "", diff --git a/agentrun/knowledgebase/model.py b/agentrun/knowledgebase/model.py index 69ce23c..1c7c227 100644 --- a/agentrun/knowledgebase/model.py +++ b/agentrun/knowledgebase/model.py @@ -18,6 +18,8 @@ class KnowledgeBaseProvider(str, Enum): """RagFlow 知识库 / RagFlow knowledge base""" BAILIAN = "bailian" """百炼知识库 / Bailian knowledge base""" + ADB = "adb" + """ADB (AnalyticDB for PostgreSQL) 知识库 / ADB knowledge base""" # ============================================================================= @@ -75,17 +77,77 @@ class BailianRetrieveSettings(BaseModel): """重排序返回的 Top N 数量 / Rerank top N""" +# ============================================================================= +# ADB 配置模型 / ADB Configuration Models +# ============================================================================= + + +class ADBProviderSettings(BaseModel): + """ADB (AnalyticDB for PostgreSQL) 提供商设置 / ADB Provider Settings + + 配置 ADB 知识库的连接和访问参数。 + Configure ADB knowledge base connection and access parameters. + """ + + db_instance_id: str + """ADB 实例 ID / ADB instance ID""" + namespace: str + """命名空间,默认为 public / Namespace, defaults to public""" + namespace_password: str + """命名空间密码 / Namespace password""" + embedding_model: Optional[str] = None + """向量化模型名称,如 text-embedding-v3 / Embedding model name""" + metrics: Optional[str] = None + """相似度算法:l2(欧氏距离)、ip(内积)、cosine(余弦相似度) + Similarity algorithm: l2 (Euclidean), ip (inner product), cosine""" + metadata: Optional[str] = None + """元数据配置,JSON 字符串格式 / Metadata configuration in JSON string format""" + + +class ADBRetrieveSettings(BaseModel): + """ADB 检索设置 / ADB Retrieve Settings + + 配置 ADB 知识库的检索参数,支持向量检索和全文检索的混合模式。 + Configure ADB knowledge base retrieval parameters, supporting hybrid + vector and full-text retrieval modes. + """ + + top_k: Optional[int] = None + """返回结果的数量 / Number of results to return""" + use_full_text_retrieval: Optional[bool] = None + """是否启用全文检索(双路召回),默认 false 仅使用向量检索 + Enable full-text retrieval (dual recall), defaults to false (vector only)""" + rerank_factor: Optional[float] = None + """重排序因子,取值范围 1 < RerankFactor <= 5 + Re-ranking factor, value range: 1 < RerankFactor <= 5""" + recall_window: Optional[List[int]] = None + """召回窗口,格式为 [A, B],其中 -10 <= A <= 0,0 <= B <= 10 + Recall window, format [A, B] where -10 <= A <= 0, 0 <= B <= 10""" + hybrid_search: Optional[str] = None + """混合检索算法:RRF(倒数排名融合)、Weight(加权排序)、Cascaded(级联检索) + Hybrid search algorithm: RRF, Weight, or Cascaded""" + hybrid_search_args: Optional[Dict[str, Any]] = None + """混合检索算法参数,如 {"RRF": {"k": 60}} 或 {"Weight": {"alpha": 0.5}} + Hybrid search algorithm parameters""" + + # ============================================================================= # 联合类型定义 / Union Type Definitions # ============================================================================= ProviderSettings = Union[ - RagFlowProviderSettings, BailianProviderSettings, Dict[str, Any] + RagFlowProviderSettings, + BailianProviderSettings, + ADBProviderSettings, + Dict[str, Any], ] """提供商设置联合类型 / Provider settings union type""" RetrieveSettings = Union[ - RagFlowRetrieveSettings, BailianRetrieveSettings, Dict[str, Any] + RagFlowRetrieveSettings, + BailianRetrieveSettings, + ADBRetrieveSettings, + Dict[str, Any], ] """检索设置联合类型 / Retrieve settings union type""" diff --git a/agentrun/utils/control_api.py b/agentrun/utils/control_api.py index d9db600..b74a822 100644 --- a/agentrun/utils/control_api.py +++ b/agentrun/utils/control_api.py @@ -9,6 +9,7 @@ from alibabacloud_agentrun20250910.client import Client as AgentRunClient from alibabacloud_bailian20231229.client import Client as BailianClient from alibabacloud_devs20230714.client import Client as DevsClient +from alibabacloud_gpdb20160503.client import Client as GPDBClient from alibabacloud_tea_openapi import utils_models as open_api_util_models from agentrun.utils.config import Config @@ -103,3 +104,42 @@ def _get_bailian_client( read_timeout=cfg.get_read_timeout(), # type: ignore ) ) + + def _get_gpdb_client(self, config: Optional[Config] = None) -> "GPDBClient": + """ + 获取 GPDB (AnalyticDB for PostgreSQL) API 客户端实例 + Get GPDB (AnalyticDB for PostgreSQL) API client instance + + Args: + config: 配置对象,可选 / Configuration object, optional + + Returns: + GPDBClient: GPDB API 客户端实例 / GPDB API client instance + """ + + cfg = Config.with_configs(self.config, config) + # GPDB 使用区域级别的 endpoint / GPDB uses region-level endpoint + region_id = cfg.get_region_id() + if region_id in ( + "cn-beijing", + "cn-hangzhou", + "cn-shanghai", + "cn-shenzhen", + "cn-hongkong", + "ap-southeast-1", + ): + endpoint = "gpdb.aliyuncs.com" + else: + endpoint = f"gpdb.{region_id}.aliyuncs.com" + + return GPDBClient( + open_api_util_models.Config( + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + security_token=cfg.get_security_token(), + region_id=cfg.get_region_id(), + endpoint=endpoint, + connect_timeout=cfg.get_timeout(), # type: ignore + read_timeout=cfg.get_read_timeout(), # type: ignore + ) + ) diff --git a/examples/knowledgebase.py b/examples/knowledgebase.py index b9ddef5..61b5722 100644 --- a/examples/knowledgebase.py +++ b/examples/knowledgebase.py @@ -1,9 +1,9 @@ """ 知识库模块示例 / KnowledgeBase Module Example -本示例演示如何使用 AgentRun SDK 管理知识库,包括百炼和 RagFlow 两种类型: +本示例演示如何使用 AgentRun SDK 管理知识库,包括百炼、RagFlow 和 ADB 三种类型: This example demonstrates how to use the AgentRun SDK to manage knowledge bases, -including both Bailian and RagFlow types: +including Bailian, RagFlow and ADB types: 1. 创建知识库 / Create knowledge base (Bailian & RagFlow) 2. 获取知识库信息 / Get knowledge base info @@ -25,6 +25,12 @@ - RAGFLOW_BASE_URL: RagFlow 服务地址 - RAGFLOW_DATASET_IDS: RagFlow 数据集 ID 列表(逗号分隔) - RAGFLOW_CREDENTIAL_NAME: RagFlow API Key 凭证名称 + +ADB 知识库额外配置 / Additional config for ADB: +- ADB_INSTANCE_ID: ADB 实例 ID +- ADB_NAMESPACE: ADB 命名空间 +- ADB_NAMESPACE_PASSWORD: ADB 命名空间密码 +- ADB_COLLECTION: ADB 文档集合名称 """ import json @@ -32,6 +38,8 @@ import time from agentrun.knowledgebase import ( + ADBProviderSettings, + ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, KnowledgeBase, @@ -100,6 +108,34 @@ "RAGFLOW_CREDENTIAL_NAME", "ragflow-api-key" ) +# ----------------------------------------------------------------------------- +# ADB 知识库配置 / ADB Knowledge Base Configuration +# ----------------------------------------------------------------------------- + +# ADB 知识库名称 +# ADB knowledge base name +ADB_KB_NAME = f"sdk-test-adb-kb-{TIMESTAMP}" + +# ADB 实例 ID,请替换为您的实际值 +# ADB instance ID, please replace with your actual value +ADB_INSTANCE_ID = os.getenv("ADB_INSTANCE_ID", "gp-your-instance-id") + +# ADB 命名空间,默认为 public +# ADB namespace, defaults to public +ADB_NAMESPACE = os.getenv("ADB_NAMESPACE", "public") + +# ADB 命名空间密码,请替换为您的实际值 +# ADB namespace password, please replace with your actual value +ADB_NAMESPACE_PASSWORD = os.getenv("ADB_NAMESPACE_PASSWORD", "your-password") + +# ADB 文档集合名称,请替换为您的实际值 +# ADB collection name, please replace with your actual value +ADB_COLLECTION = os.getenv("ADB_COLLECTION", "your-collection") + +# ADB 向量化模型名称(可选) +# ADB embedding model name (optional) +ADB_EMBEDDING_MODEL = os.getenv("ADB_EMBEDDING_MODEL", "text-embedding-v3") + # ============================================================================ # 客户端初始化 / Client Initialization # ============================================================================ @@ -362,6 +398,153 @@ def delete_ragflow_kb(kb: KnowledgeBase): ) +# ============================================================================ +# ADB 知识库示例函数 / ADB Knowledge Base Example Functions +# ============================================================================ + + +def create_or_get_adb_kb() -> KnowledgeBase: + """创建或获取已有的 ADB 知识库 / Create or get existing ADB knowledge base + + Returns: + KnowledgeBase: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("创建或获取 ADB 知识库") + logger.info("Create or get ADB knowledge base") + logger.info("=" * 60) + + try: + # 创建 ADB 知识库 / Create ADB knowledge base + kb = KnowledgeBase.create( + KnowledgeBaseCreateInput( + knowledge_base_name=ADB_KB_NAME, + description=( + "通过 SDK 创建的 ADB 知识库示例 / ADB KB example" + " created via SDK" + ), + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id=ADB_INSTANCE_ID, + namespace=ADB_NAMESPACE, + namespace_password=ADB_NAMESPACE_PASSWORD, + embedding_model=ADB_EMBEDDING_MODEL, + metrics="cosine", # 使用余弦相似度 / Use cosine similarity + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + use_full_text_retrieval=False, # 仅使用向量检索 / Vector only + rerank_factor=2.0, # 重排序因子 / Rerank factor + ), + ) + ) + logger.info("✅ ADB 知识库创建成功 / ADB KB created successfully") + + except ResourceAlreadyExistError: + logger.info( + "ℹ️ ADB 知识库已存在,获取已有资源 / ADB KB exists, getting" + " existing" + ) + kb = client.get(ADB_KB_NAME) + + _log_kb_info(kb) + return kb + + +def query_adb_kb(kb: KnowledgeBase): + """查询 ADB 知识库 / Query ADB knowledge base + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("查询 ADB 知识库") + logger.info("Query ADB knowledge base") + logger.info("=" * 60) + + query_text = "什么是云原生数据库" + logger.info("查询文本 / Query text: %s", query_text) + + try: + results = kb.retrieve(query=query_text) + logger.info("✅ 查询成功 / Query successful") + logger.info("检索结果 / Retrieval results: %s", results) + logger.info( + " - 结果数量 / Result count: %s", len(results.get("data", [])) + ) + except Exception as e: + logger.warning("⚠️ 查询失败(可能是配置或连接问题): %s", e) + + +def query_adb_kb_by_name(knowledgebase_name: str): + """查询 ADB 知识库 / Query ADB knowledge base + Args: + knowledgebase_name: 知识库名称 / Knowledge base name + """ + + try: + kb = KnowledgeBase.get_by_name(knowledgebase_name) + results = kb.retrieve(query="什么是云原生数据库") + logger.info("✅ 查询成功 / Query successful") + logger.info("检索结果 / Retrieval results: %s", results) + logger.info( + " - 结果数量 / Result count: %s", len(results.get("data", [])) + ) + except Exception as e: + logger.warning("⚠️ 查询失败(可能是配置或连接问题): %s", e) + + +def update_adb_kb(kb: KnowledgeBase): + """更新 ADB 知识库配置 / Update ADB knowledge base configuration + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("更新 ADB 知识库配置") + logger.info("Update ADB knowledge base configuration") + logger.info("=" * 60) + + new_description = f"[ADB] 更新于 {time.strftime('%Y-%m-%d %H:%M:%S')}" + + kb.update( + KnowledgeBaseUpdateInput( + description=new_description, + retrieve_settings=ADBRetrieveSettings( + top_k=20, # 增加返回数量 / Increase result count + use_full_text_retrieval=True, # 启用双路召回 / Enable dual recall + rerank_factor=3.0, # 调整重排序因子 / Adjust rerank factor + hybrid_search="RRF", # 使用 RRF 混合检索 / Use RRF hybrid search + hybrid_search_args={"RRF": {"k": 60}}, + ), + ) + ) + + logger.info("✅ ADB 知识库更新成功 / ADB KB updated successfully") + logger.info(" - 新描述 / New description: %s", kb.description) + + +def delete_adb_kb(kb: KnowledgeBase): + """删除 ADB 知识库 / Delete ADB knowledge base + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("删除 ADB 知识库") + logger.info("Delete ADB knowledge base") + logger.info("=" * 60) + + kb.delete() + logger.info("✅ ADB 知识库删除请求已发送 / ADB KB delete request sent") + + try: + client.get(ADB_KB_NAME) + logger.warning("⚠️ ADB 知识库仍然存在 / ADB KB still exists") + except ResourceNotExistError: + logger.info("✅ ADB 知识库已成功删除 / ADB KB deleted successfully") + + # ============================================================================ # 通用工具函数 / Common Utility Functions # ============================================================================ @@ -404,8 +587,10 @@ def list_knowledge_bases(): ragflow_list = KnowledgeBase.list_all( provider=KnowledgeBaseProvider.RAGFLOW.value ) + adb_list = KnowledgeBase.list_all(provider=KnowledgeBaseProvider.ADB.value) logger.info(" - 百炼知识库 / Bailian KBs: %d 个", len(bailian_list)) logger.info(" - RagFlow 知识库 / RagFlow KBs: %d 个", len(ragflow_list)) + logger.info(" - ADB 知识库 / ADB KBs: %d 个", len(adb_list)) # ============================================================================ @@ -457,6 +642,28 @@ def ragflow_example(): logger.info("") +def adb_example(): + """ADB 知识库完整示例 / Complete ADB knowledge base example""" + logger.info("") + logger.info("🔹 ADB 知识库示例 / ADB Knowledge Base Example") + logger.info("=" * 60) + + # 创建 ADB 知识库 / Create ADB KB + kb = create_or_get_adb_kb() + + # 查询 ADB 知识库 / Query ADB KB + query_adb_kb(kb) + + # 更新 ADB 知识库 / Update ADB KB + update_adb_kb(kb) + + # 删除 ADB 知识库 / Delete ADB KB + delete_adb_kb(kb) + + logger.info("🔹 ADB 知识库示例完成 / ADB KB Example Complete") + logger.info("") + + def knowledgebase_example(): """知识库模块完整示例 / Complete knowledge base module example @@ -501,6 +708,15 @@ def ragflow_only_example(): logger.info("🎉 完成 / Complete") +def adb_only_example(): + """仅运行 ADB 知识库示例 / Run ADB knowledge base example only""" + logger.info("🚀 ADB 知识库示例 / ADB KB Example") + list_knowledge_bases() + adb_example() + list_knowledge_bases() + logger.info("🎉 完成 / Complete") + + def multiple_knowledgebase_query(): """多知识库检索 / Multi knowledge base retrieval 根据知识库名称列表进行检索,自动获取各知识库的配置并执行检索。 @@ -509,7 +725,7 @@ def multiple_knowledgebase_query(): """ multi_query_result = KnowledgeBase.multi_retrieve( query="什么是Serverless", - knowledge_base_names=["ragflow-test", "jingsu-bailian"], + knowledge_base_names=["jingsu-bailian", "logantest"], ) logger.info( "多知识库检索结果 / Multi knowledge base retrieval result:\n%s", @@ -536,5 +752,7 @@ def update_ragflow_kb_config(): if __name__ == "__main__": # bailian_only_example() # ragflow_only_example() + # adb_only_example() multiple_knowledgebase_query() + # query_adb_kb_by_name("") # update_ragflow_kb_config() diff --git a/pyproject.toml b/pyproject.toml index effdf68..251c959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "alibabacloud_tea_openapi>=0.4.2", "alibabacloud_bailian20231229>=2.6.2", "agentrun-mem0ai>=0.0.10", + "alibabacloud_gpdb20160503>=5.0.1" ] [project.optional-dependencies] diff --git a/tests/unittests/knowledgebase/__init__.py b/tests/unittests/knowledgebase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/knowledgebase/api/__init__.py b/tests/unittests/knowledgebase/api/__init__.py new file mode 100644 index 0000000..24785e9 --- /dev/null +++ b/tests/unittests/knowledgebase/api/__init__.py @@ -0,0 +1 @@ +"""KnowledgeBase API 单元测试模块 / KnowledgeBase API Unit Test Module""" diff --git a/tests/unittests/knowledgebase/api/test_data.py b/tests/unittests/knowledgebase/api/test_data.py new file mode 100644 index 0000000..8e15902 --- /dev/null +++ b/tests/unittests/knowledgebase/api/test_data.py @@ -0,0 +1,1594 @@ +"""测试 agentrun.knowledgebase.api.data 模块 / Test agentrun.knowledgebase.api.data module""" + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.api.data import ( + ADBDataAPI, + BailianDataAPI, + get_data_api, + KnowledgeBaseDataAPI, + RagFlowDataAPI, +) +from agentrun.knowledgebase.model import ( + ADBProviderSettings, + ADBRetrieveSettings, + BailianProviderSettings, + BailianRetrieveSettings, + KnowledgeBaseProvider, + RagFlowProviderSettings, + RagFlowRetrieveSettings, +) +from agentrun.utils.config import Config + + +class TestKnowledgeBaseDataAPIBase: + """测试 KnowledgeBaseDataAPI 基类""" + + def test_abstract_methods(self): + """测试抽象方法""" + # KnowledgeBaseDataAPI 是抽象类,不能直接实例化 + with pytest.raises(TypeError): + KnowledgeBaseDataAPI("test-kb") # type: ignore + + +class TestRagFlowDataAPIInit: + """测试 RagFlowDataAPI 初始化""" + + def test_init(self): + """测试初始化""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + credential_name="test-credential", + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is not None + assert api.retrieve_settings is not None + assert api.credential_name == "test-credential" + + def test_init_minimal(self): + """测试最小化初始化""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is None + assert api.retrieve_settings is None + assert api.credential_name is None + + def test_init_with_config(self): + """测试带配置初始化""" + config = Config(access_key_id="test-ak") + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + config=config, + ) + + assert api.knowledge_base_name == "test-kb" + + +class TestRagFlowDataAPIGetApiKey: + """测试 RagFlowDataAPI._get_api_key 方法""" + + @patch("agentrun.credential.Credential.get_by_name") + def test_get_api_key_sync(self, mock_get_credential): + """测试同步获取 API Key""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + result = api._get_api_key() + assert result == "test-api-key" + + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_get_api_key_async(self, mock_get_credential): + """测试异步获取 API Key""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + result = await api._get_api_key_async() + assert result == "test-api-key" + + def test_get_api_key_without_credential_name(self): + """测试无凭证名称时获取 API Key""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="credential_name is required"): + api._get_api_key() + + @pytest.mark.asyncio + async def test_get_api_key_async_without_credential_name(self): + """测试异步无凭证名称时获取 API Key""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="credential_name is required"): + await api._get_api_key_async() + + @patch("agentrun.credential.Credential.get_by_name") + def test_get_api_key_empty_secret(self, mock_get_credential): + """测试凭证密钥为空时获取 API Key""" + mock_credential = MagicMock() + mock_credential.credential_secret = None + mock_get_credential.return_value = mock_credential + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + with pytest.raises(ValueError, match="has no secret configured"): + api._get_api_key() + + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_get_api_key_async_empty_secret(self, mock_get_credential): + """测试异步凭证密钥为空时获取 API Key""" + mock_credential = MagicMock() + mock_credential.credential_secret = None + mock_get_credential.return_value = mock_credential + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + with pytest.raises(ValueError, match="has no secret configured"): + await api._get_api_key_async() + + +class TestRagFlowDataAPIBuildRequestBody: + """测试 RagFlowDataAPI._build_request_body 方法""" + + def test_build_request_body(self): + """测试构建请求体""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1", "ds-2"], + ), + ) + + body = api._build_request_body("test query") + assert body["question"] == "test query" + assert body["dataset_ids"] == ["ds-1", "ds-2"] + assert body["page"] == 1 + assert body["page_size"] == 30 + + def test_build_request_body_with_retrieve_settings(self): + """测试带检索设置构建请求体""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + vector_similarity_weight=0.5, + cross_languages=["English", "Chinese"], + ), + ) + + body = api._build_request_body("test query") + assert body["similarity_threshold"] == 0.8 + assert body["vector_similarity_weight"] == 0.5 + assert body["cross_languages"] == ["English", "Chinese"] + + def test_build_request_body_without_provider_settings(self): + """测试无提供商设置构建请求体""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="provider_settings is required"): + api._build_request_body("test query") + + def test_build_request_body_with_partial_retrieve_settings_only_threshold( + self, + ): + """测试仅设置 similarity_threshold 的部分检索设置""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + ) + + body = api._build_request_body("test query") + assert body["similarity_threshold"] == 0.8 + assert "vector_similarity_weight" not in body + assert "cross_languages" not in body + + def test_build_request_body_with_partial_retrieve_settings_only_weight( + self, + ): + """测试仅设置 vector_similarity_weight 的部分检索设置""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + vector_similarity_weight=0.5, + ), + ) + + body = api._build_request_body("test query") + assert "similarity_threshold" not in body + assert body["vector_similarity_weight"] == 0.5 + assert "cross_languages" not in body + + def test_build_request_body_with_partial_retrieve_settings_only_languages( + self, + ): + """测试仅设置 cross_languages 的部分检索设置""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + cross_languages=["English"], + ), + ) + + body = api._build_request_body("test query") + assert "similarity_threshold" not in body + assert "vector_similarity_weight" not in body + assert body["cross_languages"] == ["English"] + + def test_build_request_body_without_retrieve_settings(self): + """测试无检索设置构建请求体""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + body = api._build_request_body("test query") + assert "similarity_threshold" not in body + assert "vector_similarity_weight" not in body + assert "cross_languages" not in body + + +class TestRagFlowDataAPIRetrieve: + """测试 RagFlowDataAPI.retrieve 方法""" + + @patch("httpx.Client") + @patch("agentrun.credential.Credential.get_by_name") + def test_retrieve_sync(self, mock_get_credential, mock_httpx_client): + """测试同步检索""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_response = MagicMock() + mock_response.json.return_value = { + "data": {"chunks": [{"content": "test content"}]} + } + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post.return_value = mock_response + mock_client_instance.__enter__ = MagicMock( + return_value=mock_client_instance + ) + mock_client_instance.__exit__ = MagicMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + assert "data" in result + + @patch("httpx.AsyncClient") + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_retrieve_async(self, mock_get_credential, mock_httpx_client): + """测试异步检索""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_response = MagicMock() + mock_response.json.return_value = { + "data": {"chunks": [{"content": "test content"}]} + } + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + def test_retrieve_without_provider_settings(self): + """测试无提供商设置检索""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @patch("httpx.Client") + @patch("agentrun.credential.Credential.get_by_name") + def test_retrieve_with_false_data( + self, mock_get_credential, mock_httpx_client + ): + """测试检索返回 False 数据""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_response = MagicMock() + mock_response.json.return_value = {"data": False} + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post.return_value = mock_response + mock_client_instance.__enter__ = MagicMock( + return_value=mock_client_instance + ) + mock_client_instance.__exit__ = MagicMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @pytest.mark.asyncio + async def test_retrieve_async_without_provider_settings(self): + """测试异步检索无提供商设置""" + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + credential_name="test-credential", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + + @patch("httpx.AsyncClient") + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_retrieve_async_with_false_data( + self, mock_get_credential, mock_httpx_client + ): + """测试异步检索返回 False 数据""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_response = MagicMock() + mock_response.json.return_value = {"data": False} + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + + @patch("httpx.AsyncClient") + @patch("agentrun.credential.Credential.get_by_name_async") + @pytest.mark.asyncio + async def test_retrieve_async_exception( + self, mock_get_credential, mock_httpx_client + ): + """测试异步检索异常处理""" + mock_credential = MagicMock() + mock_credential.credential_secret = "test-api-key" + mock_get_credential.return_value = mock_credential + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock( + side_effect=Exception("Connection error") + ) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_httpx_client.return_value = mock_client_instance + + api = RagFlowDataAPI( + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + assert "Connection error" in result["data"] + + +class TestBailianDataAPIInit: + """测试 BailianDataAPI 初始化""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_init(self): + """测试初始化""" + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is not None + assert api.retrieve_settings is not None + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_init_minimal(self): + """测试最小化初始化""" + api = BailianDataAPI( + knowledge_base_name="test-kb", + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is None + + +class TestBailianDataAPIRetrieve: + """测试 BailianDataAPI.retrieve 方法""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_sync(self, mock_get_bailian_client): + """测试同步检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [ + MagicMock(text="test content", score=0.9, metadata={}), + ] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async(self, mock_get_bailian_client): + """测试异步检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [ + MagicMock(text="test content", score=0.9, metadata={}), + ] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_retrieve_without_provider_settings(self): + """测试无提供商设置检索""" + api = BailianDataAPI( + knowledge_base_name="test-kb", + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @pytest.mark.asyncio + async def test_retrieve_async_without_provider_settings(self): + """测试异步检索无提供商设置""" + api = BailianDataAPI( + knowledge_base_name="test-kb", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_with_retrieve_settings(self, mock_get_bailian_client): + """测试带检索设置检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + sparse_similarity_top_k=5, + rerank_min_score=0.5, + rerank_top_n=3, + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_with_partial_retrieve_settings_dense_only( + self, mock_get_bailian_client + ): + """测试仅设置 dense_similarity_top_k 的部分检索设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_with_partial_retrieve_settings_sparse_only( + self, mock_get_bailian_client + ): + """测试仅设置 sparse_similarity_top_k 的部分检索设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + sparse_similarity_top_k=5, + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + def test_retrieve_with_partial_retrieve_settings_rerank_only( + self, mock_get_bailian_client + ): + """测试仅设置 rerank 相关的部分检索设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve.return_value = mock_response + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + rerank_min_score=0.5, + rerank_top_n=3, + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_with_partial_settings( + self, mock_get_bailian_client + ): + """测试异步检索带部分设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_with_sparse_only_settings( + self, mock_get_bailian_client + ): + """测试异步检索仅设置 sparse_similarity_top_k""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + sparse_similarity_top_k=5, + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_with_rerank_settings( + self, mock_get_bailian_client + ): + """测试异步检索仅设置 rerank 相关参数""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + rerank_min_score=0.5, + rerank_top_n=3, + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_exception(self, mock_get_bailian_client): + """测试异步检索异常处理""" + mock_client = MagicMock() + mock_client.retrieve_async = AsyncMock( + side_effect=Exception("API error") + ) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + assert "API error" in result["data"] + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_bailian_client") + @pytest.mark.asyncio + async def test_retrieve_async_with_all_settings( + self, mock_get_bailian_client + ): + """测试异步检索带完整设置""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.data.nodes = [] + mock_client.retrieve_async = AsyncMock(return_value=mock_response) + mock_get_bailian_client.return_value = mock_client + + api = BailianDataAPI( + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + sparse_similarity_top_k=5, + rerank_min_score=0.5, + rerank_top_n=3, + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + + +class TestADBDataAPIInit: + """测试 ADBDataAPI 初始化""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_init(self): + """测试初始化""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + ), + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is not None + assert api.retrieve_settings is not None + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_init_minimal(self): + """测试最小化初始化""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + ) + + assert api.knowledge_base_name == "test-kb" + assert api.provider_settings is None + + +class TestADBDataAPIBuildQueryContentRequest: + """测试 ADBDataAPI._build_query_content_request 方法""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request(self): + """测试构建 QueryContent 请求""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + request = api._build_query_content_request("test query") + assert request.content == "test query" + assert request.dbinstance_id == "gp-123456" + assert request.namespace == "public" + assert request.collection == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_settings(self): + """测试带设置构建 QueryContent 请求""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + metrics="cosine", + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + use_full_text_retrieval=True, + rerank_factor=1.5, + recall_window=[-5, 5], + hybrid_search="RRF", + hybrid_search_args={"RRF": {"k": 60}}, + ), + ) + + request = api._build_query_content_request("test query") + assert request.metrics == "cosine" + assert request.top_k == 10 + assert request.use_full_text_retrieval is True + assert request.rerank_factor == 1.5 + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_without_provider_settings(self): + """测试无提供商设置构建 QueryContent 请求""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="provider_settings is required"): + api._build_query_content_request("test query") + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_top_k_only(self): + """测试仅设置 top_k 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + ), + ) + + request = api._build_query_content_request("test query") + assert request.top_k == 10 + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_full_text_only( + self, + ): + """测试仅设置 use_full_text_retrieval 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + use_full_text_retrieval=True, + ), + ) + + request = api._build_query_content_request("test query") + assert request.use_full_text_retrieval is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_rerank_only( + self, + ): + """测试仅设置 rerank_factor 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + rerank_factor=1.5, + ), + ) + + request = api._build_query_content_request("test query") + assert request.rerank_factor == 1.5 + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_recall_window_only( + self, + ): + """测试仅设置 recall_window 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + recall_window=[-5, 5], + ), + ) + + request = api._build_query_content_request("test query") + assert request.recall_window == [-5, 5] + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_hybrid_only( + self, + ): + """测试仅设置 hybrid_search 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + hybrid_search="RRF", + ), + ) + + request = api._build_query_content_request("test query") + assert request.hybrid_search == "RRF" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_build_query_content_request_with_partial_settings_hybrid_args_only( + self, + ): + """测试仅设置 hybrid_search_args 的部分检索设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + hybrid_search_args={"RRF": {"k": 60}}, + ), + ) + + request = api._build_query_content_request("test query") + assert request.hybrid_search_args == {"RRF": {"k": 60}} + + +class TestADBDataAPIParseQueryContentResponse: + """测试 ADBDataAPI._parse_query_content_response 方法""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_parse_query_content_response(self): + """测试解析 QueryContent 响应""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + mock_response = MagicMock() + mock_response.body.matches.match_list = [ + MagicMock( + content="test content", + score=0.9, + id="1", + file_name="test.txt", + metadata={}, + rerank_score=0.95, + retrieval_source="vector", + ), + ] + mock_response.body.request_id = "req-123" + + result = api._parse_query_content_response(mock_response, "test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + assert result["request_id"] == "req-123" + assert len(result["data"]) == 1 + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_parse_query_content_response_empty(self): + """测试解析空 QueryContent 响应""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + mock_response = MagicMock() + mock_response.body.matches = None + mock_response.body.request_id = "req-123" + + result = api._parse_query_content_response(mock_response, "test query") + assert result["data"] == [] + + +class TestADBDataAPIRetrieve: + """测试 ADBDataAPI.retrieve 方法""" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_gpdb_client") + def test_retrieve_sync(self, mock_get_gpdb_client): + """测试同步检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.matches.match_list = [ + MagicMock( + content="test content", + score=0.9, + id="1", + file_name="test.txt", + metadata={}, + rerank_score=0.95, + retrieval_source="vector", + ), + ] + mock_response.body.request_id = "req-123" + mock_client.query_content.return_value = mock_response + mock_get_gpdb_client.return_value = mock_client + + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_gpdb_client") + @pytest.mark.asyncio + async def test_retrieve_async(self, mock_get_gpdb_client): + """测试异步检索""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.body.matches.match_list = [ + MagicMock( + content="test content", + score=0.9, + id="1", + file_name="test.txt", + metadata={}, + rerank_score=0.95, + retrieval_source="vector", + ), + ] + mock_response.body.request_id = "req-123" + mock_client.query_content_async = AsyncMock(return_value=mock_response) + mock_get_gpdb_client.return_value = mock_client + + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = await api.retrieve_async("test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_retrieve_without_provider_settings(self): + """测试无提供商设置检索""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_gpdb_client") + def test_retrieve_exception(self, mock_get_gpdb_client): + """测试检索异常处理""" + mock_client = MagicMock() + mock_client.query_content.side_effect = Exception("Query failed") + mock_get_gpdb_client.return_value = mock_client + + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = api.retrieve("test query") + assert result["error"] is True + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @patch("agentrun.utils.control_api.ControlAPI._get_gpdb_client") + @pytest.mark.asyncio + async def test_retrieve_async_exception(self, mock_get_gpdb_client): + """测试异步检索异常处理""" + mock_client = MagicMock() + mock_client.query_content_async = AsyncMock( + side_effect=Exception("Query failed") + ) + mock_get_gpdb_client.return_value = mock_client + + api = ADBDataAPI( + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + assert "Query failed" in result["data"] + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + @pytest.mark.asyncio + async def test_retrieve_async_without_provider_settings(self): + """测试异步检索无提供商设置""" + api = ADBDataAPI( + knowledge_base_name="test-kb", + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + + +class TestGetDataAPI: + """测试 get_data_api 工厂函数""" + + def test_get_data_api_ragflow(self): + """测试获取 RagFlow 数据链路 API""" + api = get_data_api( + provider=KnowledgeBaseProvider.RAGFLOW, + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + assert isinstance(api, RagFlowDataAPI) + + def test_get_data_api_ragflow_string(self): + """测试使用字符串获取 RagFlow 数据链路 API""" + api = get_data_api( + provider="ragflow", # type: ignore + knowledge_base_name="test-kb", + ) + + assert isinstance(api, RagFlowDataAPI) + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_get_data_api_bailian(self): + """测试获取百炼数据链路 API""" + api = get_data_api( + provider=KnowledgeBaseProvider.BAILIAN, + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + assert isinstance(api, BailianDataAPI) + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_get_data_api_bailian_string(self): + """测试使用字符串获取百炼数据链路 API""" + api = get_data_api( + provider="bailian", # type: ignore + knowledge_base_name="test-kb", + ) + + assert isinstance(api, BailianDataAPI) + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_get_data_api_adb(self): + """测试获取 ADB 数据链路 API""" + api = get_data_api( + provider=KnowledgeBaseProvider.ADB, + knowledge_base_name="test-kb", + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + assert isinstance(api, ADBDataAPI) + + @patch.dict( + os.environ, + { + "AGENTRUN_ACCESS_KEY_ID": "test-ak", + "AGENTRUN_ACCESS_KEY_SECRET": "test-sk", + }, + ) + def test_get_data_api_adb_string(self): + """测试使用字符串获取 ADB 数据链路 API""" + api = get_data_api( + provider="adb", # type: ignore + knowledge_base_name="test-kb", + ) + + assert isinstance(api, ADBDataAPI) + + def test_get_data_api_unsupported_provider(self): + """测试不支持的提供商""" + with pytest.raises(ValueError, match="Unsupported provider type"): + get_data_api( + provider="unsupported", # type: ignore + knowledge_base_name="test-kb", + ) + + def test_get_data_api_with_wrong_settings_type(self): + """测试使用错误类型的设置""" + # 传入 Bailian 设置给 RagFlow + api = get_data_api( + provider=KnowledgeBaseProvider.RAGFLOW, + knowledge_base_name="test-kb", + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + # 应该返回 RagFlowDataAPI,但 provider_settings 会是 None + assert isinstance(api, RagFlowDataAPI) + assert api.provider_settings is None + + def test_get_data_api_with_config(self): + """测试带配置获取数据链路 API""" + config = Config(access_key_id="test-ak") + api = get_data_api( + provider=KnowledgeBaseProvider.RAGFLOW, + knowledge_base_name="test-kb", + config=config, + ) + + assert isinstance(api, RagFlowDataAPI) + + def test_get_data_api_with_retrieve_settings(self): + """测试带检索设置获取数据链路 API""" + api = get_data_api( + provider=KnowledgeBaseProvider.RAGFLOW, + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + credential_name="test-credential", + ) + + assert isinstance(api, RagFlowDataAPI) + assert api.retrieve_settings is not None diff --git a/tests/unittests/knowledgebase/test_client.py b/tests/unittests/knowledgebase/test_client.py new file mode 100644 index 0000000..9601c6c --- /dev/null +++ b/tests/unittests/knowledgebase/test_client.py @@ -0,0 +1,639 @@ +"""测试 agentrun.knowledgebase.client 模块 / Test agentrun.knowledgebase.client module""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.client import KnowledgeBaseClient +from agentrun.knowledgebase.model import ( + ADBProviderSettings, + BailianProviderSettings, + KnowledgeBaseCreateInput, + KnowledgeBaseListInput, + KnowledgeBaseProvider, + KnowledgeBaseUpdateInput, + RagFlowProviderSettings, +) +from agentrun.utils.config import Config +from agentrun.utils.exception import ( + HTTPError, + ResourceAlreadyExistError, + ResourceNotExistError, +) + + +class MockKnowledgeBaseData: + """模拟知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-123", + "knowledgeBaseName": "test-kb", + "provider": "ragflow", + "description": "Test knowledge base", + "credentialName": "test-credential", + "providerSettings": { + "baseUrl": "https://ragflow.example.com", + "datasetIds": ["ds-1"], + }, + "retrieveSettings": { + "similarityThreshold": 0.8, + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockBailianKnowledgeBaseData: + """模拟百炼知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-456", + "knowledgeBaseName": "test-bailian-kb", + "provider": "bailian", + "description": "Test Bailian knowledge base", + "providerSettings": { + "workspaceId": "ws-123", + "indexIds": ["idx-1"], + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockADBKnowledgeBaseData: + """模拟 ADB 知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-789", + "knowledgeBaseName": "test-adb-kb", + "provider": "adb", + "description": "Test ADB knowledge base", + "providerSettings": { + "DBInstanceId": "gp-123456", + "Namespace": "public", + "NamespacePassword": "password123", + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockListResult: + """模拟列表结果""" + + def __init__(self, items): + self.items = items + + +class TestKnowledgeBaseClientInit: + """测试 KnowledgeBaseClient 初始化""" + + def test_init_without_config(self): + """测试不带配置的初始化""" + client = KnowledgeBaseClient() + assert client is not None + + def test_init_with_config(self): + """测试带配置的初始化""" + config = Config(access_key_id="test-ak") + client = KnowledgeBaseClient(config=config) + assert client is not None + + +class TestKnowledgeBaseClientCreate: + """测试 KnowledgeBaseClient.create 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_sync(self, mock_control_api_class): + """测试同步创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + description="Test knowledge base", + ) + + result = client.create(input_obj) + assert result.knowledge_base_name == "test-kb" + assert mock_control_api.create_knowledge_base.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_create_async(self, mock_control_api_class): + """测试异步创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + result = await client.create_async(input_obj) + assert result.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_bailian_kb(self, mock_control_api_class): + """测试创建百炼知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockBailianKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-bailian-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + result = client.create(input_obj) + assert result.knowledge_base_name == "test-bailian-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_adb_kb(self, mock_control_api_class): + """测试创建 ADB 知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockADBKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-adb-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + result = client.create(input_obj) + assert result.knowledge_base_name == "test-adb-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_with_config(self, mock_control_api_class): + """测试带配置创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + config = Config(access_key_id="custom-ak") + + result = client.create(input_obj, config=config) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_already_exists(self, mock_control_api_class): + """测试创建已存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.side_effect = HTTPError( + status_code=409, + message="Resource already exists", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="existing-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + with pytest.raises(ResourceAlreadyExistError): + client.create(input_obj) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_create_async_already_exists(self, mock_control_api_class): + """测试异步创建已存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base_async = AsyncMock( + side_effect=HTTPError( + status_code=409, + message="Resource already exists", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="existing-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + with pytest.raises(ResourceAlreadyExistError): + await client.create_async(input_obj) + + +class TestKnowledgeBaseClientDelete: + """测试 KnowledgeBaseClient.delete 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_sync(self, mock_control_api_class): + """测试同步删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.delete("test-kb") + assert result is not None + assert mock_control_api.delete_knowledge_base.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_delete_async(self, mock_control_api_class): + """测试异步删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = await client.delete_async("test-kb") + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_with_config(self, mock_control_api_class): + """测试带配置删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + config = Config(access_key_id="custom-ak") + result = client.delete("test-kb", config=config) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_not_exist(self, mock_control_api_class): + """测试删除不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + with pytest.raises(ResourceNotExistError): + client.delete("nonexistent-kb") + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_delete_async_not_exist(self, mock_control_api_class): + """测试异步删除不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base_async = AsyncMock( + side_effect=HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + with pytest.raises(ResourceNotExistError): + await client.delete_async("nonexistent-kb") + + +class TestKnowledgeBaseClientUpdate: + """测试 KnowledgeBaseClient.update 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_sync(self, mock_control_api_class): + """测试同步更新知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput(description="Updated description") + result = client.update("test-kb", input_obj) + assert result is not None + assert mock_control_api.update_knowledge_base.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_update_async(self, mock_control_api_class): + """测试异步更新知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + result = await client.update_async("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_with_provider_settings(self, mock_control_api_class): + """测试更新知识库(带提供商设置)""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput( + provider_settings=RagFlowProviderSettings( + base_url="https://new-ragflow.example.com", + dataset_ids=["ds-new"], + ), + ) + result = client.update("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_with_credential(self, mock_control_api_class): + """测试更新知识库(带凭证)""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput( + credential_name="new-credential", + ) + result = client.update("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_not_exist(self, mock_control_api_class): + """测试更新不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + with pytest.raises(ResourceNotExistError): + client.update("nonexistent-kb", input_obj) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_update_async_not_exist(self, mock_control_api_class): + """测试异步更新不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base_async = AsyncMock( + side_effect=HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + with pytest.raises(ResourceNotExistError): + await client.update_async("nonexistent-kb", input_obj) + + +class TestKnowledgeBaseClientGet: + """测试 KnowledgeBaseClient.get 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_sync(self, mock_control_api_class): + """测试同步获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.get("test-kb") + assert result.knowledge_base_name == "test-kb" + assert mock_control_api.get_knowledge_base.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_get_async(self, mock_control_api_class): + """测试异步获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = await client.get_async("test-kb") + assert result.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_with_config(self, mock_control_api_class): + """测试带配置获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + config = Config(access_key_id="custom-ak") + result = client.get("test-kb", config=config) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_not_exist(self, mock_control_api_class): + """测试获取不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + with pytest.raises(ResourceNotExistError): + client.get("nonexistent-kb") + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_get_async_not_exist(self, mock_control_api_class): + """测试异步获取不存在的知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + side_effect=HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + with pytest.raises(ResourceNotExistError): + await client.get_async("nonexistent-kb") + + +class TestKnowledgeBaseClientList: + """测试 KnowledgeBaseClient.list 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_sync(self, mock_control_api_class): + """测试同步列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult([ + MockKnowledgeBaseData(), + MockBailianKnowledgeBaseData(), + ]) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.list() + assert len(result) == 2 + assert mock_control_api.list_knowledge_bases.called + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_list_async(self, mock_control_api_class): + """测试异步列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases_async = AsyncMock( + return_value=MockListResult([MockKnowledgeBaseData()]) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = await client.list_async() + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_with_input(self, mock_control_api_class): + """测试同步列出知识库(带输入参数)""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult( + [MockKnowledgeBaseData()] + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseListInput( + page_number=1, + page_size=10, + provider=KnowledgeBaseProvider.RAGFLOW, + ) + result = client.list(input=input_obj) + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_list_async_with_input(self, mock_control_api_class): + """测试异步列出知识库(带输入参数)""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases_async = AsyncMock( + return_value=MockListResult([MockKnowledgeBaseData()]) + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + input_obj = KnowledgeBaseListInput(page_number=1, page_size=10) + result = await client.list_async(input=input_obj) + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_empty(self, mock_control_api_class): + """测试列出空知识库列表""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult([]) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.list() + assert len(result) == 0 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_with_none_input(self, mock_control_api_class): + """测试列出知识库(输入为 None)""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult( + [MockKnowledgeBaseData()] + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + result = client.list(input=None) + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_with_config(self, mock_control_api_class): + """测试带配置列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult( + [MockKnowledgeBaseData()] + ) + mock_control_api_class.return_value = mock_control_api + + client = KnowledgeBaseClient() + config = Config(access_key_id="custom-ak") + result = client.list(config=config) + assert len(result) == 1 diff --git a/tests/unittests/knowledgebase/test_knowledgebase.py b/tests/unittests/knowledgebase/test_knowledgebase.py new file mode 100644 index 0000000..8936d9d --- /dev/null +++ b/tests/unittests/knowledgebase/test_knowledgebase.py @@ -0,0 +1,1306 @@ +"""测试 agentrun.knowledgebase.knowledgebase 模块 / Test agentrun.knowledgebase.knowledgebase module""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.knowledgebase import KnowledgeBase +from agentrun.knowledgebase.model import ( + ADBProviderSettings, + ADBRetrieveSettings, + BailianProviderSettings, + BailianRetrieveSettings, + KnowledgeBaseCreateInput, + KnowledgeBaseProvider, + KnowledgeBaseUpdateInput, + RagFlowProviderSettings, + RagFlowRetrieveSettings, +) +from agentrun.utils.config import Config + + +class MockKnowledgeBaseData: + """模拟知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-123", + "knowledgeBaseName": "test-kb", + "provider": "ragflow", + "description": "Test knowledge base", + "credentialName": "test-credential", + "providerSettings": { + "baseUrl": "https://ragflow.example.com", + "datasetIds": ["ds-1"], + }, + "retrieveSettings": { + "similarityThreshold": 0.8, + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockBailianKnowledgeBaseData: + """模拟百炼知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-456", + "knowledgeBaseName": "test-bailian-kb", + "provider": "bailian", + "description": "Test Bailian knowledge base", + "providerSettings": { + "workspaceId": "ws-123", + "indexIds": ["idx-1"], + }, + "retrieveSettings": { + "denseSimilarityTopK": 10, + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockADBKnowledgeBaseData: + """模拟 ADB 知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-789", + "knowledgeBaseName": "test-adb-kb", + "provider": "adb", + "description": "Test ADB knowledge base", + "providerSettings": { + "DBInstanceId": "gp-123456", + "Namespace": "public", + "NamespacePassword": "password123", + }, + "retrieveSettings": { + "TopK": 10, + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockListResult: + """模拟列表结果""" + + def __init__(self, items): + self.items = items + + +class TestKnowledgeBaseCreate: + """测试 KnowledgeBase.create 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_create_sync(self, mock_control_api_class): + """测试同步创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + description="Test knowledge base", + ) + + result = KnowledgeBase.create(input_obj) + assert result.knowledge_base_name == "test-kb" + assert result.knowledge_base_id == "kb-123" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_create_async(self, mock_control_api_class): + """测试异步创建知识库""" + mock_control_api = MagicMock() + mock_control_api.create_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + + result = await KnowledgeBase.create_async(input_obj) + assert result.knowledge_base_name == "test-kb" + + +class TestKnowledgeBaseDelete: + """测试 KnowledgeBase.delete 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_by_name_sync(self, mock_control_api_class): + """测试根据名称同步删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase.delete_by_name("test-kb") + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_delete_by_name_async(self, mock_control_api_class): + """测试根据名称异步删除知识库""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase.delete_by_name_async("test-kb") + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_delete_instance_sync(self, mock_control_api_class): + """测试实例同步删除""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = KnowledgeBase.get_by_name("test-kb") + + # 删除实例 + result = kb.delete() + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_delete_instance_async(self, mock_control_api_class): + """测试实例异步删除""" + mock_control_api = MagicMock() + mock_control_api.delete_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = await KnowledgeBase.get_by_name_async("test-kb") + + # 删除实例 + result = await kb.delete_async() + assert result is not None + + def test_delete_instance_without_name(self): + """测试实例删除(无名称)""" + kb = KnowledgeBase() + with pytest.raises(ValueError, match="knowledge_base_name is required"): + kb.delete() + + @pytest.mark.asyncio + async def test_delete_instance_async_without_name(self): + """测试异步实例删除(无名称)""" + kb = KnowledgeBase() + with pytest.raises(ValueError, match="knowledge_base_name is required"): + await kb.delete_async() + + +class TestKnowledgeBaseUpdate: + """测试 KnowledgeBase.update 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_by_name_sync(self, mock_control_api_class): + """测试根据名称同步更新知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = KnowledgeBaseUpdateInput(description="Updated description") + result = KnowledgeBase.update_by_name("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_update_by_name_async(self, mock_control_api_class): + """测试根据名称异步更新知识库""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = KnowledgeBaseUpdateInput(description="Updated") + result = await KnowledgeBase.update_by_name_async("test-kb", input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_update_instance_sync(self, mock_control_api_class): + """测试实例同步更新""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = KnowledgeBase.get_by_name("test-kb") + + # 更新实例 + input_obj = KnowledgeBaseUpdateInput(description="Updated") + result = kb.update(input_obj) + assert result is not None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_update_instance_async(self, mock_control_api_class): + """测试实例异步更新""" + mock_control_api = MagicMock() + mock_control_api.update_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = await KnowledgeBase.get_by_name_async("test-kb") + + # 更新实例 + input_obj = KnowledgeBaseUpdateInput(description="Updated") + result = await kb.update_async(input_obj) + assert result is not None + + def test_update_instance_without_name(self): + """测试实例更新(无名称)""" + kb = KnowledgeBase() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + with pytest.raises(ValueError, match="knowledge_base_name is required"): + kb.update(input_obj) + + @pytest.mark.asyncio + async def test_update_instance_async_without_name(self): + """测试异步实例更新(无名称)""" + kb = KnowledgeBase() + input_obj = KnowledgeBaseUpdateInput(description="Updated") + with pytest.raises(ValueError, match="knowledge_base_name is required"): + await kb.update_async(input_obj) + + +class TestKnowledgeBaseGet: + """测试 KnowledgeBase.get_by_name 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_by_name_sync(self, mock_control_api_class): + """测试同步获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase.get_by_name("test-kb") + assert result.knowledge_base_name == "test-kb" + assert result.knowledge_base_id == "kb-123" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_get_by_name_async(self, mock_control_api_class): + """测试异步获取知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase.get_by_name_async("test-kb") + assert result.knowledge_base_name == "test-kb" + + +class TestKnowledgeBaseRefresh: + """测试 KnowledgeBase.refresh 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_refresh_sync(self, mock_control_api_class): + """测试同步刷新知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = KnowledgeBase.get_by_name("test-kb") + + # 刷新实例 + kb.refresh() + assert kb.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_refresh_async(self, mock_control_api_class): + """测试异步刷新知识库""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + kb = await KnowledgeBase.get_by_name_async("test-kb") + + # 刷新实例 + await kb.refresh_async() + assert kb.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_get_instance_sync(self, mock_control_api_class): + """测试实例 get 方法同步""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + kb = KnowledgeBase.get_by_name("test-kb") + result = kb.get() + assert result.knowledge_base_name == "test-kb" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_get_instance_async(self, mock_control_api_class): + """测试实例 get 方法异步""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + kb = await KnowledgeBase.get_by_name_async("test-kb") + result = await kb.get_async() + assert result.knowledge_base_name == "test-kb" + + def test_get_instance_without_name(self): + """测试实例获取(无名称)""" + kb = KnowledgeBase() + with pytest.raises(ValueError, match="knowledge_base_name is required"): + kb.get() + + @pytest.mark.asyncio + async def test_get_instance_async_without_name(self): + """测试异步实例获取(无名称)""" + kb = KnowledgeBase() + with pytest.raises(ValueError, match="knowledge_base_name is required"): + await kb.get_async() + + +class TestKnowledgeBaseList: + """测试 KnowledgeBase.list_all 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_all_sync(self, mock_control_api_class): + """测试同步列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult([ + MockKnowledgeBaseData(), + MockBailianKnowledgeBaseData(), + ]) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase.list_all() + # list_all 会对结果去重,所以相同 ID 的记录只会返回一个 + assert len(result) >= 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_list_all_async(self, mock_control_api_class): + """测试异步列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases_async = AsyncMock( + return_value=MockListResult([MockKnowledgeBaseData()]) + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase.list_all_async() + assert len(result) == 1 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_list_all_with_provider(self, mock_control_api_class): + """测试按提供商列出知识库""" + mock_control_api = MagicMock() + mock_control_api.list_knowledge_bases.return_value = MockListResult( + [MockKnowledgeBaseData()] + ) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase.list_all(provider="ragflow") + assert len(result) >= 1 + + +class TestKnowledgeBaseFromInnerObject: + """测试 KnowledgeBase.from_inner_object 方法""" + + def test_from_inner_object(self): + """测试从内部对象创建知识库""" + mock_data = MockKnowledgeBaseData() + kb = KnowledgeBase.from_inner_object(mock_data) + + assert kb.knowledge_base_id == "kb-123" + assert kb.knowledge_base_name == "test-kb" + assert kb.provider == "ragflow" + assert kb.description == "Test knowledge base" + + def test_from_inner_object_with_extra(self): + """测试从内部对象创建知识库(带额外字段)""" + mock_data = MockKnowledgeBaseData() + extra = {"custom_field": "custom_value"} + kb = KnowledgeBase.from_inner_object(mock_data, extra) + + assert kb.knowledge_base_name == "test-kb" + + def test_from_inner_object_bailian(self): + """测试从内部对象创建百炼知识库""" + mock_data = MockBailianKnowledgeBaseData() + kb = KnowledgeBase.from_inner_object(mock_data) + + assert kb.knowledge_base_id == "kb-456" + assert kb.knowledge_base_name == "test-bailian-kb" + assert kb.provider == "bailian" + + def test_from_inner_object_adb(self): + """测试从内部对象创建 ADB 知识库""" + mock_data = MockADBKnowledgeBaseData() + kb = KnowledgeBase.from_inner_object(mock_data) + + assert kb.knowledge_base_id == "kb-789" + assert kb.knowledge_base_name == "test-adb-kb" + assert kb.provider == "adb" + + +class TestKnowledgeBaseGetDataAPI: + """测试 KnowledgeBase._get_data_api 方法""" + + def test_get_data_api_ragflow(self): + """测试获取 RagFlow 数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + + def test_get_data_api_bailian(self): + """测试获取百炼数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + + def test_get_data_api_adb(self): + """测试获取 ADB 数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + retrieve_settings=ADBRetrieveSettings( + top_k=10, + ), + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + + def test_get_data_api_with_dict_settings(self): + """测试使用字典设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings={ + "base_url": "https://ragflow.example.com", + "dataset_ids": ["ds-1"], + }, + retrieve_settings={ + "similarity_threshold": 0.8, + }, + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + + def test_get_data_api_bailian_with_dict_settings(self): + """测试百炼使用字典设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings={ + "workspace_id": "ws-123", + "index_ids": ["idx-1"], + }, + retrieve_settings={ + "dense_similarity_top_k": 10, + }, + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + + def test_get_data_api_adb_with_dict_settings(self): + """测试 ADB 使用字典设置获取数据链路 API(PascalCase 键名)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings={ + "DBInstanceId": "gp-123456", + "Namespace": "public", + "NamespacePassword": "password123", + }, + retrieve_settings={ + "TopK": 10, + }, + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + + def test_get_data_api_bailian_with_raw_dict_settings(self): + """测试百炼使用原始字典设置获取数据链路 API(绕过 Pydantic 转换)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + ) + # Bypass Pydantic validation to set raw dict + object.__setattr__( + kb, + "provider_settings", + { + "workspace_id": "ws-123", + "index_ids": ["idx-1"], + }, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "dense_similarity_top_k": 10, + }, + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + + def test_get_data_api_ragflow_with_raw_dict_settings(self): + """测试 RagFlow 使用原始字典设置获取数据链路 API(绕过 Pydantic 转换)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + credential_name="test-credential", + ) + # Bypass Pydantic validation to set raw dict + object.__setattr__( + kb, + "provider_settings", + { + "base_url": "https://ragflow.example.com", + "dataset_ids": ["ds-1"], + }, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "similarity_threshold": 0.8, + }, + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + + def test_get_data_api_adb_with_raw_dict_settings(self): + """测试 ADB 使用原始字典设置获取数据链路 API(绕过 Pydantic 转换,PascalCase 键名)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + ) + # Bypass Pydantic validation to set raw dict with PascalCase keys + object.__setattr__( + kb, + "provider_settings", + { + "DBInstanceId": "gp-123456", + "Namespace": "public", + "NamespacePassword": "password123", + "EmbeddingModel": "text-embedding-v1", + "Metrics": "cosine", + "Metadata": '{"key": "value"}', + }, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "TopK": 10, + "UseFullTextRetrieval": True, + "RerankFactor": 1.5, + "RecallWindow": [-5, 5], + "HybridSearch": "RRF", + "HybridSearchArgs": {"RRF": {"k": 60}}, + }, + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + + def test_get_data_api_without_provider(self): + """测试获取数据链路 API(无提供商)""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + ) + + with pytest.raises(ValueError, match="provider is required"): + kb._get_data_api() + + def test_get_data_api_with_string_provider(self): + """测试使用字符串提供商获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider="ragflow", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + + def test_get_data_api_bailian_without_settings(self): + """测试百炼无设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + assert data_api.provider_settings is None + assert data_api.retrieve_settings is None + + def test_get_data_api_bailian_without_retrieve_settings(self): + """测试百炼无检索设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + assert data_api.provider_settings is not None + assert data_api.retrieve_settings is None + + def test_get_data_api_ragflow_without_settings(self): + """测试 RagFlow 无设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + assert data_api.provider_settings is None + assert data_api.retrieve_settings is None + + def test_get_data_api_ragflow_without_retrieve_settings(self): + """测试 RagFlow 无检索设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + assert data_api.provider_settings is not None + assert data_api.retrieve_settings is None + + def test_get_data_api_adb_without_settings(self): + """测试 ADB 无设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + assert data_api.provider_settings is None + assert data_api.retrieve_settings is None + + def test_get_data_api_adb_without_retrieve_settings(self): + """测试 ADB 无检索设置获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + assert data_api.provider_settings is not None + assert data_api.retrieve_settings is None + + def test_get_data_api_bailian_with_invalid_provider_settings_type(self): + """测试百炼使用无效类型的 provider_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + ) + # Set an invalid type (not BailianProviderSettings or dict) + object.__setattr__(kb, "provider_settings", "invalid") + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + # converted_provider_settings should be None as the type is invalid + assert data_api.provider_settings is None + + def test_get_data_api_bailian_with_invalid_retrieve_settings_type(self): + """测试百炼使用无效类型的 retrieve_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + ) + # Set an invalid type (not BailianRetrieveSettings or dict) + object.__setattr__(kb, "retrieve_settings", "invalid") + + from agentrun.knowledgebase.api.data import BailianDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, BailianDataAPI) + assert data_api.provider_settings is not None + # converted_retrieve_settings should be None as the type is invalid + assert data_api.retrieve_settings is None + + def test_get_data_api_ragflow_with_invalid_provider_settings_type(self): + """测试 RagFlow 使用无效类型的 provider_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + credential_name="test-credential", + ) + # Set an invalid type (not RagFlowProviderSettings or dict) + object.__setattr__(kb, "provider_settings", "invalid") + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + # converted_provider_settings should be None as the type is invalid + assert data_api.provider_settings is None + + def test_get_data_api_ragflow_with_invalid_retrieve_settings_type(self): + """测试 RagFlow 使用无效类型的 retrieve_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + # Set an invalid type (not RagFlowRetrieveSettings or dict) + object.__setattr__(kb, "retrieve_settings", "invalid") + + from agentrun.knowledgebase.api.data import RagFlowDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, RagFlowDataAPI) + assert data_api.provider_settings is not None + # converted_retrieve_settings should be None as the type is invalid + assert data_api.retrieve_settings is None + + def test_get_data_api_adb_with_invalid_provider_settings_type(self): + """测试 ADB 使用无效类型的 provider_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + ) + # Set an invalid type (not ADBProviderSettings or dict) + object.__setattr__(kb, "provider_settings", "invalid") + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + # converted_provider_settings should be None as the type is invalid + assert data_api.provider_settings is None + + def test_get_data_api_adb_with_invalid_retrieve_settings_type(self): + """测试 ADB 使用无效类型的 retrieve_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + # Set an invalid type (not ADBRetrieveSettings or dict) + object.__setattr__(kb, "retrieve_settings", "invalid") + + from agentrun.knowledgebase.api.data import ADBDataAPI + + data_api = kb._get_data_api() + assert isinstance(data_api, ADBDataAPI) + assert data_api.provider_settings is not None + # converted_retrieve_settings should be None as the type is invalid + assert data_api.retrieve_settings is None + + def test_get_data_api_with_unknown_provider(self): + """测试使用未知提供商获取数据链路 API""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + ) + # Set an unknown provider that's not in KnowledgeBaseProvider enum + object.__setattr__(kb, "provider", "unknown_provider") + + # get_data_api should raise an error for unsupported provider + # The error comes from trying to convert to KnowledgeBaseProvider enum + with pytest.raises( + ValueError, match="is not a valid KnowledgeBaseProvider" + ): + kb._get_data_api() + + +class TestKnowledgeBaseRetrieve: + """测试 KnowledgeBase.retrieve 方法""" + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_retrieve_sync(self, mock_retrieve): + """测试同步检索""" + mock_retrieve.return_value = { + "data": [{"content": "test content", "score": 0.9}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = kb.retrieve("test query") + assert result["query"] == "test query" + assert "data" in result + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_retrieve_async(self, mock_retrieve_async): + """测试异步检索""" + mock_retrieve_async.return_value = { + "data": [{"content": "test content", "score": 0.9}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await kb.retrieve_async("test query") + assert result["query"] == "test query" + assert "data" in result + + +class TestKnowledgeBaseSafeGetKB: + """测试 KnowledgeBase._safe_get_kb 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_safe_get_kb_success(self, mock_control_api_class): + """测试安全获取知识库成功""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase._safe_get_kb("test-kb") + assert isinstance(result, KnowledgeBase) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_safe_get_kb_failure(self, mock_control_api_class): + """测试安全获取知识库失败""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.side_effect = Exception("Not found") + mock_control_api_class.return_value = mock_control_api + + result = KnowledgeBase._safe_get_kb("test-kb") + assert isinstance(result, Exception) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_safe_get_kb_async_success(self, mock_control_api_class): + """测试异步安全获取知识库成功""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase._safe_get_kb_async("test-kb") + assert isinstance(result, KnowledgeBase) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_safe_get_kb_async_failure(self, mock_control_api_class): + """测试异步安全获取知识库失败""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + side_effect=Exception("Not found") + ) + mock_control_api_class.return_value = mock_control_api + + result = await KnowledgeBase._safe_get_kb_async("test-kb") + assert isinstance(result, Exception) + + +class TestKnowledgeBaseSafeRetrieveKB: + """测试 KnowledgeBase._safe_retrieve_kb 方法""" + + def test_safe_retrieve_kb_with_exception(self): + """测试安全检索知识库(传入异常)""" + error = Exception("Not found") + result = KnowledgeBase._safe_retrieve_kb("test-kb", error, "test query") + + assert result["error"] is True + assert "Failed to retrieve" in result["data"] + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_safe_retrieve_kb_success(self, mock_retrieve): + """测试安全检索知识库成功""" + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = KnowledgeBase._safe_retrieve_kb("test-kb", kb, "test query") + assert "data" in result + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_safe_retrieve_kb_failure(self, mock_retrieve): + """测试安全检索知识库失败""" + mock_retrieve.side_effect = Exception("Retrieve failed") + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = KnowledgeBase._safe_retrieve_kb("test-kb", kb, "test query") + assert result["error"] is True + + @pytest.mark.asyncio + async def test_safe_retrieve_kb_async_with_exception(self): + """测试异步安全检索知识库(传入异常)""" + error = Exception("Not found") + result = await KnowledgeBase._safe_retrieve_kb_async( + "test-kb", error, "test query" + ) + + assert result["error"] is True + assert "Failed to retrieve" in result["data"] + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_safe_retrieve_kb_async_success(self, mock_retrieve_async): + """测试异步安全检索知识库成功""" + mock_retrieve_async.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await KnowledgeBase._safe_retrieve_kb_async( + "test-kb", kb, "test query" + ) + assert "data" in result + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_safe_retrieve_kb_async_failure(self, mock_retrieve_async): + """测试异步安全检索知识库失败""" + mock_retrieve_async.side_effect = Exception("Retrieve failed") + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + result = await KnowledgeBase._safe_retrieve_kb_async( + "test-kb", kb, "test query" + ) + assert result["error"] is True + + +class TestKnowledgeBaseMultiRetrieve: + """测试 KnowledgeBase.multi_retrieve 方法""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_multi_retrieve_sync(self, mock_retrieve, mock_control_api_class): + """测试同步多知识库检索""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + result = KnowledgeBase.multi_retrieve( + query="test query", + knowledge_base_names=["kb-1", "kb-2"], + ) + + assert "results" in result + assert "query" in result + assert result["query"] == "test query" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_multi_retrieve_async( + self, mock_retrieve_async, mock_control_api_class + ): + """测试异步多知识库检索""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + mock_retrieve_async.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + result = await KnowledgeBase.multi_retrieve_async( + query="test query", + knowledge_base_names=["kb-1", "kb-2"], + ) + + assert "results" in result + assert "query" in result + assert result["query"] == "test query" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + def test_multi_retrieve_with_partial_failure(self, mock_control_api_class): + """测试同步多知识库检索(部分失败)""" + mock_control_api = MagicMock() + # 第一个成功,第二个失败 + mock_control_api.get_knowledge_base.side_effect = [ + MockKnowledgeBaseData(), + Exception("Not found"), + ] + mock_control_api_class.return_value = mock_control_api + + with patch( + "agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "kb-1", + } + + result = KnowledgeBase.multi_retrieve( + query="test query", + knowledge_base_names=["kb-1", "kb-2"], + ) + + assert "results" in result + # kb-2 应该有错误 + assert "kb-2" in result["results"] + assert result["results"]["kb-2"]["error"] is True + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @pytest.mark.asyncio + async def test_multi_retrieve_async_with_partial_failure( + self, mock_control_api_class + ): + """测试异步多知识库检索(部分失败)""" + mock_control_api = MagicMock() + + # 创建一个返回不同结果的 side_effect 函数 + call_count = [0] + + async def mock_get_async(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return MockKnowledgeBaseData() + else: + raise Exception("Not found") + + mock_control_api.get_knowledge_base_async = mock_get_async + mock_control_api_class.return_value = mock_control_api + + with patch( + "agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve_async" + ) as mock_retrieve_async: + mock_retrieve_async.return_value = { + "data": [{"content": "test"}], + "query": "test query", + "knowledge_base_name": "kb-1", + } + + result = await KnowledgeBase.multi_retrieve_async( + query="test query", + knowledge_base_names=["kb-1", "kb-2"], + ) + + assert "results" in result diff --git a/tests/unittests/knowledgebase/test_model.py b/tests/unittests/knowledgebase/test_model.py new file mode 100644 index 0000000..5f5e335 --- /dev/null +++ b/tests/unittests/knowledgebase/test_model.py @@ -0,0 +1,650 @@ +"""测试 agentrun.knowledgebase.model 模块 / Test agentrun.knowledgebase.model module""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.model import ( + ADBProviderSettings, + ADBRetrieveSettings, + BailianProviderSettings, + BailianRetrieveSettings, + KnowledgeBaseCreateInput, + KnowledgeBaseImmutableProps, + KnowledgeBaseListInput, + KnowledgeBaseListOutput, + KnowledgeBaseMutableProps, + KnowledgeBaseProvider, + KnowledgeBaseSystemProps, + KnowledgeBaseUpdateInput, + RagFlowProviderSettings, + RagFlowRetrieveSettings, + RetrieveInput, +) + + +class TestKnowledgeBaseProvider: + """测试 KnowledgeBaseProvider 枚举""" + + def test_ragflow_value(self): + """测试 RAGFLOW 枚举值""" + assert KnowledgeBaseProvider.RAGFLOW.value == "ragflow" + + def test_bailian_value(self): + """测试 BAILIAN 枚举值""" + assert KnowledgeBaseProvider.BAILIAN.value == "bailian" + + def test_adb_value(self): + """测试 ADB 枚举值""" + assert KnowledgeBaseProvider.ADB.value == "adb" + + def test_provider_is_string_enum(self): + """测试 Provider 是字符串枚举""" + assert isinstance(KnowledgeBaseProvider.RAGFLOW, str) + assert KnowledgeBaseProvider.RAGFLOW == "ragflow" + + +class TestRagFlowProviderSettings: + """测试 RagFlowProviderSettings 模型""" + + def test_create_ragflow_provider_settings(self): + """测试创建 RagFlow 提供商设置""" + settings = RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1", "ds-2"], + ) + assert settings.base_url == "https://ragflow.example.com" + assert settings.dataset_ids == ["ds-1", "ds-2"] + + def test_ragflow_provider_settings_required_fields(self): + """测试 RagFlow 提供商设置必填字段""" + with pytest.raises(Exception): # Pydantic ValidationError + RagFlowProviderSettings() # type: ignore + + def test_ragflow_provider_settings_model_dump(self): + """测试 RagFlow 提供商设置序列化""" + settings = RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ) + data = settings.model_dump() + assert "base_url" in data or "baseUrl" in data + + +class TestRagFlowRetrieveSettings: + """测试 RagFlowRetrieveSettings 模型""" + + def test_create_ragflow_retrieve_settings(self): + """测试创建 RagFlow 检索设置""" + settings = RagFlowRetrieveSettings( + similarity_threshold=0.8, + vector_similarity_weight=0.5, + cross_languages=["English", "Chinese"], + ) + assert settings.similarity_threshold == 0.8 + assert settings.vector_similarity_weight == 0.5 + assert settings.cross_languages == ["English", "Chinese"] + + def test_ragflow_retrieve_settings_optional(self): + """测试 RagFlow 检索设置可选字段""" + settings = RagFlowRetrieveSettings() + assert settings.similarity_threshold is None + assert settings.vector_similarity_weight is None + assert settings.cross_languages is None + + def test_ragflow_retrieve_settings_partial(self): + """测试 RagFlow 检索设置部分字段""" + settings = RagFlowRetrieveSettings(similarity_threshold=0.7) + assert settings.similarity_threshold == 0.7 + assert settings.vector_similarity_weight is None + + +class TestBailianProviderSettings: + """测试 BailianProviderSettings 模型""" + + def test_create_bailian_provider_settings(self): + """测试创建百炼提供商设置""" + settings = BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1", "idx-2"], + ) + assert settings.workspace_id == "ws-123" + assert settings.index_ids == ["idx-1", "idx-2"] + + def test_bailian_provider_settings_required_fields(self): + """测试百炼提供商设置必填字段""" + with pytest.raises(Exception): + BailianProviderSettings() # type: ignore + + def test_bailian_provider_settings_single_index(self): + """测试百炼提供商设置单个索引""" + settings = BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ) + assert len(settings.index_ids) == 1 + + +class TestBailianRetrieveSettings: + """测试 BailianRetrieveSettings 模型""" + + def test_create_bailian_retrieve_settings(self): + """测试创建百炼检索设置""" + settings = BailianRetrieveSettings( + dense_similarity_top_k=10, + sparse_similarity_top_k=5, + rerank_min_score=0.5, + rerank_top_n=3, + ) + assert settings.dense_similarity_top_k == 10 + assert settings.sparse_similarity_top_k == 5 + assert settings.rerank_min_score == 0.5 + assert settings.rerank_top_n == 3 + + def test_bailian_retrieve_settings_optional(self): + """测试百炼检索设置可选字段""" + settings = BailianRetrieveSettings() + assert settings.dense_similarity_top_k is None + assert settings.sparse_similarity_top_k is None + assert settings.rerank_min_score is None + assert settings.rerank_top_n is None + + def test_bailian_retrieve_settings_partial(self): + """测试百炼检索设置部分字段""" + settings = BailianRetrieveSettings( + dense_similarity_top_k=20, + rerank_top_n=5, + ) + assert settings.dense_similarity_top_k == 20 + assert settings.rerank_top_n == 5 + assert settings.sparse_similarity_top_k is None + + +class TestADBProviderSettings: + """测试 ADBProviderSettings 模型""" + + def test_create_adb_provider_settings(self): + """测试创建 ADB 提供商设置""" + settings = ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + embedding_model="text-embedding-v3", + metrics="cosine", + metadata='{"key": "value"}', + ) + assert settings.db_instance_id == "gp-123456" + assert settings.namespace == "public" + assert settings.namespace_password == "password123" + assert settings.embedding_model == "text-embedding-v3" + assert settings.metrics == "cosine" + assert settings.metadata == '{"key": "value"}' + + def test_adb_provider_settings_required_fields(self): + """测试 ADB 提供商设置必填字段""" + with pytest.raises(Exception): + ADBProviderSettings() # type: ignore + + def test_adb_provider_settings_minimal(self): + """测试 ADB 提供商设置最小配置""" + settings = ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ) + assert settings.db_instance_id == "gp-123456" + assert settings.embedding_model is None + assert settings.metrics is None + assert settings.metadata is None + + +class TestADBRetrieveSettings: + """测试 ADBRetrieveSettings 模型""" + + def test_create_adb_retrieve_settings(self): + """测试创建 ADB 检索设置""" + settings = ADBRetrieveSettings( + top_k=10, + use_full_text_retrieval=True, + rerank_factor=1.5, + recall_window=[-5, 5], + hybrid_search="RRF", + hybrid_search_args={"RRF": {"k": 60}}, + ) + assert settings.top_k == 10 + assert settings.use_full_text_retrieval is True + assert settings.rerank_factor == 1.5 + assert settings.recall_window == [-5, 5] + assert settings.hybrid_search == "RRF" + assert settings.hybrid_search_args == {"RRF": {"k": 60}} + + def test_adb_retrieve_settings_optional(self): + """测试 ADB 检索设置可选字段""" + settings = ADBRetrieveSettings() + assert settings.top_k is None + assert settings.use_full_text_retrieval is None + assert settings.rerank_factor is None + assert settings.recall_window is None + assert settings.hybrid_search is None + assert settings.hybrid_search_args is None + + def test_adb_retrieve_settings_weight_hybrid(self): + """测试 ADB 检索设置加权混合检索""" + settings = ADBRetrieveSettings( + hybrid_search="Weight", + hybrid_search_args={"Weight": {"alpha": 0.5}}, + ) + assert settings.hybrid_search == "Weight" + assert settings.hybrid_search_args["Weight"]["alpha"] == 0.5 + + +class TestKnowledgeBaseMutableProps: + """测试 KnowledgeBaseMutableProps 模型""" + + def test_create_mutable_props(self): + """测试创建可变属性""" + props = KnowledgeBaseMutableProps( + description="Test description", + credential_name="test-credential", + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + retrieve_settings=RagFlowRetrieveSettings( + similarity_threshold=0.8, + ), + ) + assert props.description == "Test description" + assert props.credential_name == "test-credential" + assert isinstance(props.provider_settings, RagFlowProviderSettings) + assert isinstance(props.retrieve_settings, RagFlowRetrieveSettings) + + def test_mutable_props_optional(self): + """测试可变属性可选字段""" + props = KnowledgeBaseMutableProps() + assert props.description is None + assert props.credential_name is None + assert props.provider_settings is None + assert props.retrieve_settings is None + + def test_mutable_props_with_typed_settings(self): + """测试可变属性使用类型化设置""" + provider_settings = RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ) + props = KnowledgeBaseMutableProps( + provider_settings=provider_settings, + ) + assert isinstance(props.provider_settings, RagFlowProviderSettings) + + +class TestKnowledgeBaseImmutableProps: + """测试 KnowledgeBaseImmutableProps 模型""" + + def test_create_immutable_props(self): + """测试创建不可变属性""" + props = KnowledgeBaseImmutableProps( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + ) + assert props.knowledge_base_name == "test-kb" + assert props.provider == KnowledgeBaseProvider.RAGFLOW + + def test_immutable_props_optional(self): + """测试不可变属性可选字段""" + props = KnowledgeBaseImmutableProps() + assert props.knowledge_base_name is None + assert props.provider is None + + def test_immutable_props_with_string_provider(self): + """测试不可变属性使用字符串提供商""" + props = KnowledgeBaseImmutableProps( + knowledge_base_name="test-kb", + provider="bailian", + ) + assert props.provider == "bailian" + + +class TestKnowledgeBaseSystemProps: + """测试 KnowledgeBaseSystemProps 模型""" + + def test_create_system_props(self): + """测试创建系统属性""" + props = KnowledgeBaseSystemProps( + knowledge_base_id="kb-123", + created_at="2024-01-01T00:00:00Z", + last_updated_at="2024-01-02T00:00:00Z", + ) + assert props.knowledge_base_id == "kb-123" + assert props.created_at == "2024-01-01T00:00:00Z" + assert props.last_updated_at == "2024-01-02T00:00:00Z" + + def test_system_props_optional(self): + """测试系统属性可选字段""" + props = KnowledgeBaseSystemProps() + assert props.knowledge_base_id is None + assert props.created_at is None + assert props.last_updated_at is None + + +class TestKnowledgeBaseCreateInput: + """测试 KnowledgeBaseCreateInput 模型""" + + def test_create_minimal_input(self): + """测试创建最小输入参数""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + ) + assert input_obj.knowledge_base_name == "test-kb" + assert input_obj.provider == KnowledgeBaseProvider.RAGFLOW + assert isinstance(input_obj.provider_settings, RagFlowProviderSettings) + + def test_create_full_input(self): + """测试创建完整输入参数""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + description="Test knowledge base", + credential_name="test-credential", + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=10, + ), + ) + assert input_obj.knowledge_base_name == "test-kb" + assert input_obj.description == "Test knowledge base" + assert input_obj.credential_name == "test-credential" + + def test_create_input_with_adb(self): + """测试创建 ADB 输入参数""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-adb-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="gp-123456", + namespace="public", + namespace_password="password123", + ), + ) + assert input_obj.provider == KnowledgeBaseProvider.ADB + + def test_create_input_with_dict_settings(self): + """测试创建输入参数(字典设置会自动转换为类型化对象)""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider="ragflow", + provider_settings={ + "base_url": "https://ragflow.example.com", + "dataset_ids": ["ds-1"], + }, + ) + assert input_obj.provider == "ragflow" + # Pydantic 会自动将字典转换为类型化对象 + assert isinstance(input_obj.provider_settings, RagFlowProviderSettings) + assert ( + input_obj.provider_settings.base_url + == "https://ragflow.example.com" + ) + + def test_model_dump(self): + """测试模型序列化""" + input_obj = KnowledgeBaseCreateInput( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings={ + "base_url": "https://example.com", + "dataset_ids": [], + }, + ) + data = input_obj.model_dump() + assert input_obj.knowledge_base_name == "test-kb" + + +class TestKnowledgeBaseUpdateInput: + """测试 KnowledgeBaseUpdateInput 模型""" + + def test_create_update_input(self): + """测试创建更新输入参数""" + input_obj = KnowledgeBaseUpdateInput( + description="Updated description", + ) + assert input_obj.description == "Updated description" + + def test_update_input_with_credential(self): + """测试更新输入参数(带凭证)""" + input_obj = KnowledgeBaseUpdateInput( + credential_name="new-credential", + ) + assert input_obj.credential_name == "new-credential" + + def test_update_input_with_provider_settings(self): + """测试更新输入参数(带提供商设置)""" + input_obj = KnowledgeBaseUpdateInput( + provider_settings=RagFlowProviderSettings( + base_url="https://new-ragflow.example.com", + dataset_ids=["ds-new"], + ), + ) + assert input_obj.provider_settings is not None + + def test_update_input_with_retrieve_settings(self): + """测试更新输入参数(带检索设置)""" + input_obj = KnowledgeBaseUpdateInput( + retrieve_settings=BailianRetrieveSettings( + dense_similarity_top_k=20, + ), + ) + assert input_obj.retrieve_settings is not None + + def test_update_input_optional(self): + """测试更新输入参数可选字段""" + input_obj = KnowledgeBaseUpdateInput() + assert input_obj.description is None + assert input_obj.credential_name is None + assert input_obj.provider_settings is None + assert input_obj.retrieve_settings is None + + +class TestKnowledgeBaseListInput: + """测试 KnowledgeBaseListInput 模型""" + + def test_create_list_input(self): + """测试创建列表输入参数""" + input_obj = KnowledgeBaseListInput( + page_number=1, + page_size=10, + provider=KnowledgeBaseProvider.RAGFLOW, + ) + assert input_obj.page_number == 1 + assert input_obj.page_size == 10 + assert input_obj.provider == KnowledgeBaseProvider.RAGFLOW + + def test_list_input_default(self): + """测试列表输入参数默认值""" + input_obj = KnowledgeBaseListInput() + assert input_obj.provider is None + + def test_list_input_with_pagination(self): + """测试列表输入参数(带分页)""" + input_obj = KnowledgeBaseListInput( + page_number=2, + page_size=20, + ) + assert input_obj.page_number == 2 + assert input_obj.page_size == 20 + + def test_list_input_with_string_provider(self): + """测试列表输入参数(字符串提供商)""" + input_obj = KnowledgeBaseListInput( + provider="bailian", + ) + assert input_obj.provider == "bailian" + + +class TestKnowledgeBaseListOutput: + """测试 KnowledgeBaseListOutput 模型""" + + def test_create_list_output(self): + """测试创建列表输出""" + output = KnowledgeBaseListOutput( + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + description="Test knowledge base", + credential_name="test-credential", + created_at="2024-01-01T00:00:00Z", + last_updated_at="2024-01-01T00:00:00Z", + ) + assert output.knowledge_base_id == "kb-123" + assert output.knowledge_base_name == "test-kb" + assert output.provider == KnowledgeBaseProvider.RAGFLOW + assert output.description == "Test knowledge base" + + def test_list_output_optional(self): + """测试列表输出可选字段""" + output = KnowledgeBaseListOutput() + assert output.knowledge_base_id is None + assert output.knowledge_base_name is None + assert output.provider is None + assert output.description is None + assert output.credential_name is None + assert output.created_at is None + assert output.last_updated_at is None + + def test_list_output_with_settings(self): + """测试列表输出带设置(字典会自动转换为类型化对象)""" + output = KnowledgeBaseListOutput( + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + provider="bailian", + provider_settings={ + "workspace_id": "ws-123", + "index_ids": ["idx-1"], + }, + retrieve_settings={"dense_similarity_top_k": 10}, + ) + # Pydantic 会自动将字典转换为类型化对象 + assert isinstance(output.provider_settings, BailianProviderSettings) + assert output.provider_settings.workspace_id == "ws-123" + assert isinstance(output.retrieve_settings, BailianRetrieveSettings) + assert output.retrieve_settings.dense_similarity_top_k == 10 + + @patch("agentrun.knowledgebase.client.KnowledgeBaseClient") + def test_to_knowledge_base_sync(self, mock_client_class): + """测试同步转换为 KnowledgeBase 对象""" + mock_client = MagicMock() + mock_kb = MagicMock() + mock_client.get.return_value = mock_kb + mock_client_class.return_value = mock_client + + output = KnowledgeBaseListOutput( + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + ) + + result = output.to_knowledge_base() + assert result == mock_kb + mock_client.get.assert_called_once() + + @patch("agentrun.knowledgebase.client.KnowledgeBaseClient") + @pytest.mark.asyncio + async def test_to_knowledge_base_async(self, mock_client_class): + """测试异步转换为 KnowledgeBase 对象""" + mock_client = MagicMock() + mock_kb = MagicMock() + mock_client.get_async = AsyncMock(return_value=mock_kb) + mock_client_class.return_value = mock_client + + output = KnowledgeBaseListOutput( + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + ) + + result = await output.to_knowledge_base_async() + assert result == mock_kb + + +class TestRetrieveInput: + """测试 RetrieveInput 模型""" + + def test_create_retrieve_input(self): + """测试创建检索输入参数""" + input_obj = RetrieveInput( + knowledge_base_names=["kb-1", "kb-2"], + query="What is AI?", + ) + assert input_obj.knowledge_base_names == ["kb-1", "kb-2"] + assert input_obj.query == "What is AI?" + + def test_retrieve_input_with_optional_fields(self): + """测试检索输入参数可选字段""" + input_obj = RetrieveInput( + knowledge_base_names=["kb-1"], + query="Test query", + knowledge_base_id="kb-123", + knowledge_base_name="test-kb", + provider="ragflow", + description="Test description", + credential_name="test-credential", + created_at="2024-01-01T00:00:00Z", + last_updated_at="2024-01-01T00:00:00Z", + ) + assert input_obj.knowledge_base_id == "kb-123" + assert input_obj.knowledge_base_name == "test-kb" + assert input_obj.provider == "ragflow" + + def test_retrieve_input_default_optional(self): + """测试检索输入参数默认可选值""" + input_obj = RetrieveInput( + knowledge_base_names=["kb-1"], + query="Test query", + ) + assert input_obj.knowledge_base_id is None + assert input_obj.knowledge_base_name is None + assert input_obj.provider is None + + @patch("agentrun.knowledgebase.client.KnowledgeBaseClient") + def test_retrieve_input_to_knowledge_base_sync(self, mock_client_class): + """测试检索输入同步转换为 KnowledgeBase""" + mock_client = MagicMock() + mock_kb = MagicMock() + mock_client.get.return_value = mock_kb + mock_client_class.return_value = mock_client + + input_obj = RetrieveInput( + knowledge_base_names=["kb-1"], + query="Test query", + knowledge_base_name="test-kb", + ) + + result = input_obj.to_knowledge_base() + assert result == mock_kb + + @patch("agentrun.knowledgebase.client.KnowledgeBaseClient") + @pytest.mark.asyncio + async def test_retrieve_input_to_knowledge_base_async( + self, mock_client_class + ): + """测试检索输入异步转换为 KnowledgeBase""" + mock_client = MagicMock() + mock_kb = MagicMock() + mock_client.get_async = AsyncMock(return_value=mock_kb) + mock_client_class.return_value = mock_client + + input_obj = RetrieveInput( + knowledge_base_names=["kb-1"], + query="Test query", + knowledge_base_name="test-kb", + ) + + result = await input_obj.to_knowledge_base_async() + assert result == mock_kb