Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions agentrun/knowledgebase/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""KnowledgeBase 模块 / KnowledgeBase Module"""

from .api import (
ADBDataAPI,
BailianDataAPI,
get_data_api,
KnowledgeBaseControlAPI,
Expand All @@ -10,6 +11,8 @@
from .client import KnowledgeBaseClient
from .knowledgebase import KnowledgeBase
from .model import (
ADBProviderSettings,
ADBRetrieveSettings,
BailianProviderSettings,
BailianRetrieveSettings,
KnowledgeBaseCreateInput,
Expand All @@ -33,17 +36,20 @@
"KnowledgeBaseDataAPI",
"RagFlowDataAPI",
"BailianDataAPI",
"ADBDataAPI",
"get_data_api",
# enums
"KnowledgeBaseProvider",
# provider settings
"ProviderSettings",
"RagFlowProviderSettings",
"BailianProviderSettings",
"ADBProviderSettings",
# retrieve settings
"RetrieveSettings",
"RagFlowRetrieveSettings",
"BailianRetrieveSettings",
"ADBRetrieveSettings",
# api model
"KnowledgeBaseCreateInput",
"KnowledgeBaseUpdateInput",
Expand Down
50 changes: 50 additions & 0 deletions agentrun/knowledgebase/__knowledgebase_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from .api.data import get_data_api
from .model import (
ADBProviderSettings,
ADBRetrieveSettings,
BailianProviderSettings,
BailianRetrieveSettings,
KnowledgeBaseCreateInput,
Expand Down Expand Up @@ -294,6 +296,54 @@ def _get_data_api(self, config: Optional[Config] = None):
**self.retrieve_settings
)

elif provider == KnowledgeBaseProvider.ADB:
# ADB 设置 / ADB settings
if self.provider_settings:
if isinstance(self.provider_settings, ADBProviderSettings):
converted_provider_settings = self.provider_settings
elif isinstance(self.provider_settings, dict):
# ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case
# ADB provider_settings uses PascalCase keys, need to convert to snake_case
converted_provider_settings = ADBProviderSettings(
db_instance_id=self.provider_settings.get(
"DBInstanceId", ""
),
namespace=self.provider_settings.get("Namespace", ""),
namespace_password=self.provider_settings.get(
"NamespacePassword", ""
),
embedding_model=self.provider_settings.get(
"EmbeddingModel"
),
metrics=self.provider_settings.get("Metrics"),
metadata=self.provider_settings.get("Metadata"),
)

if self.retrieve_settings:
if isinstance(self.retrieve_settings, ADBRetrieveSettings):
converted_retrieve_settings = self.retrieve_settings
elif isinstance(self.retrieve_settings, dict):
# ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case
# ADB retrieve_settings uses PascalCase keys, need to convert to snake_case
converted_retrieve_settings = ADBRetrieveSettings(
top_k=self.retrieve_settings.get("TopK"),
use_full_text_retrieval=self.retrieve_settings.get(
"UseFullTextRetrieval"
),
rerank_factor=self.retrieve_settings.get(
"RerankFactor"
),
recall_window=self.retrieve_settings.get(
"RecallWindow"
),
hybrid_search=self.retrieve_settings.get(
"HybridSearch"
),
hybrid_search_args=self.retrieve_settings.get(
"HybridSearchArgs"
),
)

return get_data_api(
provider=provider,
knowledge_base_name=self.knowledge_base_name or "",
Expand Down
226 changes: 222 additions & 4 deletions agentrun/knowledgebase/api/__data_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
提供知识库检索功能的数据链路 API。
Provides data API for knowledge base retrieval operations.

根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。
Dispatches to different implementations based on provider type (ragflow / bailian).
根据不同的 provider 类型(ragflow / bailian / adb)分发到不同的实现。
Dispatches to different implementations based on provider type (ragflow / bailian / adb).
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

from alibabacloud_bailian20231229 import models as bailian_models
from alibabacloud_gpdb20160503 import models as gpdb_models
import httpx

from agentrun.utils.config import Config
Expand All @@ -19,6 +20,8 @@
from agentrun.utils.log import logger

from ..model import (
ADBProviderSettings,
ADBRetrieveSettings,
BailianProviderSettings,
BailianRetrieveSettings,
KnowledgeBaseProvider,
Expand Down Expand Up @@ -347,15 +350,213 @@ async def retrieve_async(
}


class ADBDataAPI(KnowledgeBaseDataAPI, ControlAPI):
"""ADB (AnalyticDB for PostgreSQL) 知识库数据链路 API / ADB KnowledgeBase Data API

