2
0

tool_handler.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """
  2. 工具调用处理器 - 处理 Claude 的 tool_use 响应
  3. """
  4. import asyncio
  5. import json
  6. from typing import Dict, List, Any
  7. from mcp_client import MCPClient
  8. from tool_converter import ToolConverter
  9. class ToolCallHandler:
  10. """处理 Claude 返回的 tool_use 类型内容块"""
  11. def __init__(self, session_id: str = None, mcp_tokens: dict = None):
  12. self.session_id = session_id
  13. self.mcp_tokens = mcp_tokens or {} # MCP 服务器 token 映射
  14. async def process_tool_use_block(
  15. self,
  16. tool_use_block: Dict[str, Any],
  17. tool_to_server_map: dict = None
  18. ) -> Dict[str, Any]:
  19. """
  20. 处理单个 tool_use 内容块
  21. Args:
  22. tool_use_block: Claude 返回的 tool_use 内容块
  23. {
  24. "type": "tool_use",
  25. "id": "...",
  26. "name": "tool_name",
  27. "input": {...}
  28. }
  29. tool_to_server_map: 工具名到服务器 ID 的映射
  30. Returns:
  31. 工具执行结果
  32. """
  33. tool_name = tool_use_block.get("name", "")
  34. tool_input = tool_use_block.get("input", {})
  35. tool_id = tool_use_block.get("id", "")
  36. if not tool_name:
  37. return {
  38. "tool_use_id": tool_id,
  39. "error": "工具名称为空"
  40. }
  41. # 查找工具所属的服务器
  42. server_id = None
  43. if tool_to_server_map and tool_name in tool_to_server_map:
  44. server_id = tool_to_server_map[tool_name]
  45. # 获取该服务器的 token
  46. auth_token = None
  47. if server_id and server_id in self.mcp_tokens:
  48. auth_token = self.mcp_tokens[server_id]
  49. # DEBUG: 打印工具调用详情
  50. print(f"[DEBUG ToolCallHandler.process_tool_use_block] tool: {tool_name}")
  51. print(f"[DEBUG ToolCallHandler.process_tool_use_block] server_id: {server_id}")
  52. print(f"[DEBUG ToolCallHandler.process_tool_use_block] auth_token present: {bool(auth_token)}")
  53. if auth_token:
  54. print(f"[DEBUG ToolCallHandler.process_tool_use_block] auth_token: {auth_token[:30]}...")
  55. print(f"[DEBUG ToolCallHandler.process_tool_use_block] available tokens: {list(self.mcp_tokens.keys())}")
  56. # 调用 MCP 工具(带 token)
  57. client = MCPClient(session_id=self.session_id, server_id=server_id, auth_token=auth_token)
  58. result = await client.call_tool(tool_name, tool_input)
  59. return {
  60. "tool_use_id": tool_id,
  61. "tool_name": tool_name,
  62. "result": result,
  63. "server_id": server_id
  64. }
  65. async def process_tool_use_blocks(
  66. self,
  67. content_blocks: List[Dict[str, Any]],
  68. tool_to_server_map: dict = None
  69. ) -> List[Dict[str, Any]]:
  70. """
  71. 批量处理 tool_use 内容块
  72. Args:
  73. content_blocks: Claude 返回的内容块列表
  74. tool_to_server_map: 工具名到服务器 ID 的映射
  75. Returns:
  76. 工具执行结果列表
  77. """
  78. tool_use_blocks = [
  79. block for block in content_blocks
  80. if block.get("type") == "tool_use"
  81. ]
  82. # DEBUG: 打印工具映射和可用 tokens
  83. print(f"[DEBUG ToolCallHandler.process_tool_use_blocks] tool_to_server_map keys: {list(tool_to_server_map.keys()) if tool_to_server_map else 'None'}")
  84. print(f"[DEBUG ToolCallHandler.process_tool_use_blocks] self.mcp_tokens keys: {list(self.mcp_tokens.keys())}")
  85. if not tool_use_blocks:
  86. return []
  87. # 并发执行所有工具调用
  88. tasks = [
  89. self.process_tool_use_block(block, tool_to_server_map)
  90. for block in tool_use_blocks
  91. ]
  92. results = await asyncio.gather(*tasks, return_exceptions=True)
  93. # 处理异常
  94. formatted_results = []
  95. for i, result in enumerate(results):
  96. if isinstance(result, Exception):
  97. formatted_results.append({
  98. "tool_use_id": tool_use_blocks[i].get("id", ""),
  99. "error": str(result)
  100. })
  101. else:
  102. formatted_results.append(result)
  103. return formatted_results
  104. @staticmethod
  105. def create_tool_result_message(tool_results: List[Dict[str, Any]]) -> Dict[str, Any]:
  106. """
  107. 将工具执行结果转换为 Claude API 可用的消息格式
  108. Args:
  109. tool_results: 工具执行结果列表
  110. Returns:
  111. Claude API 消息格式
  112. """
  113. content = []
  114. for result in tool_results:
  115. tool_use_id = result.get("tool_use_id", "")
  116. tool_result = result.get("result", {})
  117. # 检查是否有错误
  118. if "error" in tool_result:
  119. content.append({
  120. "type": "tool_result",
  121. "tool_use_id": tool_use_id,
  122. "content": f"错误: {tool_result['error']}",
  123. "is_error": True
  124. })
  125. else:
  126. # 成功结果
  127. result_text = tool_result.get("result", "")
  128. content.append({
  129. "type": "tool_result",
  130. "tool_use_id": tool_use_id,
  131. "content": result_text
  132. })
  133. return {
  134. "role": "user",
  135. "content": content
  136. }