mcp_client.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. MCP 客户端 - 工具发现和调用(支持 SSE 响应)
  3. """
  4. import json
  5. import httpx
  6. import asyncio
  7. import re
  8. from typing import Dict, List, Any, Optional
  9. from config import MCP_SERVERS
  10. def parse_sse_response(text: str) -> str:
  11. """
  12. 解析 SSE 响应,提取 JSON 数据
  13. SSE 格式:
  14. event: message
  15. data: {...json...}
  16. Args:
  17. text: SSE 响应文本
  18. Returns:
  19. 提取的 JSON 字符串
  20. """
  21. # 查找 data: 行并提取 JSON
  22. for line in text.split('\n'):
  23. if line.startswith('data:'):
  24. data_content = line[5:].strip()
  25. if data_content:
  26. return data_content
  27. return text
  28. class MCPClient:
  29. """MCP 客户端,负责工具发现和调用"""
  30. def __init__(self, server_id: str = None, session_id: str = None, auth_token: str = None):
  31. self.server_id = server_id or "novel-translator"
  32. self.server_config = MCP_SERVERS.get(self.server_id, {})
  33. self.session_id = session_id
  34. self.auth_token = auth_token # JWT token for authenticated MCPs
  35. self.base_url = self.server_config.get("url", "")
  36. async def discover_tools(self) -> List[Dict[str, Any]]:
  37. """从 MCP 服务器发现可用工具"""
  38. if not self.base_url:
  39. return []
  40. try:
  41. async with httpx.AsyncClient(timeout=30.0, verify=False) as client:
  42. url = f"{self.base_url}"
  43. headers = {
  44. "Content-Type": "application/json",
  45. "Accept": "application/json, text/event-stream"
  46. }
  47. if self.session_id:
  48. headers["X-Session-ID"] = self.session_id
  49. # 添加 JWT token 认证
  50. if self.auth_token:
  51. headers["Authorization"] = f"Bearer {self.auth_token}"
  52. payload = {
  53. "jsonrpc": "2.0",
  54. "id": 1,
  55. "method": "tools/list"
  56. }
  57. response = await client.post(url, json=payload, headers=headers)
  58. if response.status_code == 200:
  59. # 解析 SSE 响应
  60. json_text = parse_sse_response(response.text)
  61. result = json.loads(json_text)
  62. if "result" in result:
  63. return result["result"].get("tools", [])
  64. return []
  65. except Exception as e:
  66. print(f"MCP 工具发现失败: {e}")
  67. import traceback
  68. traceback.print_exc()
  69. return []
  70. async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
  71. """调用 MCP 工具"""
  72. if not self.base_url:
  73. return {"error": "MCP 服务器未配置"}
  74. try:
  75. async with httpx.AsyncClient(timeout=120.0, verify=False) as client:
  76. url = f"{self.base_url}"
  77. headers = {
  78. "Content-Type": "application/json",
  79. "Accept": "application/json, text/event-stream"
  80. }
  81. if self.session_id:
  82. headers["X-Session-ID"] = self.session_id
  83. # 添加 JWT token 认证
  84. if self.auth_token:
  85. headers["Authorization"] = f"Bearer {self.auth_token}"
  86. # DEBUG: 打印调用详情
  87. print(f"[DEBUG MCPClient.call_tool] tool_name: {tool_name}")
  88. print(f"[DEBUG MCPClient.call_tool] server_id: {self.server_id}")
  89. print(f"[DEBUG MCPClient.call_tool] url: {url}")
  90. print(f"[DEBUG MCPClient.call_tool] auth_token present: {bool(self.auth_token)}")
  91. if self.auth_token:
  92. print(f"[DEBUG MCPClient.call_tool] auth_token: {self.auth_token[:30]}...")
  93. print(f"[DEBUG MCPClient.call_tool] Authorization header: {headers['Authorization'][:50]}...")
  94. else:
  95. print(f"[DEBUG MCPClient.call_tool] NO auth_token!")
  96. payload = {
  97. "jsonrpc": "2.0",
  98. "id": 2,
  99. "method": "tools/call",
  100. "params": {
  101. "name": tool_name,
  102. "arguments": arguments
  103. }
  104. }
  105. response = await client.post(url, json=payload, headers=headers)
  106. print(f"[DEBUG MCPClient.call_tool] response status: {response.status_code}")
  107. if response.status_code == 200:
  108. # 解析 SSE 响应
  109. json_text = parse_sse_response(response.text)
  110. result = json.loads(json_text)
  111. if "result" in result:
  112. content_list = result["result"].get("content", [])
  113. text_results = []
  114. for item in content_list:
  115. if item.get("type") == "text":
  116. text_results.append(item.get("text", ""))
  117. return {
  118. "success": True,
  119. "result": "\n".join(text_results),
  120. "raw": result["result"]
  121. }
  122. elif "error" in result:
  123. return {"error": result["error"].get("message", "Unknown error")}
  124. return {"error": f"工具调用失败: {response.status_code}"}
  125. except Exception as e:
  126. print(f"MCP 工具调用失败: {e}")
  127. import traceback
  128. traceback.print_exc()
  129. return {"error": str(e)}
  130. @staticmethod
  131. async def get_all_tools_async(session_id: str = None) -> List[Dict[str, Any]]:
  132. """获取所有已配置 MCP 服务器的工具列表(异步版本)"""
  133. all_tools = []
  134. for server_id in MCP_SERVERS.keys():
  135. if not MCP_SERVERS[server_id].get("enabled", False):
  136. continue
  137. client = MCPClient(server_id, session_id)
  138. try:
  139. tools = await client.discover_tools()
  140. for tool in tools:
  141. tool["_server_id"] = server_id
  142. all_tools.append(tool)
  143. except Exception as e:
  144. print(f"发现 {server_id} 工具失败: {e}")
  145. return all_tools
  146. @staticmethod
  147. async def get_all_tools_with_tokens_async(
  148. session_id: str = None,
  149. mcp_tokens: dict = None
  150. ) -> List[Dict[str, Any]]:
  151. """获取所有已配置 MCP 服务器的工具列表(带 token 认证)"""
  152. all_tools = []
  153. for server_id in MCP_SERVERS.keys():
  154. if not MCP_SERVERS[server_id].get("enabled", False):
  155. continue
  156. # 获取该服务器的 token
  157. auth_token = None
  158. if mcp_tokens and server_id in mcp_tokens:
  159. auth_token = mcp_tokens[server_id]
  160. client = MCPClient(server_id, session_id, auth_token)
  161. try:
  162. tools = await client.discover_tools()
  163. for tool in tools:
  164. tool["_server_id"] = server_id
  165. all_tools.append(tool)
  166. except Exception as e:
  167. print(f"发现 {server_id} 工具失败: {e}")
  168. return all_tools
  169. @staticmethod
  170. def get_all_tools(session_id: str = None) -> List[Dict[str, Any]]:
  171. """获取所有已配置 MCP 服务器的工具列表(同步版本)"""
  172. return asyncio.run(MCPClient.get_all_tools_async(session_id))