2
0

mcp_client.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. MCP 客户端 - 工具发现和调用(支持 SSE 响应)
  3. """
  4. import json
  5. import httpx
  6. import asyncio
  7. import re
  8. from typing import Dict, List, Any, Optional
  9. from config import MCP_SERVERS
  10. def parse_sse_response(text: str) -> str:
  11. """
  12. 解析 SSE 响应,提取 JSON 数据
  13. SSE 格式:
  14. event: message
  15. data: {...json...}
  16. Args:
  17. text: SSE 响应文本
  18. Returns:
  19. 提取的 JSON 字符串
  20. """
  21. # 查找 data: 行并提取 JSON
  22. for line in text.split('\n'):
  23. if line.startswith('data:'):
  24. data_content = line[5:].strip()
  25. if data_content:
  26. return data_content
  27. return text
  28. class MCPClient:
  29. """MCP 客户端,负责工具发现和调用"""
  30. def __init__(self, server_id: str = None, session_id: str = None):
  31. self.server_id = server_id or "novel-translator"
  32. self.server_config = MCP_SERVERS.get(self.server_id, {})
  33. self.session_id = session_id
  34. self.base_url = self.server_config.get("url", "")
  35. async def discover_tools(self) -> List[Dict[str, Any]]:
  36. """从 MCP 服务器发现可用工具"""
  37. if not self.base_url:
  38. return []
  39. try:
  40. async with httpx.AsyncClient(timeout=30.0, verify=False) as client:
  41. url = f"{self.base_url}"
  42. headers = {
  43. "Content-Type": "application/json",
  44. "Accept": "application/json, text/event-stream"
  45. }
  46. if self.session_id:
  47. headers["X-Session-ID"] = self.session_id
  48. payload = {
  49. "jsonrpc": "2.0",
  50. "id": 1,
  51. "method": "tools/list"
  52. }
  53. response = await client.post(url, json=payload, headers=headers)
  54. if response.status_code == 200:
  55. # 解析 SSE 响应
  56. json_text = parse_sse_response(response.text)
  57. result = json.loads(json_text)
  58. if "result" in result:
  59. return result["result"].get("tools", [])
  60. return []
  61. except Exception as e:
  62. print(f"MCP 工具发现失败: {e}")
  63. import traceback
  64. traceback.print_exc()
  65. return []
  66. async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
  67. """调用 MCP 工具"""
  68. if not self.base_url:
  69. return {"error": "MCP 服务器未配置"}
  70. try:
  71. async with httpx.AsyncClient(timeout=120.0, verify=False) as client:
  72. url = f"{self.base_url}"
  73. headers = {
  74. "Content-Type": "application/json",
  75. "Accept": "application/json, text/event-stream"
  76. }
  77. if self.session_id:
  78. headers["X-Session-ID"] = self.session_id
  79. payload = {
  80. "jsonrpc": "2.0",
  81. "id": 2,
  82. "method": "tools/call",
  83. "params": {
  84. "name": tool_name,
  85. "arguments": arguments
  86. }
  87. }
  88. response = await client.post(url, json=payload, headers=headers)
  89. if response.status_code == 200:
  90. # 解析 SSE 响应
  91. json_text = parse_sse_response(response.text)
  92. result = json.loads(json_text)
  93. if "result" in result:
  94. content_list = result["result"].get("content", [])
  95. text_results = []
  96. for item in content_list:
  97. if item.get("type") == "text":
  98. text_results.append(item.get("text", ""))
  99. return {
  100. "success": True,
  101. "result": "\n".join(text_results),
  102. "raw": result["result"]
  103. }
  104. elif "error" in result:
  105. return {"error": result["error"].get("message", "Unknown error")}
  106. return {"error": f"工具调用失败: {response.status_code}"}
  107. except Exception as e:
  108. print(f"MCP 工具调用失败: {e}")
  109. import traceback
  110. traceback.print_exc()
  111. return {"error": str(e)}
  112. @staticmethod
  113. async def get_all_tools_async(session_id: str = None) -> List[Dict[str, Any]]:
  114. """获取所有已配置 MCP 服务器的工具列表(异步版本)"""
  115. all_tools = []
  116. for server_id in MCP_SERVERS.keys():
  117. if not MCP_SERVERS[server_id].get("enabled", False):
  118. continue
  119. client = MCPClient(server_id, session_id)
  120. try:
  121. tools = await client.discover_tools()
  122. for tool in tools:
  123. tool["_server_id"] = server_id
  124. all_tools.append(tool)
  125. except Exception as e:
  126. print(f"发现 {server_id} 工具失败: {e}")
  127. return all_tools
  128. @staticmethod
  129. def get_all_tools(session_id: str = None) -> List[Dict[str, Any]]:
  130. """获取所有已配置 MCP 服务器的工具列表(同步版本)"""
  131. return asyncio.run(MCPClient.get_all_tools_async(session_id))