|
|
from contextlib import AsyncExitStack |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
from mcp import ClientSession, StdioServerParameters |
|
|
from mcp.client.sse import sse_client |
|
|
from mcp.client.stdio import stdio_client |
|
|
from mcp.types import ListToolsResult, TextContent |
|
|
|
|
|
from app.logger import logger |
|
|
from app.tool.base import BaseTool, ToolResult |
|
|
from app.tool.tool_collection import ToolCollection |
|
|
|
|
|
|
|
|
class MCPClientTool(BaseTool): |
|
|
"""Represents a tool proxy that can be called on the MCP server from the client side.""" |
|
|
|
|
|
session: Optional[ClientSession] = None |
|
|
server_id: str = "" |
|
|
original_name: str = "" |
|
|
|
|
|
async def execute(self, **kwargs) -> ToolResult: |
|
|
"""Execute the tool by making a remote call to the MCP server.""" |
|
|
if not self.session: |
|
|
return ToolResult(error="Not connected to MCP server") |
|
|
|
|
|
try: |
|
|
logger.info(f"Executing tool: {self.original_name}") |
|
|
result = await self.session.call_tool(self.original_name, kwargs) |
|
|
content_str = ", ".join( |
|
|
item.text for item in result.content if isinstance(item, TextContent) |
|
|
) |
|
|
return ToolResult(output=content_str or "No output returned.") |
|
|
except Exception as e: |
|
|
return ToolResult(error=f"Error executing tool: {str(e)}") |
|
|
|
|
|
|
|
|
class MCPClients(ToolCollection): |
|
|
""" |
|
|
A collection of tools that connects to multiple MCP servers and manages available tools through the Model Context Protocol. |
|
|
""" |
|
|
|
|
|
sessions: Dict[str, ClientSession] = {} |
|
|
exit_stacks: Dict[str, AsyncExitStack] = {} |
|
|
description: str = "MCP client tools for server interaction" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.name = "mcp" |
|
|
|
|
|
async def connect_sse(self, server_url: str, server_id: str = "") -> None: |
|
|
"""Connect to an MCP server using SSE transport.""" |
|
|
if not server_url: |
|
|
raise ValueError("Server URL is required.") |
|
|
|
|
|
server_id = server_id or server_url |
|
|
|
|
|
|
|
|
if server_id in self.sessions: |
|
|
await self.disconnect(server_id) |
|
|
|
|
|
exit_stack = AsyncExitStack() |
|
|
self.exit_stacks[server_id] = exit_stack |
|
|
|
|
|
streams_context = sse_client(url=server_url) |
|
|
streams = await exit_stack.enter_async_context(streams_context) |
|
|
session = await exit_stack.enter_async_context(ClientSession(*streams)) |
|
|
self.sessions[server_id] = session |
|
|
|
|
|
await self._initialize_and_list_tools(server_id) |
|
|
|
|
|
async def connect_stdio( |
|
|
self, command: str, args: List[str], server_id: str = "" |
|
|
) -> None: |
|
|
"""Connect to an MCP server using stdio transport.""" |
|
|
if not command: |
|
|
raise ValueError("Server command is required.") |
|
|
|
|
|
server_id = server_id or command |
|
|
|
|
|
|
|
|
if server_id in self.sessions: |
|
|
await self.disconnect(server_id) |
|
|
|
|
|
exit_stack = AsyncExitStack() |
|
|
self.exit_stacks[server_id] = exit_stack |
|
|
|
|
|
server_params = StdioServerParameters(command=command, args=args) |
|
|
stdio_transport = await exit_stack.enter_async_context( |
|
|
stdio_client(server_params) |
|
|
) |
|
|
read, write = stdio_transport |
|
|
session = await exit_stack.enter_async_context(ClientSession(read, write)) |
|
|
self.sessions[server_id] = session |
|
|
|
|
|
await self._initialize_and_list_tools(server_id) |
|
|
|
|
|
async def _initialize_and_list_tools(self, server_id: str) -> None: |
|
|
"""Initialize session and populate tool map.""" |
|
|
session = self.sessions.get(server_id) |
|
|
if not session: |
|
|
raise RuntimeError(f"Session not initialized for server {server_id}") |
|
|
|
|
|
await session.initialize() |
|
|
response = await session.list_tools() |
|
|
|
|
|
|
|
|
for tool in response.tools: |
|
|
original_name = tool.name |
|
|
tool_name = f"mcp_{server_id}_{original_name}" |
|
|
tool_name = self._sanitize_tool_name(tool_name) |
|
|
|
|
|
server_tool = MCPClientTool( |
|
|
name=tool_name, |
|
|
description=tool.description, |
|
|
parameters=tool.inputSchema, |
|
|
session=session, |
|
|
server_id=server_id, |
|
|
original_name=original_name, |
|
|
) |
|
|
self.tool_map[tool_name] = server_tool |
|
|
|
|
|
|
|
|
self.tools = tuple(self.tool_map.values()) |
|
|
logger.info( |
|
|
f"Connected to server {server_id} with tools: {[tool.name for tool in response.tools]}" |
|
|
) |
|
|
|
|
|
def _sanitize_tool_name(self, name: str) -> str: |
|
|
"""Sanitize tool name to match MCPClientTool requirements.""" |
|
|
import re |
|
|
|
|
|
|
|
|
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name) |
|
|
|
|
|
|
|
|
sanitized = re.sub(r"_+", "_", sanitized) |
|
|
|
|
|
|
|
|
sanitized = sanitized.strip("_") |
|
|
|
|
|
|
|
|
if len(sanitized) > 64: |
|
|
sanitized = sanitized[:64] |
|
|
|
|
|
return sanitized |
|
|
|
|
|
async def list_tools(self) -> ListToolsResult: |
|
|
"""List all available tools.""" |
|
|
tools_result = ListToolsResult(tools=[]) |
|
|
for session in self.sessions.values(): |
|
|
response = await session.list_tools() |
|
|
tools_result.tools += response.tools |
|
|
return tools_result |
|
|
|
|
|
async def disconnect(self, server_id: str = "") -> None: |
|
|
"""Disconnect from a specific MCP server or all servers if no server_id provided.""" |
|
|
if server_id: |
|
|
if server_id in self.sessions: |
|
|
try: |
|
|
exit_stack = self.exit_stacks.get(server_id) |
|
|
|
|
|
|
|
|
if exit_stack: |
|
|
try: |
|
|
await exit_stack.aclose() |
|
|
except RuntimeError as e: |
|
|
if "cancel scope" in str(e).lower(): |
|
|
logger.warning( |
|
|
f"Cancel scope error during disconnect from {server_id}, continuing with cleanup: {e}" |
|
|
) |
|
|
else: |
|
|
raise |
|
|
|
|
|
|
|
|
self.sessions.pop(server_id, None) |
|
|
self.exit_stacks.pop(server_id, None) |
|
|
|
|
|
|
|
|
self.tool_map = { |
|
|
k: v |
|
|
for k, v in self.tool_map.items() |
|
|
if v.server_id != server_id |
|
|
} |
|
|
self.tools = tuple(self.tool_map.values()) |
|
|
logger.info(f"Disconnected from MCP server {server_id}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error disconnecting from server {server_id}: {e}") |
|
|
else: |
|
|
|
|
|
for sid in sorted(list(self.sessions.keys())): |
|
|
await self.disconnect(sid) |
|
|
self.tool_map = {} |
|
|
self.tools = tuple() |
|
|
logger.info("Disconnected from all MCP servers") |
|
|
|