tool_handler.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """
  2. 工具调用处理器 - 处理 Claude 的 tool_use 响应
  3. """
  4. import asyncio
  5. import json
  6. from typing import Dict, List, Any
  7. from mcp_client import MCPClient
  8. from tool_converter import ToolConverter
  9. class ToolCallHandler:
  10. """处理 Claude 返回的 tool_use 类型内容块"""
  11. def __init__(self, session_id: str = None):
  12. self.session_id = session_id
  13. async def process_tool_use_block(
  14. self,
  15. tool_use_block: Dict[str, Any]
  16. ) -> Dict[str, Any]:
  17. """
  18. 处理单个 tool_use 内容块
  19. Args:
  20. tool_use_block: Claude 返回的 tool_use 内容块
  21. {
  22. "type": "tool_use",
  23. "id": "...",
  24. "name": "tool_name",
  25. "input": {...}
  26. }
  27. Returns:
  28. 工具执行结果
  29. """
  30. tool_name = tool_use_block.get("name", "")
  31. tool_input = tool_use_block.get("input", {})
  32. tool_id = tool_use_block.get("id", "")
  33. if not tool_name:
  34. return {
  35. "tool_use_id": tool_id,
  36. "error": "工具名称为空"
  37. }
  38. # 调用 MCP 工具
  39. client = MCPClient(session_id=self.session_id)
  40. result = await client.call_tool(tool_name, tool_input)
  41. return {
  42. "tool_use_id": tool_id,
  43. "tool_name": tool_name,
  44. "result": result
  45. }
  46. async def process_tool_use_blocks(
  47. self,
  48. content_blocks: List[Dict[str, Any]]
  49. ) -> List[Dict[str, Any]]:
  50. """
  51. 批量处理 tool_use 内容块
  52. Args:
  53. content_blocks: Claude 返回的内容块列表
  54. Returns:
  55. 工具执行结果列表
  56. """
  57. tool_use_blocks = [
  58. block for block in content_blocks
  59. if block.get("type") == "tool_use"
  60. ]
  61. if not tool_use_blocks:
  62. return []
  63. # 并发执行所有工具调用
  64. tasks = [
  65. self.process_tool_use_block(block)
  66. for block in tool_use_blocks
  67. ]
  68. results = await asyncio.gather(*tasks, return_exceptions=True)
  69. # 处理异常
  70. formatted_results = []
  71. for i, result in enumerate(results):
  72. if isinstance(result, Exception):
  73. formatted_results.append({
  74. "tool_use_id": tool_use_blocks[i].get("id", ""),
  75. "error": str(result)
  76. })
  77. else:
  78. formatted_results.append(result)
  79. return formatted_results
  80. @staticmethod
  81. def create_tool_result_message(tool_results: List[Dict[str, Any]]) -> Dict[str, Any]:
  82. """
  83. 将工具执行结果转换为 Claude API 可用的消息格式
  84. Args:
  85. tool_results: 工具执行结果列表
  86. Returns:
  87. Claude API 消息格式
  88. """
  89. content = []
  90. for result in tool_results:
  91. tool_use_id = result.get("tool_use_id", "")
  92. tool_result = result.get("result", {})
  93. # 检查是否有错误
  94. if "error" in tool_result:
  95. content.append({
  96. "type": "tool_result",
  97. "tool_use_id": tool_use_id,
  98. "content": f"错误: {tool_result['error']}",
  99. "is_error": True
  100. })
  101. else:
  102. # 成功结果
  103. result_text = tool_result.get("result", "")
  104. content.append({
  105. "type": "tool_result",
  106. "tool_use_id": tool_use_id,
  107. "content": result_text
  108. })
  109. return {
  110. "role": "user",
  111. "content": content
  112. }