| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- """
- MCP 客户端 - 工具发现和调用(支持 SSE 响应)
- """
- import json
- import httpx
- import asyncio
- import re
- from typing import Dict, List, Any, Optional
- from config import MCP_SERVERS
- def parse_sse_response(text: str) -> str:
- """
- 解析 SSE 响应,提取 JSON 数据
- SSE 格式:
- event: message
- data: {...json...}
- Args:
- text: SSE 响应文本
- Returns:
- 提取的 JSON 字符串
- """
- # 查找 data: 行并提取 JSON
- for line in text.split('\n'):
- if line.startswith('data:'):
- data_content = line[5:].strip()
- if data_content:
- return data_content
- return text
- class MCPClient:
- """MCP 客户端,负责工具发现和调用"""
- def __init__(self, server_id: str = None, session_id: str = None):
- self.server_id = server_id or "novel-translator"
- self.server_config = MCP_SERVERS.get(self.server_id, {})
- self.session_id = session_id
- self.base_url = self.server_config.get("url", "")
- async def discover_tools(self) -> List[Dict[str, Any]]:
- """从 MCP 服务器发现可用工具"""
- if not self.base_url:
- return []
- try:
- async with httpx.AsyncClient(timeout=30.0, verify=False) as client:
- url = f"{self.base_url}"
- headers = {
- "Content-Type": "application/json",
- "Accept": "application/json, text/event-stream"
- }
- if self.session_id:
- headers["X-Session-ID"] = self.session_id
- payload = {
- "jsonrpc": "2.0",
- "id": 1,
- "method": "tools/list"
- }
- response = await client.post(url, json=payload, headers=headers)
- if response.status_code == 200:
- # 解析 SSE 响应
- json_text = parse_sse_response(response.text)
- result = json.loads(json_text)
- if "result" in result:
- return result["result"].get("tools", [])
- return []
- except Exception as e:
- print(f"MCP 工具发现失败: {e}")
- import traceback
- traceback.print_exc()
- return []
- async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
- """调用 MCP 工具"""
- if not self.base_url:
- return {"error": "MCP 服务器未配置"}
- try:
- async with httpx.AsyncClient(timeout=120.0, verify=False) as client:
- url = f"{self.base_url}"
- headers = {
- "Content-Type": "application/json",
- "Accept": "application/json, text/event-stream"
- }
- if self.session_id:
- headers["X-Session-ID"] = self.session_id
- payload = {
- "jsonrpc": "2.0",
- "id": 2,
- "method": "tools/call",
- "params": {
- "name": tool_name,
- "arguments": arguments
- }
- }
- response = await client.post(url, json=payload, headers=headers)
- if response.status_code == 200:
- # 解析 SSE 响应
- json_text = parse_sse_response(response.text)
- result = json.loads(json_text)
- if "result" in result:
- content_list = result["result"].get("content", [])
- text_results = []
- for item in content_list:
- if item.get("type") == "text":
- text_results.append(item.get("text", ""))
- return {
- "success": True,
- "result": "\n".join(text_results),
- "raw": result["result"]
- }
- elif "error" in result:
- return {"error": result["error"].get("message", "Unknown error")}
- return {"error": f"工具调用失败: {response.status_code}"}
- except Exception as e:
- print(f"MCP 工具调用失败: {e}")
- import traceback
- traceback.print_exc()
- return {"error": str(e)}
- @staticmethod
- async def get_all_tools_async(session_id: str = None) -> List[Dict[str, Any]]:
- """获取所有已配置 MCP 服务器的工具列表(异步版本)"""
- all_tools = []
- for server_id in MCP_SERVERS.keys():
- if not MCP_SERVERS[server_id].get("enabled", False):
- continue
- client = MCPClient(server_id, session_id)
- try:
- tools = await client.discover_tools()
- for tool in tools:
- tool["_server_id"] = server_id
- all_tools.append(tool)
- except Exception as e:
- print(f"发现 {server_id} 工具失败: {e}")
- return all_tools
- @staticmethod
- def get_all_tools(session_id: str = None) -> List[Dict[str, Any]]:
- """获取所有已配置 MCP 服务器的工具列表(同步版本)"""
- return asyncio.run(MCPClient.get_all_tools_async(session_id))
|