mcp_client.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. """
  2. MCP 客户端 - 工具发现和调用(支持 SSE 响应)
  3. """
  4. import json
  5. import httpx
  6. import asyncio
  7. import re
  8. from typing import Dict, List, Any, Optional
  9. from config import MCP_SERVERS
  10. from debug_logger import log_debug
  11. def parse_sse_response(text: str) -> str:
  12. """
  13. 解析 SSE 响应,提取 JSON 数据
  14. SSE 格式:
  15. event: message
  16. data: {...json...}
  17. Args:
  18. text: SSE 响应文本
  19. Returns:
  20. 提取的 JSON 字符串
  21. """
  22. # 查找 data: 行并提取 JSON
  23. for line in text.split('\n'):
  24. if line.startswith('data:'):
  25. data_content = line[5:].strip()
  26. if data_content:
  27. return data_content
  28. return text
  29. class MCPClient:
  30. """MCP 客户端,负责工具发现和调用"""
  31. def __init__(self, server_id: str = None, session_id: str = None, auth_token: str = None):
  32. self.server_id = server_id or "novel-translator"
  33. self.server_config = MCP_SERVERS.get(self.server_id, {})
  34. self.session_id = session_id
  35. self.auth_token = auth_token # JWT token for authenticated MCPs
  36. self.base_url = self.server_config.get("url", "")
  37. async def discover_tools(self) -> List[Dict[str, Any]]:
  38. """从 MCP 服务器发现可用工具"""
  39. if not self.base_url:
  40. return []
  41. try:
  42. async with httpx.AsyncClient(timeout=30.0, verify=False) as client:
  43. url = f"{self.base_url}"
  44. headers = {
  45. "Content-Type": "application/json",
  46. "Accept": "application/json, text/event-stream"
  47. }
  48. if self.session_id:
  49. headers["X-Session-ID"] = self.session_id
  50. # 添加 JWT token 认证
  51. if self.auth_token:
  52. headers["Authorization"] = f"Bearer {self.auth_token}"
  53. payload = {
  54. "jsonrpc": "2.0",
  55. "id": 1,
  56. "method": "tools/list"
  57. }
  58. response = await client.post(url, json=payload, headers=headers)
  59. if response.status_code == 200:
  60. # 解析 SSE 响应
  61. json_text = parse_sse_response(response.text)
  62. result = json.loads(json_text)
  63. if "result" in result:
  64. return result["result"].get("tools", [])
  65. return []
  66. except Exception as e:
  67. print(f"MCP 工具发现失败: {e}")
  68. import traceback
  69. traceback.print_exc()
  70. return []
  71. async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
  72. """调用 MCP 工具"""
  73. if not self.base_url:
  74. return {"error": "MCP 服务器未配置"}
  75. try:
  76. async with httpx.AsyncClient(timeout=120.0, verify=False) as client:
  77. url = f"{self.base_url}"
  78. headers = {
  79. "Content-Type": "application/json",
  80. "Accept": "application/json, text/event-stream"
  81. }
  82. if self.session_id:
  83. headers["X-Session-ID"] = self.session_id
  84. # 添加 JWT token 认证
  85. if self.auth_token:
  86. headers["Authorization"] = f"Bearer {self.auth_token}"
  87. # DEBUG: 打印调用详情
  88. print(f"[DEBUG MCPClient.call_tool] tool_name: {tool_name}")
  89. print(f"[DEBUG MCPClient.call_tool] server_id: {self.server_id}")
  90. print(f"[DEBUG MCPClient.call_tool] url: {url}")
  91. print(f"[DEBUG MCPClient.call_tool] auth_token present: {bool(self.auth_token)}")
  92. if self.auth_token:
  93. print(f"[DEBUG MCPClient.call_tool] auth_token: {self.auth_token[:30]}...")
  94. print(f"[DEBUG MCPClient.call_tool] Authorization header: {headers['Authorization'][:50]}...")
  95. else:
  96. print(f"[DEBUG MCPClient.call_tool] NO auth_token!")
  97. payload = {
  98. "jsonrpc": "2.0",
  99. "id": 2,
  100. "method": "tools/call",
  101. "params": {
  102. "name": tool_name,
  103. "arguments": arguments
  104. }
  105. }
  106. response = await client.post(url, json=payload, headers=headers)
  107. print(f"[DEBUG MCPClient.call_tool] response status: {response.status_code}")
  108. if response.status_code == 200:
  109. # 解析 SSE 响应
  110. json_text = parse_sse_response(response.text)
  111. result = json.loads(json_text)
  112. if "result" in result:
  113. content_list = result["result"].get("content", [])
  114. text_results = []
  115. for item in content_list:
  116. if item.get("type") == "text":
  117. text_results.append(item.get("text", ""))
  118. return {
  119. "success": True,
  120. "result": "\n".join(text_results),
  121. "raw": result["result"]
  122. }
  123. elif "error" in result:
  124. return {"error": result["error"].get("message", "Unknown error")}
  125. return {"error": f"工具调用失败: {response.status_code}"}
  126. except Exception as e:
  127. print(f"MCP 工具调用失败: {e}")
  128. import traceback
  129. traceback.print_exc()
  130. return {"error": str(e)}
  131. @staticmethod
  132. async def get_all_tools_async(session_id: str = None) -> List[Dict[str, Any]]:
  133. """获取所有已配置 MCP 服务器的工具列表(异步版本)"""
  134. all_tools = []
  135. for server_id in MCP_SERVERS.keys():
  136. if not MCP_SERVERS[server_id].get("enabled", False):
  137. continue
  138. client = MCPClient(server_id, session_id)
  139. try:
  140. tools = await client.discover_tools()
  141. for tool in tools:
  142. tool["_server_id"] = server_id
  143. all_tools.append(tool)
  144. except Exception as e:
  145. print(f"发现 {server_id} 工具失败: {e}")
  146. return all_tools
  147. @staticmethod
  148. async def get_all_tools_with_tokens_async(
  149. session_id: str = None,
  150. mcp_tokens: dict = None,
  151. enabled_mcp_list: list = None # 新增:前端传递的已启用 MCP 列表
  152. ) -> List[Dict[str, Any]]:
  153. """
  154. 获取所有已配置 MCP 服务器的工具列表(带 token 认证)
  155. Args:
  156. session_id: 会话 ID
  157. mcp_tokens: MCP token 映射
  158. enabled_mcp_list: 前端传递的已启用 MCP 列表(优先级高于 config.py 中的配置)
  159. - None: 前端未传递,使用 config.py fallback
  160. - []: 前端明确禁用所有 MCP
  161. - [xxx]: 前端指定启用的 MCP 列表
  162. """
  163. all_tools = []
  164. # DEBUG: 打印接收到的启用列表
  165. print(f"[DEBUG MCPClient.get_all_tools_with_tokens_async] enabled_mcp_list: {enabled_mcp_list}")
  166. log_debug("mcp_client.get_all_tools_with_tokens_async_start", {
  167. "enabled_mcp_list": enabled_mcp_list,
  168. "enabled_mcp_list_type": str(type(enabled_mcp_list)),
  169. "enabled_mcp_list_is_none": enabled_mcp_list is None,
  170. "enabled_mcp_list_len": len(enabled_mcp_list) if enabled_mcp_list is not None else "N/A"
  171. })
  172. for server_id in MCP_SERVERS.keys():
  173. # 优先使用前端传递的启用列表
  174. log_debug("mcp_client.checking_server", {
  175. "server_id": server_id,
  176. "enabled_mcp_list": enabled_mcp_list,
  177. "is_none": enabled_mcp_list is None
  178. })
  179. if enabled_mcp_list is not None:
  180. # 前端传递了启用列表(可能是空数组),只处理列表中的 MCP
  181. in_list = server_id in enabled_mcp_list
  182. log_debug("mcp_client.server_in_list_check", {
  183. "server_id": server_id,
  184. "in_enabled_list": in_list,
  185. "will_skip": not in_list
  186. })
  187. if server_id not in enabled_mcp_list:
  188. print(f"[DEBUG MCPClient] Skipping {server_id} (not in enabled_mcp_list from frontend)")
  189. continue
  190. else:
  191. # 前端未传递启用列表(None),使用配置文件中的 enabled 状态作为 fallback
  192. if not MCP_SERVERS[server_id].get("enabled", False):
  193. continue
  194. # 获取该服务器的 token
  195. auth_token = None
  196. if mcp_tokens and server_id in mcp_tokens:
  197. auth_token = mcp_tokens[server_id]
  198. print(f"[DEBUG MCPClient] Discovering tools for {server_id} (has_token: {bool(auth_token)})")
  199. client = MCPClient(server_id, session_id, auth_token)
  200. try:
  201. tools = await client.discover_tools()
  202. for tool in tools:
  203. tool["_server_id"] = server_id
  204. all_tools.append(tool)
  205. except Exception as e:
  206. print(f"发现 {server_id} 工具失败: {e}")
  207. return all_tools
  208. @staticmethod
  209. def get_all_tools(session_id: str = None) -> List[Dict[str, Any]]:
  210. """获取所有已配置 MCP 服务器的工具列表(同步版本)"""
  211. return asyncio.run(MCPClient.get_all_tools_async(session_id))