LiamKhoaLe commited on
Commit
1034c81
·
1 Parent(s): 4f116ec

Simplify MCP arch #4

Browse files
Files changed (2) hide show
  1. agent.py +10 -4
  2. app.py +24 -20
agent.py CHANGED
@@ -16,7 +16,8 @@ from pathlib import Path
16
 
17
  # MCP imports
18
  try:
19
- from mcp.server import Server
 
20
  from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource
21
  from mcp.server.models import InitializationOptions
22
  except ImportError:
@@ -285,12 +286,17 @@ async def main():
285
  # Prepare server capabilities for initialization
286
  try:
287
  if hasattr(app, "get_capabilities"):
288
- server_capabilities = app.get_capabilities()
 
 
 
 
 
289
  else:
290
- server_capabilities = {}
291
  except Exception as cap_error:
292
  logger.warning(f"Failed to gather server capabilities: {cap_error}")
293
- server_capabilities = {}
294
 
295
  init_options = InitializationOptions(
296
  server_name="gemini-mcp-server",
 
16
 
17
  # MCP imports
18
  try:
19
+ from mcp import types as mcp_types
20
+ from mcp.server import Server, NotificationOptions
21
  from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource
22
  from mcp.server.models import InitializationOptions
23
  except ImportError:
 
286
  # Prepare server capabilities for initialization
287
  try:
288
  if hasattr(app, "get_capabilities"):
289
+ notification_options = NotificationOptions()
290
+ experimental_capabilities: dict[str, dict[str, Any]] = {}
291
+ server_capabilities = app.get_capabilities(
292
+ notification_options=notification_options,
293
+ experimental_capabilities=experimental_capabilities,
294
+ )
295
  else:
296
+ server_capabilities = mcp_types.ServerCapabilities()
297
  except Exception as cap_error:
298
  logger.warning(f"Failed to gather server capabilities: {cap_error}")
299
+ server_capabilities = mcp_types.ServerCapabilities()
300
 
301
  init_options = InitializationOptions(
302
  server_name="gemini-mcp-server",
app.py CHANGED
@@ -43,8 +43,10 @@ mcp_client_logger.setLevel(logging.WARNING)
43
  hf_logging.set_verbosity_error()
44
 
45
  # MCP imports
 
46
  try:
47
  from mcp import ClientSession, StdioServerParameters
 
48
  from mcp.client.stdio import stdio_client
49
  import asyncio
50
  try:
@@ -52,8 +54,11 @@ try:
52
  nest_asyncio.apply() # Allow nested event loops
53
  except ImportError:
54
  pass # nest_asyncio is optional
55
-
56
  MCP_AVAILABLE = True
 
 
 
 
57
  except ImportError as e:
58
  logger.warning(f"MCP SDK not available: {e}")
59
  MCP_AVAILABLE = False
@@ -265,7 +270,11 @@ async def get_mcp_session():
265
  read, write = await stdio_ctx.__aenter__()
266
 
267
  # Create ClientSession from the streams
268
- session = ClientSession(read, write)
 
 
 
 
269
 
270
  # Initialize the session (this sends initialize request and waits for response + initialized notification)
271
  # The __aenter__() method handles the complete initialization handshake:
@@ -279,16 +288,24 @@ async def get_mcp_session():
279
  # including waiting for the server's initialized notification
280
  # This is a blocking call that completes only after the server sends initialized
281
  await session.__aenter__()
282
- logger.info("✅ MCP session initialized")
 
 
 
 
283
  except Exception as e:
284
  error_msg = str(e)
285
  error_type = type(e).__name__
286
  logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}")
287
 
288
  # Clean up and return None
 
 
 
 
289
  try:
290
  await stdio_ctx.__aexit__(None, None, None)
291
- except:
292
  pass
293
  return None
294
 
@@ -317,25 +334,12 @@ async def call_agent(user_prompt: str, system_prompt: str = None, files: list =
317
  logger.warning("Failed to get MCP session for Gemini call")
318
  return ""
319
 
320
- # List tools - session should be ready after proper initialization
321
- # Add a small delay to ensure server has fully processed initialization
322
- await asyncio.sleep(0.1)
323
  try:
324
  tools = await session.list_tools()
325
  except Exception as e:
326
- error_msg = str(e)
327
- # Check if it's an initialization error
328
- if "initialization" in error_msg.lower() or "before initialization" in error_msg.lower():
329
- logger.warning(f"⚠️ Server not ready yet, waiting a bit more...: {error_msg}")
330
- await asyncio.sleep(0.5)
331
- try:
332
- tools = await session.list_tools()
333
- except Exception as retry_error:
334
- logger.error(f"❌ Failed to list MCP tools after retry: {retry_error}")
335
- return ""
336
- else:
337
- logger.error(f"❌ Failed to list MCP tools: {error_msg}")
338
- return ""
339
 
340
  if not tools or not hasattr(tools, 'tools'):
341
  logger.error("Invalid tools response from MCP server")
 
43
  hf_logging.set_verbosity_error()
44
 
45
  # MCP imports
46
+ MCP_CLIENT_INFO = None
47
  try:
48
  from mcp import ClientSession, StdioServerParameters
49
+ from mcp import types as mcp_types
50
  from mcp.client.stdio import stdio_client
51
  import asyncio
52
  try:
 
54
  nest_asyncio.apply() # Allow nested event loops
55
  except ImportError:
56
  pass # nest_asyncio is optional
 
57
  MCP_AVAILABLE = True
58
+ MCP_CLIENT_INFO = mcp_types.Implementation(
59
+ name="MedLLM-Agent",
60
+ version=os.environ.get("SPACE_VERSION", "local"),
61
+ )
62
  except ImportError as e:
63
  logger.warning(f"MCP SDK not available: {e}")
64
  MCP_AVAILABLE = False
 
270
  read, write = await stdio_ctx.__aenter__()
271
 
272
  # Create ClientSession from the streams
273
+ session = ClientSession(
274
+ read,
275
+ write,
276
+ client_info=MCP_CLIENT_INFO,
277
+ )
278
 
279
  # Initialize the session (this sends initialize request and waits for response + initialized notification)
280
  # The __aenter__() method handles the complete initialization handshake:
 
288
  # including waiting for the server's initialized notification
289
  # This is a blocking call that completes only after the server sends initialized
290
  await session.__aenter__()
291
+ init_result = await session.initialize()
292
+ server_info = getattr(init_result, "serverInfo", None)
293
+ server_name = getattr(server_info, "name", "unknown")
294
+ server_version = getattr(server_info, "version", "unknown")
295
+ logger.info(f"✅ MCP session initialized (server={server_name} v{server_version})")
296
  except Exception as e:
297
  error_msg = str(e)
298
  error_type = type(e).__name__
299
  logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}")
300
 
301
  # Clean up and return None
302
+ try:
303
+ await session.__aexit__(None, None, None)
304
+ except Exception:
305
+ pass
306
  try:
307
  await stdio_ctx.__aexit__(None, None, None)
308
+ except Exception:
309
  pass
310
  return None
311
 
 
334
  logger.warning("Failed to get MCP session for Gemini call")
335
  return ""
336
 
337
+ # List tools - session is fully initialized via ClientSession.initialize()
 
 
338
  try:
339
  tools = await session.list_tools()
340
  except Exception as e:
341
+ logger.error(f"❌ Failed to list MCP tools: {e}")
342
+ return ""
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  if not tools or not hasattr(tools, 'tools'):
345
  logger.error("Invalid tools response from MCP server")