File size: 7,542 Bytes
88f3fce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from contextlib import AsyncExitStack
from typing import Dict, List, Optional
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import ListToolsResult, TextContent
from app.logger import logger
from app.tool.base import BaseTool, ToolResult
from app.tool.tool_collection import ToolCollection
class MCPClientTool(BaseTool):
"""Represents a tool proxy that can be called on the MCP server from the client side."""
session: Optional[ClientSession] = None
server_id: str = "" # Add server identifier
original_name: str = ""
async def execute(self, **kwargs) -> ToolResult:
"""Execute the tool by making a remote call to the MCP server."""
if not self.session:
return ToolResult(error="Not connected to MCP server")
try:
logger.info(f"Executing tool: {self.original_name}")
result = await self.session.call_tool(self.original_name, kwargs)
content_str = ", ".join(
item.text for item in result.content if isinstance(item, TextContent)
)
return ToolResult(output=content_str or "No output returned.")
except Exception as e:
return ToolResult(error=f"Error executing tool: {str(e)}")
class MCPClients(ToolCollection):
"""
A collection of tools that connects to multiple MCP servers and manages available tools through the Model Context Protocol.
"""
sessions: Dict[str, ClientSession] = {}
exit_stacks: Dict[str, AsyncExitStack] = {}
description: str = "MCP client tools for server interaction"
def __init__(self):
super().__init__() # Initialize with empty tools list
self.name = "mcp" # Keep name for backward compatibility
async def connect_sse(self, server_url: str, server_id: str = "") -> None:
"""Connect to an MCP server using SSE transport."""
if not server_url:
raise ValueError("Server URL is required.")
server_id = server_id or server_url
# Always ensure clean disconnection before new connection
if server_id in self.sessions:
await self.disconnect(server_id)
exit_stack = AsyncExitStack()
self.exit_stacks[server_id] = exit_stack
streams_context = sse_client(url=server_url)
streams = await exit_stack.enter_async_context(streams_context)
session = await exit_stack.enter_async_context(ClientSession(*streams))
self.sessions[server_id] = session
await self._initialize_and_list_tools(server_id)
async def connect_stdio(
self, command: str, args: List[str], server_id: str = ""
) -> None:
"""Connect to an MCP server using stdio transport."""
if not command:
raise ValueError("Server command is required.")
server_id = server_id or command
# Always ensure clean disconnection before new connection
if server_id in self.sessions:
await self.disconnect(server_id)
exit_stack = AsyncExitStack()
self.exit_stacks[server_id] = exit_stack
server_params = StdioServerParameters(command=command, args=args)
stdio_transport = await exit_stack.enter_async_context(
stdio_client(server_params)
)
read, write = stdio_transport
session = await exit_stack.enter_async_context(ClientSession(read, write))
self.sessions[server_id] = session
await self._initialize_and_list_tools(server_id)
async def _initialize_and_list_tools(self, server_id: str) -> None:
"""Initialize session and populate tool map."""
session = self.sessions.get(server_id)
if not session:
raise RuntimeError(f"Session not initialized for server {server_id}")
await session.initialize()
response = await session.list_tools()
# Create proper tool objects for each server tool
for tool in response.tools:
original_name = tool.name
tool_name = f"mcp_{server_id}_{original_name}"
tool_name = self._sanitize_tool_name(tool_name)
server_tool = MCPClientTool(
name=tool_name,
description=tool.description,
parameters=tool.inputSchema,
session=session,
server_id=server_id,
original_name=original_name,
)
self.tool_map[tool_name] = server_tool
# Update tools tuple
self.tools = tuple(self.tool_map.values())
logger.info(
f"Connected to server {server_id} with tools: {[tool.name for tool in response.tools]}"
)
def _sanitize_tool_name(self, name: str) -> str:
"""Sanitize tool name to match MCPClientTool requirements."""
import re
# Replace invalid characters with underscores
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
# Remove consecutive underscores
sanitized = re.sub(r"_+", "_", sanitized)
# Remove leading/trailing underscores
sanitized = sanitized.strip("_")
# Truncate to 64 characters if needed
if len(sanitized) > 64:
sanitized = sanitized[:64]
return sanitized
async def list_tools(self) -> ListToolsResult:
"""List all available tools."""
tools_result = ListToolsResult(tools=[])
for session in self.sessions.values():
response = await session.list_tools()
tools_result.tools += response.tools
return tools_result
async def disconnect(self, server_id: str = "") -> None:
"""Disconnect from a specific MCP server or all servers if no server_id provided."""
if server_id:
if server_id in self.sessions:
try:
exit_stack = self.exit_stacks.get(server_id)
# Close the exit stack which will handle session cleanup
if exit_stack:
try:
await exit_stack.aclose()
except RuntimeError as e:
if "cancel scope" in str(e).lower():
logger.warning(
f"Cancel scope error during disconnect from {server_id}, continuing with cleanup: {e}"
)
else:
raise
# Clean up references
self.sessions.pop(server_id, None)
self.exit_stacks.pop(server_id, None)
# Remove tools associated with this server
self.tool_map = {
k: v
for k, v in self.tool_map.items()
if v.server_id != server_id
}
self.tools = tuple(self.tool_map.values())
logger.info(f"Disconnected from MCP server {server_id}")
except Exception as e:
logger.error(f"Error disconnecting from server {server_id}: {e}")
else:
# Disconnect from all servers in a deterministic order
for sid in sorted(list(self.sessions.keys())):
await self.disconnect(sid)
self.tool_map = {}
self.tools = tuple()
logger.info("Disconnected from all MCP servers")
|