conversation_manager.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. 多轮对话管理器 - 处理包含工具调用的多轮对话
  3. """
  4. import asyncio
  5. from typing import Dict, List, Any, Optional
  6. from anthropic import Anthropic
  7. from mcp_client import MCPClient
  8. from tool_converter import ToolConverter
  9. from tool_handler import ToolCallHandler
  10. class ConversationManager:
  11. """管理包含工具调用的多轮对话"""
  12. def __init__(
  13. self,
  14. api_key: str,
  15. base_url: str,
  16. model: str,
  17. session_id: str = None
  18. ):
  19. self.api_key = api_key
  20. self.base_url = base_url
  21. self.model = model
  22. self.session_id = session_id
  23. self.tool_handler = ToolCallHandler(session_id=session_id)
  24. self._cached_tools = None
  25. self.client = Anthropic(api_key=api_key, base_url=base_url)
  26. async def get_available_tools(self) -> List[Dict[str, Any]]:
  27. """获取可用的 Claude 格式工具列表(带缓存)"""
  28. if self._cached_tools is not None:
  29. return self._cached_tools
  30. # 从 MCP 服务器发现工具
  31. mcp_tools = await MCPClient.get_all_tools_async(self.session_id)
  32. # 转换为 Claude 格式
  33. claude_tools = ToolConverter.convert_mcp_tools(mcp_tools)
  34. self._cached_tools = claude_tools
  35. return claude_tools
  36. @classmethod
  37. async def get_tools_async(cls, session_id: str = None) -> List[Dict[str, Any]]:
  38. """
  39. 类方法:获取可用的工具列表(异步)
  40. 用于 API 端点直接调用,无需创建完整实例
  41. """
  42. mcp_tools = await MCPClient.get_all_tools_async(session_id)
  43. return ToolConverter.convert_mcp_tools(mcp_tools)
  44. @staticmethod
  45. def get_tools(session_id: str = None) -> List[Dict[str, Any]]:
  46. """
  47. 静态方法:获取可用的工具列表(同步)
  48. 用于 API 端点直接调用
  49. """
  50. return asyncio.run(ConversationManager.get_tools_async(session_id))
  51. async def chat(
  52. self,
  53. user_message: str,
  54. conversation_history: List[Dict[str, Any]] = None,
  55. max_turns: int = 5
  56. ) -> Dict[str, Any]:
  57. """
  58. 执行多轮对话(自动处理工具调用)
  59. Args:
  60. user_message: 用户消息
  61. conversation_history: 对话历史
  62. max_turns: 最大对话轮数(防止无限循环)
  63. Returns:
  64. 最终响应和对话历史
  65. """
  66. if conversation_history is None:
  67. conversation_history = []
  68. messages = conversation_history.copy()
  69. messages.append({
  70. "role": "user",
  71. "content": user_message
  72. })
  73. current_messages = messages
  74. response_text = ""
  75. tool_calls_made = []
  76. for turn in range(max_turns):
  77. # 获取可用工具
  78. tools = await self.get_available_tools()
  79. # 调用 Claude API
  80. if tools:
  81. response = self.client.messages.create(
  82. model=self.model,
  83. max_tokens=4096,
  84. messages=current_messages,
  85. tools=tools
  86. )
  87. else:
  88. response = self.client.messages.create(
  89. model=self.model,
  90. max_tokens=4096,
  91. messages=current_messages
  92. )
  93. # 检查响应中是否有 tool_use
  94. content_blocks = []
  95. tool_use_blocks = []
  96. text_blocks = []
  97. for block in response.content:
  98. block_type = getattr(block, "type", None)
  99. if block_type == "tool_use":
  100. # 工具调用块
  101. block_dict = {
  102. "type": "tool_use",
  103. "id": getattr(block, "id", ""),
  104. "name": getattr(block, "name", ""),
  105. "input": getattr(block, "input", {})
  106. }
  107. content_blocks.append(block_dict)
  108. tool_use_blocks.append(block_dict)
  109. else:
  110. # 文本块
  111. text_content = getattr(block, "text", "")
  112. if text_content:
  113. text_blocks.append({
  114. "type": "text",
  115. "text": text_content
  116. })
  117. content_blocks.append({
  118. "type": "text",
  119. "text": text_content
  120. })
  121. response_text += text_content
  122. # 如果没有工具调用,返回结果
  123. if not tool_use_blocks:
  124. return {
  125. "response": response_text,
  126. "messages": current_messages,
  127. "tool_calls": tool_calls_made
  128. }
  129. # 处理工具调用
  130. tool_results = await self.tool_handler.process_tool_use_blocks(
  131. tool_use_blocks
  132. )
  133. # 记录工具调用
  134. for tr in tool_results:
  135. tool_calls_made.append({
  136. "tool": tr.get("tool_name"),
  137. "result": tr.get("result", {})
  138. })
  139. # 构建工具结果消息
  140. tool_result_message = ToolCallHandler.create_tool_result_message(
  141. tool_results
  142. )
  143. # 添加到消息历史
  144. current_messages.append({
  145. "role": "assistant",
  146. "content": content_blocks
  147. })
  148. current_messages.append(tool_result_message)
  149. # 达到最大轮数
  150. return {
  151. "response": response_text,
  152. "messages": current_messages,
  153. "tool_calls": tool_calls_made,
  154. "warning": "达到最大对话轮数"
  155. }
  156. @staticmethod
  157. def format_history_for_claude(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  158. """
  159. 格式化对话历史为 Claude API 格式
  160. Args:
  161. history: 原始对话历史
  162. Returns:
  163. Claude API 格式的消息列表
  164. """
  165. formatted = []
  166. for msg in history:
  167. role = msg.get("role")
  168. content = msg.get("content")
  169. if role == "user":
  170. if isinstance(content, str):
  171. formatted.append({"role": "user", "content": content})
  172. elif isinstance(content, list):
  173. formatted.append({"role": "user", "content": content})
  174. elif role == "assistant":
  175. if isinstance(content, str):
  176. formatted.append({"role": "assistant", "content": content})
  177. elif isinstance(content, list):
  178. formatted.append({"role": "assistant", "content": content})
  179. return formatted