From 8106995654899e82bf006b65d5e13cf17a088b79 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 7 Feb 2026 01:57:33 +0100 Subject: [PATCH 1/2] add new tools --- agent/tools.py | 9 + onchain/analytics/analytics_tools.py | 385 +++++++++++++++++++++- onchain/analytics/test_analytics_tools.py | 120 +++++++ templates/analytics_agent.jinja2 | 4 + 4 files changed, 517 insertions(+), 1 deletion(-) diff --git a/agent/tools.py b/agent/tools.py index 9a0790f..05ba21d 100644 --- a/agent/tools.py +++ b/agent/tools.py @@ -19,6 +19,10 @@ analyze_price_trend, analyze_wallet_portfolio, get_coingecko_current_price, + get_token_market_info, + get_top_coins_by_market_cap, + get_global_market_overview, + compare_tokens, ) from onchain.tokens.trending import ( get_trending_tokens, @@ -96,6 +100,11 @@ async def search_token( portfolio_volatility, analyze_wallet_portfolio, get_coingecko_current_price, + # Market data tools + get_token_market_info, + get_top_coins_by_market_cap, + get_global_market_overview, + compare_tokens, # Token tools get_trending_tokens, evaluate_token_risk, diff --git a/onchain/analytics/analytics_tools.py b/onchain/analytics/analytics_tools.py index 076f94b..33ac99e 100644 --- a/onchain/analytics/analytics_tools.py +++ b/onchain/analytics/analytics_tools.py @@ -9,7 +9,7 @@ import requests from time import sleep from datetime import datetime, timedelta, UTC -from cachetools import TTLCache +from cachetools import TTLCache, cached from agent.telemetry import track_tool_usage from api.api_types import WalletTokenHolding @@ -33,6 +33,12 @@ class CandleInterval(StrEnum): # Cache for price data (10-minute TTL) price_data_cache = TTLCache(maxsize=1000, ttl=600) +# Caches for market data (5-minute TTL) +_coin_data_cache = TTLCache(maxsize=100, ttl=300) +_coins_markets_cache = TTLCache(maxsize=50, ttl=300) +_global_data_cache = TTLCache(maxsize=1, ttl=300) +_global_defi_cache = TTLCache(maxsize=1, ttl=300) + # Common token mappings for better user experience PREFERRED_TOKEN_IDS = { "btc": "bitcoin", @@ -1366,3 +1372,380 @@ def portfolio_volatility( "error": f"Error calculating portfolio volatility: {str(e)}", "traceback": traceback.format_exc(), } + + +# --- Cached helpers for market data endpoints --- + + +@cached(_coin_data_cache) +def _fetch_coin_data(coin_id: str) -> Dict[str, Any]: + """Fetch comprehensive coin data from /coins/{id} with caching.""" + url = f"{COINGECKO_BASE_URL}/coins/{coin_id}" + params = { + "localization": "false", + "tickers": "false", + "market_data": "true", + "community_data": "true", + "developer_data": "true", + } + return make_coingecko_request(url, params=params) + + +@cached(_coins_markets_cache) +def _fetch_coins_markets( + vs_currency: str, num_coins: int, category: Optional[str], ids: Optional[str] +) -> List[Dict[str, Any]]: + """Fetch coins/markets data with caching.""" + url = f"{COINGECKO_BASE_URL}/coins/markets" + params = { + "vs_currency": vs_currency, + "order": "market_cap_desc", + "per_page": num_coins, + "page": 1, + "sparkline": "false", + "price_change_percentage": "1h,24h,7d,30d", + } + if category: + params["category"] = category + if ids: + params["ids"] = ids + return make_coingecko_request(url, params=params) + + +@cached(_global_data_cache) +def _fetch_global_data() -> Dict[str, Any]: + """Fetch global market data with caching.""" + url = f"{COINGECKO_BASE_URL}/global" + return make_coingecko_request(url) + + +@cached(_global_defi_cache) +def _fetch_global_defi_data() -> Dict[str, Any]: + """Fetch global DeFi data with caching.""" + url = f"{COINGECKO_BASE_URL}/global/decentralized_finance_defi" + return make_coingecko_request(url) + + +# --- New market data tools --- + + +@tool() +@track_tool_usage("get_token_market_info") +def get_token_market_info(token_symbol: str) -> Dict[str, Any]: + """ + Get comprehensive market data for a single token including market cap, FDV, + volume, price changes across multiple timeframes, ATH/ATL, supply metrics, + community data, and developer activity. + """ + try: + token_id, error_message = get_coingecko_id(token_symbol) + if not token_id: + return {"error": f"Failed to resolve CoinGecko ID for {token_symbol}"} + + data = _fetch_coin_data(token_id) + + if isinstance(data, dict) and "error" in data: + suggestions = get_coin_suggestions(token_symbol, token_id) + error_msg = f"CoinGecko API couldn't find coin with ID '{token_id}'." + if suggestions: + error_msg += f" {suggestions}" + return {"error": error_msg} + + md = data.get("market_data", {}) + + result = { + "token_symbol": token_symbol, + "token_id": token_id, + "name": data.get("name"), + "market_cap_rank": data.get("market_cap_rank"), + "market_data": { + "current_price_usd": md.get("current_price", {}).get("usd"), + "market_cap_usd": md.get("market_cap", {}).get("usd"), + "fully_diluted_valuation_usd": md.get( + "fully_diluted_valuation", {} + ).get("usd"), + "total_volume_usd": md.get("total_volume", {}).get("usd"), + "price_change_percentage": { + "1h": md.get("price_change_percentage_1h_in_currency", {}).get( + "usd" + ), + "24h": md.get("price_change_percentage_24h"), + "7d": md.get("price_change_percentage_7d"), + "14d": md.get("price_change_percentage_14d"), + "30d": md.get("price_change_percentage_30d"), + "60d": md.get("price_change_percentage_60d"), + "200d": md.get("price_change_percentage_200d"), + "1y": md.get("price_change_percentage_1y"), + }, + "ath": { + "price_usd": md.get("ath", {}).get("usd"), + "date": md.get("ath_date", {}).get("usd"), + "change_percentage": md.get("ath_change_percentage", {}).get("usd"), + }, + "atl": { + "price_usd": md.get("atl", {}).get("usd"), + "date": md.get("atl_date", {}).get("usd"), + "change_percentage": md.get("atl_change_percentage", {}).get("usd"), + }, + "supply": { + "circulating": md.get("circulating_supply"), + "total": md.get("total_supply"), + "max": md.get("max_supply"), + }, + }, + "community_data": data.get("community_data"), + "developer_data": { + "stars": data.get("developer_data", {}).get("stars"), + "forks": data.get("developer_data", {}).get("forks"), + "subscribers": data.get("developer_data", {}).get("subscribers"), + "total_issues": data.get("developer_data", {}).get("total_issues"), + "commit_count_4_weeks": data.get("developer_data", {}).get( + "commit_count_4_weeks" + ), + }, + } + + if error_message: + result["warning"] = error_message + + return result + + except Exception as e: + return { + "error": f"Error fetching market info for {token_symbol}: {str(e)}", + "traceback": traceback.format_exc(), + } + + +@tool() +@track_tool_usage("get_top_coins_by_market_cap") +def get_top_coins_by_market_cap( + vs_currency: str = "usd", + num_coins: int = 20, + category: Optional[str] = None, +) -> Dict[str, Any]: + """ + Get top N coins ranked by market cap with optional category filter. + Categories include: 'decentralized-finance-defi', 'meme-token', 'layer-1', + 'layer-2', 'gaming', etc. Max 50 coins per request. + """ + try: + num_coins = min(max(1, num_coins), 50) + + coins = _fetch_coins_markets(vs_currency, num_coins, category, None) + + if isinstance(coins, dict) and "error" in coins: + return {"error": f"CoinGecko API error: {coins['error']}"} + + if not isinstance(coins, list) or len(coins) == 0: + return {"error": "No coins found for the specified criteria."} + + results = [] + for coin in coins: + results.append( + { + "rank": coin.get("market_cap_rank"), + "name": coin.get("name"), + "symbol": coin.get("symbol", "").upper(), + "current_price": coin.get("current_price"), + "market_cap": coin.get("market_cap"), + "total_volume": coin.get("total_volume"), + "price_change_percentage": { + "1h": coin.get("price_change_percentage_1h_in_currency"), + "24h": coin.get("price_change_percentage_24h"), + "7d": coin.get("price_change_percentage_7d_in_currency"), + "30d": coin.get("price_change_percentage_30d_in_currency"), + }, + "ath": coin.get("ath"), + "ath_change_percentage": coin.get("ath_change_percentage"), + } + ) + + return { + "vs_currency": vs_currency, + "category": category, + "num_coins": len(results), + "coins": results, + } + + except Exception as e: + return { + "error": f"Error fetching top coins: {str(e)}", + "traceback": traceback.format_exc(), + } + + +@tool() +@track_tool_usage("get_global_market_overview") +def get_global_market_overview() -> Dict[str, Any]: + """ + Get a global crypto market snapshot including total market cap, BTC/ETH dominance, + active cryptocurrencies count, and DeFi-specific metrics like DeFi market cap + and DeFi-to-ETH ratio. + """ + try: + global_data = _fetch_global_data() + defi_data = _fetch_global_defi_data() + + if isinstance(global_data, dict) and "error" in global_data: + return {"error": f"CoinGecko API error: {global_data['error']}"} + + gd = global_data.get("data", {}) + dd = defi_data.get("data", {}) + + return { + "global_market": { + "total_market_cap_usd": gd.get("total_market_cap", {}).get("usd"), + "total_volume_usd": gd.get("total_volume", {}).get("usd"), + "market_cap_change_percentage_24h": gd.get( + "market_cap_change_percentage_24h_usd" + ), + "btc_dominance": gd.get("market_cap_percentage", {}).get("btc"), + "eth_dominance": gd.get("market_cap_percentage", {}).get("eth"), + "active_cryptocurrencies": gd.get("active_cryptocurrencies"), + "markets": gd.get("markets"), + }, + "defi_market": { + "defi_market_cap": dd.get("defi_market_cap"), + "eth_market_cap": dd.get("eth_market_cap"), + "defi_to_eth_ratio": dd.get("defi_to_eth_ratio"), + "trading_volume_24h": dd.get("trading_volume_24h"), + "defi_dominance": dd.get("defi_dominance"), + "top_coin_name": dd.get("top_coin_name"), + "top_coin_defi_dominance": dd.get("top_coin_defi_dominance"), + }, + } + + except Exception as e: + return { + "error": f"Error fetching global market overview: {str(e)}", + "traceback": traceback.format_exc(), + } + + +@tool() +@track_tool_usage("compare_tokens") +def compare_tokens(token_symbols: List[str]) -> Dict[str, Any]: + """ + Side-by-side fundamental comparison of 2-4 tokens using market data. + Compares price, market cap, volume, price changes, and ATH distance. + Also computes relative metrics like highest market cap, best/worst 24h performer, + closest to ATH, and highest volume/market-cap ratio. + This complements compare_assets which does technical comparison. + """ + try: + if len(token_symbols) < 2: + return {"error": "Please provide at least 2 tokens to compare."} + if len(token_symbols) > 4: + return {"error": "Please provide at most 4 tokens to compare."} + + # Resolve all token IDs + coin_ids = [] + id_to_symbol = {} + for symbol in token_symbols: + token_id, err = get_coingecko_id(symbol) + if not token_id: + return {"error": f"Failed to resolve CoinGecko ID for {symbol}"} + coin_ids.append(token_id) + id_to_symbol[token_id] = symbol + + # Single batched API call + ids_param = ",".join(coin_ids) + coins = _fetch_coins_markets("usd", len(coin_ids), None, ids_param) + + if isinstance(coins, dict) and "error" in coins: + return {"error": f"CoinGecko API error: {coins['error']}"} + + if not isinstance(coins, list) or len(coins) == 0: + return {"error": "No data returned for the specified tokens."} + + # Build per-token data + tokens_data = [] + for coin in coins: + cid = coin.get("id", "") + symbol = id_to_symbol.get(cid, coin.get("symbol", "").upper()) + market_cap = coin.get("market_cap") or 0 + total_volume = coin.get("total_volume") or 0 + vol_mcap_ratio = ( + round(total_volume / market_cap, 4) if market_cap > 0 else None + ) + + tokens_data.append( + { + "symbol": symbol, + "name": coin.get("name"), + "current_price": coin.get("current_price"), + "market_cap": market_cap, + "total_volume": total_volume, + "volume_to_market_cap_ratio": vol_mcap_ratio, + "price_change_percentage": { + "1h": coin.get("price_change_percentage_1h_in_currency"), + "24h": coin.get("price_change_percentage_24h"), + "7d": coin.get("price_change_percentage_7d_in_currency"), + "30d": coin.get("price_change_percentage_30d_in_currency"), + }, + "ath": coin.get("ath"), + "ath_change_percentage": coin.get("ath_change_percentage"), + "market_cap_rank": coin.get("market_cap_rank"), + } + ) + + # Compute relative metrics + valid_tokens = [t for t in tokens_data if t["market_cap"] > 0] + + relative = {} + if valid_tokens: + highest_mcap = max(valid_tokens, key=lambda t: t["market_cap"]) + relative["highest_market_cap"] = highest_mcap["symbol"] + + tokens_with_24h = [ + t + for t in valid_tokens + if t["price_change_percentage"]["24h"] is not None + ] + if tokens_with_24h: + best_24h = max( + tokens_with_24h, + key=lambda t: t["price_change_percentage"]["24h"], + ) + worst_24h = min( + tokens_with_24h, + key=lambda t: t["price_change_percentage"]["24h"], + ) + relative["best_24h_performer"] = best_24h["symbol"] + relative["worst_24h_performer"] = worst_24h["symbol"] + + tokens_with_ath = [ + t + for t in valid_tokens + if t["ath_change_percentage"] is not None + ] + if tokens_with_ath: + closest_ath = max( + tokens_with_ath, + key=lambda t: t["ath_change_percentage"], + ) + relative["closest_to_ath"] = closest_ath["symbol"] + + tokens_with_ratio = [ + t + for t in valid_tokens + if t["volume_to_market_cap_ratio"] is not None + ] + if tokens_with_ratio: + highest_vol_ratio = max( + tokens_with_ratio, + key=lambda t: t["volume_to_market_cap_ratio"], + ) + relative["highest_volume_mcap_ratio"] = highest_vol_ratio["symbol"] + + return { + "tokens": tokens_data, + "relative_metrics": relative, + } + + except Exception as e: + return { + "error": f"Error comparing tokens: {str(e)}", + "traceback": traceback.format_exc(), + } diff --git a/onchain/analytics/test_analytics_tools.py b/onchain/analytics/test_analytics_tools.py index cc45775..1643981 100644 --- a/onchain/analytics/test_analytics_tools.py +++ b/onchain/analytics/test_analytics_tools.py @@ -8,6 +8,10 @@ max_drawdown_for_token, compare_assets, analyze_price_trend, + get_token_market_info, + get_top_coins_by_market_cap, + get_global_market_overview, + compare_tokens, CandleInterval, ) @@ -188,5 +192,121 @@ def test_compare_assets(self): print(result) + # --- Tests for new market data tools --- + + def test_get_token_market_info(self): + """Test fetching comprehensive market data for a token""" + result = get_token_market_info.invoke({"token_symbol": "BTC"}) + + self.assertNotIn("error", result) + self.assertEqual(result["token_id"], "bitcoin") + self.assertIn("market_data", result) + self.assertIn("current_price_usd", result["market_data"]) + self.assertIn("market_cap_usd", result["market_data"]) + self.assertIn("price_change_percentage", result["market_data"]) + self.assertIn("ath", result["market_data"]) + self.assertIn("atl", result["market_data"]) + self.assertIn("supply", result["market_data"]) + self.assertIsNotNone(result["market_data"]["current_price_usd"]) + self.assertIsNotNone(result.get("market_cap_rank")) + + def test_get_token_market_info_invalid(self): + """Test error handling for invalid token in market info""" + result = get_token_market_info.invoke( + {"token_symbol": "TOTALLYINVALIDTOKEN999"} + ) + self.assertIn("error", result) + + def test_get_top_coins_by_market_cap(self): + """Test fetching top coins by market cap""" + result = get_top_coins_by_market_cap.invoke( + {"vs_currency": "usd", "num_coins": 10} + ) + + self.assertNotIn("error", result) + self.assertIn("coins", result) + self.assertEqual(result["num_coins"], 10) + self.assertEqual(len(result["coins"]), 10) + + # Verify first coin has expected fields + first_coin = result["coins"][0] + self.assertIn("rank", first_coin) + self.assertIn("name", first_coin) + self.assertIn("current_price", first_coin) + self.assertIn("market_cap", first_coin) + self.assertIn("price_change_percentage", first_coin) + + def test_get_top_coins_by_market_cap_with_category(self): + """Test fetching top coins filtered by category""" + result = get_top_coins_by_market_cap.invoke( + { + "vs_currency": "usd", + "num_coins": 5, + "category": "decentralized-finance-defi", + } + ) + + self.assertNotIn("error", result) + self.assertIn("coins", result) + self.assertEqual(result["category"], "decentralized-finance-defi") + self.assertTrue(len(result["coins"]) > 0) + + def test_get_global_market_overview(self): + """Test fetching global market overview""" + result = get_global_market_overview.invoke({}) + + self.assertNotIn("error", result) + self.assertIn("global_market", result) + self.assertIn("defi_market", result) + + gm = result["global_market"] + self.assertIn("total_market_cap_usd", gm) + self.assertIn("btc_dominance", gm) + self.assertIn("eth_dominance", gm) + self.assertIn("active_cryptocurrencies", gm) + self.assertIsNotNone(gm["total_market_cap_usd"]) + + dm = result["defi_market"] + self.assertIn("defi_market_cap", dm) + self.assertIn("defi_to_eth_ratio", dm) + + def test_compare_tokens(self): + """Test side-by-side fundamental comparison of tokens""" + result = compare_tokens.invoke( + {"token_symbols": ["BTC", "ETH", "SOL"]} + ) + + self.assertNotIn("error", result) + self.assertIn("tokens", result) + self.assertIn("relative_metrics", result) + self.assertEqual(len(result["tokens"]), 3) + + # Verify relative metrics + rm = result["relative_metrics"] + self.assertIn("highest_market_cap", rm) + self.assertIn("best_24h_performer", rm) + self.assertIn("worst_24h_performer", rm) + + # Verify per-token data + for token in result["tokens"]: + self.assertIn("current_price", token) + self.assertIn("market_cap", token) + self.assertIn("price_change_percentage", token) + + def test_compare_tokens_too_few(self): + """Test that compare_tokens rejects fewer than 2 tokens""" + result = compare_tokens.invoke({"token_symbols": ["BTC"]}) + self.assertIn("error", result) + self.assertIn("at least 2", result["error"]) + + def test_compare_tokens_too_many(self): + """Test that compare_tokens rejects more than 4 tokens""" + result = compare_tokens.invoke( + {"token_symbols": ["BTC", "ETH", "SOL", "ADA", "DOT"]} + ) + self.assertIn("error", result) + self.assertIn("at most 4", result["error"]) + + if __name__ == "__main__": unittest.main() diff --git a/templates/analytics_agent.jinja2 b/templates/analytics_agent.jinja2 index 4e2cff3..2e218b1 100644 --- a/templates/analytics_agent.jinja2 +++ b/templates/analytics_agent.jinja2 @@ -15,6 +15,10 @@ IMPORTANT: ALWAYS use the provided tools to answer questions instead of relying - Analyze trending tokens or memecoins on a chain: use the `get_trending_tokens()` tool to get the trending tokens. In your answer, include the ID of each token you mention in the following format: ```token:```. - Analyze the risk of a token: use the `evaluate_token_risk()` tool to get the risk analysis. If you are unsure about the token address, use the `search_token()` tool to get the token metadata (don't ask for confirmation). - Analyze the top holders of a token: use the `get_top_token_holders()` tool to get the top holders of the token. +- Get comprehensive market data for a token (market cap, FDV, volume, supply, ATH/ATL, community & developer stats): use the `get_token_market_info()` tool. +- Get top coins by market cap, optionally filtered by category (e.g. 'decentralized-finance-defi', 'meme-token', 'layer-1', 'layer-2', 'gaming'): use the `get_top_coins_by_market_cap()` tool. +- Get a global crypto market overview (total market cap, BTC/ETH dominance, DeFi metrics): use the `get_global_market_overview()` tool. +- Compare fundamentals of 2-4 tokens side-by-side (price, market cap, volume, price changes, ATH distance): use the `compare_tokens()` tool. For technical comparison (SMAs, Bollinger Bands), use `compare_assets()` instead. If the user asks a more complex question, you can combine and use the tools to get the information you need to answer the question. Be intelligent and careful with the tools you use. From 1b64ca58091ab6e5b51380d44cc5779379a4a130 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 7 Feb 2026 02:12:21 +0100 Subject: [PATCH 2/2] rm router --- agent/agent_executors.py | 14 ------------ agent/prompts.py | 27 ----------------------- api/api_types.py | 2 +- server/fastapi_server.py | 46 ++++------------------------------------ templates/router.jinja2 | 36 ------------------------------- 5 files changed, 5 insertions(+), 120 deletions(-) delete mode 100644 templates/router.jinja2 diff --git a/agent/agent_executors.py b/agent/agent_executors.py index cf31f8c..86f7c57 100644 --- a/agent/agent_executors.py +++ b/agent/agent_executors.py @@ -72,30 +72,16 @@ # Select model based on configuration if not config.SUBNET_MODE: SUGGESTIONS_MODEL = GOOGLE_GEMINI_20_FLASH_MODEL - ROUTING_MODEL = GOOGLE_GEMINI_FLASH_15_8B_MODEL REASONING_MODEL = GOOGLE_GEMINI_20_FLASH_MODEL BASE_URL = "https://generativelanguage.googleapis.com/v1beta/" API_KEY = os.getenv("GEMINI_API_KEY") else: SUGGESTIONS_MODEL = LOCAL_LLM_MODEL - ROUTING_MODEL = LOCAL_LLM_MODEL REASONING_MODEL = LOCAL_LLM_MODEL BASE_URL = LOCAL_LLM_BASE_URL API_KEY = "dummy_key" -def create_routing_model() -> BaseChatModel: - return ChatOpenAI( - model=ROUTING_MODEL, - temperature=0.0, - max_tokens=500, - api_key=config.DUMMY_X402_API_KEY, - http_async_client=x402_http_client, - stream_usage=True, - streaming=True, - base_url=config.LLM_SERVER_URL, - ) - def create_suggestions_model() -> BaseChatModel: return ChatOpenAI( diff --git a/agent/prompts.py b/agent/prompts.py index 4cb1251..64c3173 100644 --- a/agent/prompts.py +++ b/agent/prompts.py @@ -14,7 +14,6 @@ investor_agent_template = env.get_template("investor_agent.jinja2") analytics_agent_template = env.get_template("analytics_agent.jinja2") suggestions_template = env.get_template("suggestions.jinja2") -router_template = env.get_template("router.jinja2") # We ignore token holdings with a total value of less than $1 @@ -106,29 +105,3 @@ def get_analytics_prompt( ) return analytics_agent_prompt - - -def get_router_prompt(message_history: List[Message], current_message: str) -> str: - """Get the router prompt to determine which agent should handle the request.""" - - MAX_AGENT_MESSAGE_LENGTH = 400 - - # Truncate assistant response to 400 characters, also include the message type - message_history = [ - { - "type": message.type, - "message": ( - message.message[:MAX_AGENT_MESSAGE_LENGTH] + "..." - if message.type == "assistant" - and len(message.message) > MAX_AGENT_MESSAGE_LENGTH - else message.message - ), - } - for message in message_history - ] - - router_prompt = router_template.render( - message_history=message_history, - current_message=current_message, - ) - return router_prompt diff --git a/api/api_types.py b/api/api_types.py index 56c2519..a4eee40 100644 --- a/api/api_types.py +++ b/api/api_types.py @@ -123,7 +123,7 @@ class Context(BaseModel): class AgentChatRequest(BaseModel): context: Context message: UserMessage - agent: Optional[AgentType] = None + agent: AgentType captchaToken: Optional[str] = None diff --git a/server/fastapi_server.py b/server/fastapi_server.py index a6e3836..9b64985 100644 --- a/server/fastapi_server.py +++ b/server/fastapi_server.py @@ -37,13 +37,11 @@ create_investor_executor, create_suggestions_model, create_analytics_executor, - create_routing_model, ) from agent.prompts import ( get_investor_agent_prompt, get_suggestions_prompt, get_analytics_prompt, - get_router_prompt, ) from agent.tools import ( create_investor_agent_toolkit, @@ -142,7 +140,6 @@ async def shutdown_event(): await cow_validator.close() # Initialize agents - router_model = create_routing_model() suggestions_model = create_suggestions_model() analytics_agent = create_analytics_executor(token_metadata_repo) investor_agent = create_investor_executor() @@ -154,7 +151,6 @@ async def shutdown_event(): protocol_registry.register_protocol(KaminoProtocol()) # Store agents in app state - app.state.router_model = router_model app.state.suggestions_model = suggestions_model app.state.analytics_agent = analytics_agent app.state.investor_agent = investor_agent @@ -326,7 +322,6 @@ async def run_agent( portfolio=portfolio, investor_agent=investor_agent, analytics_agent=analytics_agent, - router_model=router_model, ) return ( @@ -574,7 +569,7 @@ async def subnet_query_endpoint( address=wallet_address, conversationHistory=[], miner_token=None ), message=UserMessage(message=quant_query.query), - agent=None, # Let router decide + agent=AgentType.ANALYTICS, captchaToken=None, # No captcha for subnet ) @@ -591,7 +586,6 @@ async def subnet_query_endpoint( portfolio=portfolio, investor_agent=investor_agent, analytics_agent=analytics_agent, - router_model=router_model, ) # Convert response to QuantResponse format @@ -671,49 +665,17 @@ async def handle_agent_chat_request( token_metadata_repo: TokenMetadataRepo, investor_agent: any, analytics_agent: any, - router_model: ChatOpenAI, ) -> AgentMessage: - # If agent is explicitly specified, bypass router - if request.agent is not None: - if request.agent == AgentType.ANALYTICS: - return await handle_analytics_chat_request( - request, token_metadata_repo, portfolio, analytics_agent - ) - elif request.agent == AgentType.INVESTOR: - return await handle_investor_chat_request( - request, portfolio, investor_agent, protocol_registry - ) - else: - raise ValueError(f"Invalid agent type specified: {request.agent}") - - # Otherwise use router to determine agent - router_prompt = get_router_prompt( - message_history=request.context.conversationHistory[-NUM_MESSAGES_TO_KEEP:], - current_message=request.message.message, - ) - - router_response = await router_model.ainvoke(router_prompt) - selected_agent = router_response.content.strip().lower() - - # Extract agent type from response if it contains additional text - if "yield_agent" in selected_agent: - selected_agent = AgentType.YIELD - elif "analytics_agent" in selected_agent: - selected_agent = AgentType.ANALYTICS - else: - # Default to analytics agent if no clear choice - selected_agent = AgentType.ANALYTICS - - if selected_agent == AgentType.ANALYTICS: + if request.agent == AgentType.ANALYTICS: return await handle_analytics_chat_request( request, token_metadata_repo, portfolio, analytics_agent ) - elif selected_agent == AgentType.YIELD: + elif request.agent == AgentType.INVESTOR: return await handle_investor_chat_request( request, portfolio, investor_agent, protocol_registry ) else: - raise ValueError(f"Invalid agent selection from router: {selected_agent}") + raise ValueError(f"Invalid agent type specified: {request.agent}") async def handle_investor_chat_request( diff --git a/templates/router.jinja2 b/templates/router.jinja2 deleted file mode 100644 index a3dbc43..0000000 --- a/templates/router.jinja2 +++ /dev/null @@ -1,36 +0,0 @@ -You are a router that helps direct user requests to the appropriate agent. Based on the message history and current message, you should determine which agent should handle the request. - -You have two options: -1. "yield_agent" - The yield agent that helps users: - - Guide users through pool selection based on their tokens and risk preferences - - Focus on practical yield-earning strategies and pool interactions - - Help users find yield opportunities - -2. "analytics_agent" - The data analysis agent that helps users run analytics on their portfolio: - - Analyze market trends and protocol performance - - Compare assets and analyze portfolio performance - - Provide data-driven insights about protocols, TVL, and market conditions - - Analyze trending tokens (memecoins etc) - - Evaluate token risk - - Search and buy tokens - - Swap tokens - -Routing Guidelines: -1. Conversation Continuity: - - If the current message is a follow-up to a previous topic, route to the same agent - - For example, if user asked about yield opportunities and then asks follow-up questions about the same topic, keep routing to yield_agent - - Only switch agents if the user starts a completely new topic or conversation thread - -2. Conversation Flow: - - Consider the natural flow of conversation - - If the user is in the middle of a yield-related discussion, maintain that context - - Only switch agents if there's a clear indication of a new topic or conversation thread - -Message History: -{{ message_history }} - -Current Message: -{{ current_message }} - -REMEMBER, RESPONSD ONLY with exactly one of: "yield_agent" or "analytics_agent", nothing else! -IF YOU ARE NOT SURE WHICH AGENT TO USE, ROUTE TO THE ANALYTICS AGENT. \ No newline at end of file