tool_handler.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """
  2. 工具调用处理器 - 处理 Claude 的 tool_use 响应
  3. """
  4. import asyncio
  5. import json
  6. from typing import Dict, List, Any
  7. from mcp_client import MCPClient
  8. from tool_converter import ToolConverter
  9. class ToolCallHandler:
  10. """处理 Claude 返回的 tool_use 类型内容块"""
  11. def __init__(self, session_id: str = None, mcp_tokens: dict = None):
  12. self.session_id = session_id
  13. self.mcp_tokens = mcp_tokens or {} # MCP 服务器 token 映射
  14. async def process_tool_use_block(
  15. self,
  16. tool_use_block: Dict[str, Any],
  17. tool_to_server_map: dict = None
  18. ) -> Dict[str, Any]:
  19. """
  20. 处理单个 tool_use 内容块
  21. Args:
  22. tool_use_block: Claude 返回的 tool_use 内容块
  23. {
  24. "type": "tool_use",
  25. "id": "...",
  26. "name": "tool_name",
  27. "input": {...}
  28. }
  29. tool_to_server_map: 工具名到服务器 ID 的映射
  30. Returns:
  31. 工具执行结果
  32. """
  33. tool_name = tool_use_block.get("name", "")
  34. tool_input = tool_use_block.get("input", {})
  35. tool_id = tool_use_block.get("id", "")
  36. if not tool_name:
  37. return {
  38. "tool_use_id": tool_id,
  39. "error": "工具名称为空"
  40. }
  41. # 查找工具所属的服务器
  42. server_id = None
  43. if tool_to_server_map and tool_name in tool_to_server_map:
  44. server_id = tool_to_server_map[tool_name]
  45. # 获取该服务器的 token
  46. auth_token = None
  47. if server_id and server_id in self.mcp_tokens:
  48. auth_token = self.mcp_tokens[server_id]
  49. # 调用 MCP 工具(带 token)
  50. client = MCPClient(session_id=self.session_id, server_id=server_id, auth_token=auth_token)
  51. result = await client.call_tool(tool_name, tool_input)
  52. return {
  53. "tool_use_id": tool_id,
  54. "tool_name": tool_name,
  55. "result": result,
  56. "server_id": server_id
  57. }
  58. async def process_tool_use_blocks(
  59. self,
  60. content_blocks: List[Dict[str, Any]],
  61. tool_to_server_map: dict = None
  62. ) -> List[Dict[str, Any]]:
  63. """
  64. 批量处理 tool_use 内容块
  65. Args:
  66. content_blocks: Claude 返回的内容块列表
  67. tool_to_server_map: 工具名到服务器 ID 的映射
  68. Returns:
  69. 工具执行结果列表
  70. """
  71. tool_use_blocks = [
  72. block for block in content_blocks
  73. if block.get("type") == "tool_use"
  74. ]
  75. if not tool_use_blocks:
  76. return []
  77. # 并发执行所有工具调用
  78. tasks = [
  79. self.process_tool_use_block(block, tool_to_server_map)
  80. for block in tool_use_blocks
  81. ]
  82. results = await asyncio.gather(*tasks, return_exceptions=True)
  83. # 处理异常
  84. formatted_results = []
  85. for i, result in enumerate(results):
  86. if isinstance(result, Exception):
  87. formatted_results.append({
  88. "tool_use_id": tool_use_blocks[i].get("id", ""),
  89. "error": str(result)
  90. })
  91. else:
  92. formatted_results.append(result)
  93. return formatted_results
  94. @staticmethod
  95. def create_tool_result_message(tool_results: List[Dict[str, Any]]) -> Dict[str, Any]:
  96. """
  97. 将工具执行结果转换为 Claude API 可用的消息格式
  98. Args:
  99. tool_results: 工具执行结果列表
  100. Returns:
  101. Claude API 消息格式
  102. """
  103. content = []
  104. for result in tool_results:
  105. tool_use_id = result.get("tool_use_id", "")
  106. tool_result = result.get("result", {})
  107. # 检查是否有错误
  108. if "error" in tool_result:
  109. content.append({
  110. "type": "tool_result",
  111. "tool_use_id": tool_use_id,
  112. "content": f"错误: {tool_result['error']}",
  113. "is_error": True
  114. })
  115. else:
  116. # 成功结果
  117. result_text = tool_result.get("result", "")
  118. content.append({
  119. "type": "tool_result",
  120. "tool_use_id": tool_use_id,
  121. "content": result_text
  122. })
  123. return {
  124. "role": "user",
  125. "content": content
  126. }