Files
trade-assistant/backend/app/services/mcp_search_client.py
T

102 lines
3.4 KiB
Python

import asyncio
import json
import logging
import os
import sys
import warnings
from typing import Dict, Any, List, Optional
from mcp.client.stdio import stdio_client, StdioServerParameters
from mcp.client.session import ClientSession
logger = logging.getLogger(__name__)
SERVER_SCRIPT = os.path.join(os.path.dirname(__file__), "mcp_search_server.py")
VENV_PYTHON = sys.executable
class MCPClientManager:
_instance: Optional["MCPClientManager"] = None
_lock = asyncio.Lock()
def __init__(self):
self._session: Optional[ClientSession] = None
self._read = None
self._write = None
self._ctx = None
self._initialized = False
@classmethod
async def get_instance(cls) -> "MCPClientManager":
if cls._instance is None or not cls._instance._initialized:
async with cls._lock:
if cls._instance is None or not cls._instance._initialized:
cls._instance = cls()
try:
await asyncio.wait_for(cls._instance._start(), timeout=10)
except Exception as e:
logger.warning(f"MCP init failed: {e}")
cls._instance = None
raise
return cls._instance
async def _start(self):
params = StdioServerParameters(
command=VENV_PYTHON,
args=[SERVER_SCRIPT],
)
self._ctx = stdio_client(params)
self._read, self._write = await asyncio.wait_for(
self._ctx.__aenter__(), timeout=5
)
self._session = await asyncio.wait_for(
ClientSession(self._read, self._write).__aenter__(), timeout=5
)
await asyncio.wait_for(self._session.initialize(), timeout=5)
self._initialized = True
logger.info("MCP search client initialized")
async def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]:
if not self._initialized or self._session is None:
logger.warning("MCP client not initialized")
return []
try:
result = await asyncio.wait_for(
self._session.call_tool(
"web_search",
{"query": query, "max_results": max_results},
),
timeout=10,
)
if result.content and len(result.content) > 0:
text = result.content[0].text
data = json.loads(text)
return data.get("results", [])
return []
except (asyncio.TimeoutError, Exception) as e:
logger.warning(f"MCP search call failed: {e}")
return []
async def close(self):
self._initialized = False
MCPClientManager._instance = None
if self._session:
try:
await self._session.__aexit__(None, None, None)
except (BaseExceptionGroup, RuntimeError, Exception):
pass
if self._ctx:
try:
await self._ctx.__aexit__(None, None, None)
except (BaseExceptionGroup, RuntimeError, Exception):
pass
async def mcp_search(query: str, max_results: int = 10) -> List[Dict[str, str]]:
try:
mgr = await MCPClientManager.get_instance()
return await mgr.search(query, max_results)
except Exception as e:
logger.warning(f"MCP search failed: {e}")
return []