实现 ADB 知识库的检索逻辑,通过 GPDB SDK 调用 QueryContent 接口。
Implements retrieval logic for ADB knowledge base via GPDB SDK QueryContent API.
"""

def __init__(
self,
knowledge_base_name: str,
config: Optional[Config] = None,
provider_settings: Optional[ADBProviderSettings] = None,
retrieve_settings: Optional[ADBRetrieveSettings] = None,
):
"""初始化 ADB 知识库数据链路 API / Initialize ADB KnowledgeBase Data API

Args:
knowledge_base_name: 知识库名称 / Knowledge base name
config: 配置 / Configuration
provider_settings: ADB 提供商设置 / ADB provider settings
retrieve_settings: ADB 检索设置 / ADB retrieve settings
"""
KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config)
ControlAPI.__init__(self, config)
self.provider_settings = provider_settings
self.retrieve_settings = retrieve_settings

def _build_query_content_request(
self, query: str, config: Optional[Config] = None
) -> gpdb_models.QueryContentRequest:
"""构建 QueryContent 请求 / Build QueryContent request

Args:
query: 查询文本 / Query text
config: 配置 / Configuration

Returns:
QueryContentRequest: GPDB QueryContent 请求对象
"""
if self.provider_settings is None:
raise ValueError("provider_settings is required for ADB retrieval")

cfg = Config.with_configs(self.config, config)

# 构建基础请求参数 / Build base request parameters
request_params: Dict[str, Any] = {
"content": query,
"dbinstance_id": self.provider_settings.db_instance_id,
"namespace": self.provider_settings.namespace,
"namespace_password": self.provider_settings.namespace_password,
"collection": self.knowledge_base_name,
"region_id": cfg.get_region_id(),
}

# 添加可选的提供商设置 / Add optional provider settings
if self.provider_settings.metrics is not None:
request_params["metrics"] = self.provider_settings.metrics

# 添加检索设置 / Add retrieve settings
if self.retrieve_settings:
if self.retrieve_settings.top_k is not None:
request_params["top_k"] = self.retrieve_settings.top_k
if self.retrieve_settings.use_full_text_retrieval is not None:
request_params["use_full_text_retrieval"] = (
self.retrieve_settings.use_full_text_retrieval
)
if self.retrieve_settings.rerank_factor is not None:
request_params["rerank_factor"] = (
self.retrieve_settings.rerank_factor
)
if self.retrieve_settings.recall_window is not None:
request_params["recall_window"] = (
self.retrieve_settings.recall_window
)
if self.retrieve_settings.hybrid_search is not None:
request_params["hybrid_search"] = (
self.retrieve_settings.hybrid_search
)
if self.retrieve_settings.hybrid_search_args is not None:
request_params["hybrid_search_args"] = (
self.retrieve_settings.hybrid_search_args
)

return gpdb_models.QueryContentRequest(**request_params)

def _parse_query_content_response(
self, response: gpdb_models.QueryContentResponse, query: str
) -> Dict[str, Any]:
"""解析 QueryContent 响应 / Parse QueryContent response

Args:
response: GPDB QueryContent 响应对象
query: 原始查询文本 / Original query text

