mcp_client.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. """
  2. MCP 客户端 - 工具发现和调用(支持 SSE 响应)
  3. """
  4. import json
  5. import httpx
  6. import asyncio
  7. import re
  8. from typing import Dict, List, Any, Optional
  9. from config import MCP_SERVERS
  10. def parse_sse_response(text: str) -> str:
  11. """
  12. 解析 SSE 响应,提取 JSON 数据
  13. SSE 格式:
  14. event: message
  15. data: {...json...}
  16. Args:
  17. text: SSE 响应文本
  18. Returns:
  19. 提取的 JSON 字符串
  20. """
  21. # 查找 data: 行并提取 JSON
  22. for line in text.split('\n'):
  23. if line.startswith('data:'):
  24. data_content = line[5:].strip()
  25. if data_content:
  26. return data_content
  27. return text
  28. class MCPClient:
  29. """MCP 客户端,负责工具发现和调用"""
  30. def __init__(self, server_id: str = None, session_id: str = None, auth_token: str = None):
  31. self.server_id = server_id or "novel-translator"
  32. self.server_config = MCP_SERVERS.get(self.server_id, {})
  33. self.session_id = session_id
  34. self.auth_token = auth_token # JWT token for authenticated MCPs
  35. self.base_url = self.server_config.get("url", "")
  36. async def discover_tools(self) -> List[Dict[str, Any]]:
  37. """从 MCP 服务器发现可用工具"""
  38. if not self.base_url:
  39. return []
  40. try:
  41. async with httpx.AsyncClient(timeout=30.0, verify=False) as client:
  42. url = f"{self.base_url}"
  43. headers = {
  44. "Content-Type": "application/json",
  45. "Accept": "application/json, text/event-stream"
  46. }
  47. if self.session_id:
  48. headers["X-Session-ID"] = self.session_id
  49. # 添加 JWT token 认证
  50. if self.auth_token:
  51. headers["Authorization"] = f"Bearer {self.auth_token}"
  52. payload = {
  53. "jsonrpc": "2.0",
  54. "id": 1,
  55. "method": "tools/list"
  56. }
  57. response = await client.post(url, json=payload, headers=headers)
  58. if response.status_code == 200:
  59. # 解析 SSE 响应
  60. json_text = parse_sse_response(response.text)
  61. result = json.loads(json_text)
  62. if "result" in result:
  63. return result["result"].get("tools", [])
  64. return []
  65. except Exception as e:
  66. print(f"MCP 工具发现失败: {e}")
  67. import traceback
  68. traceback.print_exc()
  69. return []
  70. async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
  71. """调用 MCP 工具"""
  72. if not self.base_url:
  73. return {"error": "MCP 服务器未配置"}
  74. try:
  75. async with httpx.AsyncClient(timeout=120.0, verify=False) as client:
  76. url = f"{self.base_url}"
  77. headers = {
  78. "Content-Type": "application/json",
  79. "Accept": "application/json, text/event-stream"
  80. }
  81. if self.session_id:
  82. headers["X-Session-ID"] = self.session_id
  83. # 添加 JWT token 认证
  84. if self.auth_token:
  85. headers["Authorization"] = f"Bearer {self.auth_token}"
  86. payload = {
  87. "jsonrpc": "2.0",
  88. "id": 2,
  89. "method": "tools/call",
  90. "params": {
  91. "name": tool_name,
  92. "arguments": arguments
  93. }
  94. }
  95. response = await client.post(url, json=payload, headers=headers)
  96. if response.status_code == 200:
  97. # 解析 SSE 响应
  98. json_text = parse_sse_response(response.text)
  99. result = json.loads(json_text)
  100. if "result" in result:
  101. content_list = result["result"].get("content", [])
  102. text_results = []
  103. for item in content_list:
  104. if item.get("type") == "text":
  105. text_results.append(item.get("text", ""))
  106. return {
  107. "success": True,
  108. "result": "\n".join(text_results),
  109. "raw": result["result"]
  110. }
  111. elif "error" in result:
  112. return {"error": result["error"].get("message", "Unknown error")}
  113. return {"error": f"工具调用失败: {response.status_code}"}
  114. except Exception as e:
  115. print(f"MCP 工具调用失败: {e}")
  116. import traceback
  117. traceback.print_exc()
  118. return {"error": str(e)}
  119. @staticmethod
  120. async def get_all_tools_async(session_id: str = None) -> List[Dict[str, Any]]:
  121. """获取所有已配置 MCP 服务器的工具列表(异步版本)"""
  122. all_tools = []
  123. for server_id in MCP_SERVERS.keys():
  124. if not MCP_SERVERS[server_id].get("enabled", False):
  125. continue
  126. client = MCPClient(server_id, session_id)
  127. try:
  128. tools = await client.discover_tools()
  129. for tool in tools:
  130. tool["_server_id"] = server_id
  131. all_tools.append(tool)
  132. except Exception as e:
  133. print(f"发现 {server_id} 工具失败: {e}")
  134. return all_tools
  135. @staticmethod
  136. async def get_all_tools_with_tokens_async(
  137. session_id: str = None,
  138. mcp_tokens: dict = None
  139. ) -> List[Dict[str, Any]]:
  140. """获取所有已配置 MCP 服务器的工具列表(带 token 认证)"""
  141. all_tools = []
  142. for server_id in MCP_SERVERS.keys():
  143. if not MCP_SERVERS[server_id].get("enabled", False):
  144. continue
  145. # 获取该服务器的 token
  146. auth_token = None
  147. if mcp_tokens and server_id in mcp_tokens:
  148. auth_token = mcp_tokens[server_id]
  149. client = MCPClient(server_id, session_id, auth_token)
  150. try:
  151. tools = await client.discover_tools()
  152. for tool in tools:
  153. tool["_server_id"] = server_id
  154. all_tools.append(tool)
  155. except Exception as e:
  156. print(f"发现 {server_id} 工具失败: {e}")
  157. return all_tools
  158. @staticmethod
  159. def get_all_tools(session_id: str = None) -> List[Dict[str, Any]]:
  160. """获取所有已配置 MCP 服务器的工具列表(同步版本)"""
  161. return asyncio.run(MCPClient.get_all_tools_async(session_id))