Returns:
Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results
"""
all_matches: List[Dict[str, Any]] = []

if response.body and response.body.matches:
match_list = response.body.matches.match_list or []
for match in match_list:
all_matches.append({
"content": (
match.content if hasattr(match, "content") else None
),
"score": match.score if hasattr(match, "score") else None,
"id": match.id if hasattr(match, "id") else None,
"file_name": (
match.file_name if hasattr(match, "file_name") else None
),
"file_url": (
match.file_url if hasattr(match, "file_url") else None
),
"metadata": (
match.metadata if hasattr(match, "metadata") else None
),
"rerank_score": (
match.rerank_score
if hasattr(match, "rerank_score")
else None
),
"retrieval_source": (
match.retrieval_source
if hasattr(match, "retrieval_source")
else None
),
})

return {
"data": all_matches,
"query": query,
"knowledge_base_name": self.knowledge_base_name,
"request_id": (
response.body.request_id
if response.body and hasattr(response.body, "request_id")
else None
),
}

async def retrieve_async(
self,
query: str,
config: Optional[Config] = None,
) -> Dict[str, Any]:
"""ADB 检索(异步)/ ADB retrieval asynchronously

通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。
Retrieves from ADB knowledge base via GPDB SDK QueryContent API.

Args:
query: 查询文本 / Query text
config: 配置 / Configuration

Returns:
Dict[str, Any]: 检索结果 / Retrieval results
"""
try:
if self.provider_settings is None:
raise ValueError(
"provider_settings is required for ADB retrieval"
)

# 获取 GPDB 客户端 / Get GPDB client
client = self._get_gpdb_client(config)

# 构建请求 / Build request
request = self._build_query_content_request(query, config)
logger.debug(f"ADB QueryContent request: {request}")

# 调用 QueryContent API / Call QueryContent API
response = await client.query_content_async(request)
logger.debug(f"ADB QueryContent response: {response}")

# 解析并返回结果 / Parse and return results
return self._parse_query_content_response(response, query)

except Exception as e:
logger.warning(
"Failed to retrieve from ADB knowledge base "
f"'{self.knowledge_base_name}': {e}"
)
return {
"data": f"Failed to retrieve: {e}",
"query": query,
"knowledge_base_name": self.knowledge_base_name,
"error": True,
}


def get_data_api(
provider: KnowledgeBaseProvider,
knowledge_base_name: str,
config: Optional[Config] = None,
provider_settings: Optional[
Union[RagFlowProviderSettings, BailianProviderSettings]
Union[
RagFlowProviderSettings,
BailianProviderSettings,
ADBProviderSettings,
]
] = None,
retrieve_settings: Optional[
Union[RagFlowRetrieveSettings, BailianRetrieveSettings]
Union[
RagFlowRetrieveSettings,
BailianRetrieveSettings,
ADBRetrieveSettings,
]
] = None,
credential_name: Optional[str] = None,
) -> KnowledgeBaseDataAPI:
Expand Down Expand Up @@ -410,5 +611,22 @@ def get_data_api(
provider_settings=bailian_provider_settings,
retrieve_settings=bailian_retrieve_settings,
)
elif provider == KnowledgeBaseProvider.ADB or provider == "adb":
adb_provider_settings = (
provider_settings
if isinstance(provider_settings, ADBProviderSettings)
else None
)
adb_retrieve_settings = (
retrieve_settings
if isinstance(retrieve_settings, ADBRetrieveSettings)
else None
)
return ADBDataAPI(
knowledge_base_name,
config,
provider_settings=adb_provider_settings,
retrieve_settings=adb_retrieve_settings,
)
else:
raise ValueError(f"Unsupported provider type: {provider}")
2 changes: 2 additions & 0 deletions agentrun/knowledgebase/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .control import KnowledgeBaseControlAPI
from .data import (
ADBDataAPI,
BailianDataAPI,
get_data_api,
KnowledgeBaseDataAPI,
Expand All @@ -15,5 +16,6 @@
"KnowledgeBaseDataAPI",
"RagFlowDataAPI",
"BailianDataAPI",
"ADBDataAPI",
"get_data_api",
]
Loading
Loading