conversation_manager.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """
  2. 多轮对话管理器 - 处理包含工具调用的多轮对话
  3. """
  4. import asyncio
  5. from typing import Dict, List, Any, Optional
  6. import httpx
  7. from anthropic import Anthropic
  8. from mcp_client import MCPClient
  9. from tool_converter import ToolConverter
  10. from tool_handler import ToolCallHandler
  11. # ========== 基础 System Prompt ==========
  12. # 不包含组件列表和 MCP 工具说明,由 _build_system_prompt 动态构建
  13. BASE_SYSTEM_PROMPT = """你是一个 AI 助手,可以通过调用 MCP 工具来帮助用户完成任务。
  14. ## 当前状态
  15. {MCP_STATUS}
  16. """
  17. def create_anthropic_client(api_key: str, base_url: str) -> Anthropic:
  18. """
  19. 创建 Anthropic 客户端,支持自定义认证格式
  20. 自定义 API 代理需要 'Authorization: Bearer <token>' 格式,
  21. 而不是 Anthropic SDK 默认的 'x-api-key' header。
  22. """
  23. # 创建自定义 httpx client,设置正确的 Authorization header
  24. http_client = httpx.Client(
  25. headers={"Authorization": f"Bearer {api_key}"},
  26. timeout=120.0
  27. )
  28. return Anthropic(base_url=base_url, http_client=http_client)
  29. class ConversationManager:
  30. """管理包含工具调用的多轮对话"""
  31. def __init__(
  32. self,
  33. api_key: str,
  34. base_url: str,
  35. model: str,
  36. session_id: str = None,
  37. mcp_tokens: dict = None,
  38. components_prompt: str = None,
  39. enabled_mcp_list: list = None # 新增:前端传递的已启用 MCP 列表
  40. ):
  41. self.api_key = api_key
  42. self.base_url = base_url
  43. self.model = model
  44. self.session_id = session_id
  45. self.mcp_tokens = mcp_tokens or {} # MCP 服务器 token 映射
  46. self.enabled_mcp_list = enabled_mcp_list # 前端传递的已启用 MCP 列表
  47. # 组件提示词(由前端动态提供)
  48. self.components_prompt = components_prompt # 前端必须提供,无 fallback
  49. # 构建完整的系统提示词
  50. self.system_prompt = self._build_system_prompt()
  51. # DEBUG: 打印接收到的 token 和启用列表
  52. print(f"[DEBUG ConversationManager.__init__] mcp_tokens keys: {list(self.mcp_tokens.keys())}")
  53. print(f"[DEBUG ConversationManager.__init__] enabled_mcp_list: {self.enabled_mcp_list}")
  54. print(f"[DEBUG ConversationManager.__init__] components_prompt length: {len(components_prompt) if components_prompt else 0} chars")
  55. for k, v in self.mcp_tokens.items():
  56. print(f"[DEBUG ConversationManager.__init__] {k}: {v[:30] if v else 'None'}...")
  57. self.tool_handler = ToolCallHandler(session_id=session_id, mcp_tokens=mcp_tokens)
  58. self._cached_tools = None
  59. self._tool_to_server_map = {} # 工具名到服务器 ID 的映射
  60. # 使用自定义 client,支持 Bearer token 认证
  61. self.client = create_anthropic_client(api_key, base_url)
  62. def _build_system_prompt(self) -> str:
  63. """构建完整的 system prompt"""
  64. # 构建 MCP 状态部分
  65. if self.enabled_mcp_list is not None and len(self.enabled_mcp_list) == 0:
  66. # 用户明确禁用了所有 MCP
  67. mcp_status = """当前没有启用任何 MCP 服务器。你只能进行普通对话和返回 json-render 组件。"""
  68. elif self.enabled_mcp_list:
  69. # 部分或全部 MCP 已启用
  70. enabled_names = ", ".join(self.enabled_mcp_list)
  71. mcp_status = f"""**已启用的 MCP 服务器**: {enabled_names}
  72. 你可以调用这些 MCP 服务器的工具来帮助用户完成任务。"""
  73. else:
  74. # enabled_mcp_list 是 None,使用默认状态
  75. mcp_status = """你可以通过调用 MCP 工具来帮助用户完成任务。"""
  76. # 替换占位符并添加组件列表
  77. prompt = BASE_SYSTEM_PROMPT.replace("{MCP_STATUS}", mcp_status)
  78. if self.components_prompt:
  79. return prompt + "\n\n" + self.components_prompt
  80. else:
  81. return prompt
  82. async def get_available_tools(self) -> List[Dict[str, Any]]:
  83. """获取可用的 Claude 格式工具列表(带缓存)"""
  84. if self._cached_tools is not None:
  85. return self._cached_tools
  86. # 从 MCP 服务器发现工具(带 token 和启用列表)
  87. mcp_tools = await MCPClient.get_all_tools_with_tokens_async(
  88. self.session_id,
  89. self.mcp_tokens,
  90. self.enabled_mcp_list # 传递前端传递的已启用 MCP 列表
  91. )
  92. # 转换为 Claude 格式
  93. claude_tools = []
  94. for tool in mcp_tools:
  95. claude_tool = ToolConverter.mcp_to_claude_tool(tool)
  96. claude_tools.append(claude_tool)
  97. # 构建工具名到服务器 ID 的映射
  98. server_id = tool.get("_server_id", "")
  99. if server_id:
  100. self._tool_to_server_map[claude_tool["name"]] = server_id
  101. self._cached_tools = claude_tools
  102. return claude_tools
  103. @classmethod
  104. async def get_tools_async(cls, session_id: str = None) -> List[Dict[str, Any]]:
  105. """
  106. 类方法:获取可用的工具列表(异步)
  107. 用于 API 端点直接调用,无需创建完整实例
  108. """
  109. mcp_tools = await MCPClient.get_all_tools_async(session_id)
  110. return ToolConverter.convert_mcp_tools(mcp_tools)
  111. @staticmethod
  112. def get_tools(session_id: str = None) -> List[Dict[str, Any]]:
  113. """
  114. 静态方法:获取可用的工具列表(同步)
  115. 用于 API 端点直接调用
  116. """
  117. return asyncio.run(ConversationManager.get_tools_async(session_id))
  118. async def chat(
  119. self,
  120. user_message: str,
  121. conversation_history: List[Dict[str, Any]] = None,
  122. max_turns: int = 5
  123. ) -> Dict[str, Any]:
  124. """
  125. 执行多轮对话(自动处理工具调用)
  126. Args:
  127. user_message: 用户消息
  128. conversation_history: 对话历史
  129. max_turns: 最大对话轮数(防止无限循环)
  130. Returns:
  131. 最终响应和对话历史
  132. """
  133. if conversation_history is None:
  134. conversation_history = []
  135. messages = conversation_history.copy()
  136. messages.append({
  137. "role": "user",
  138. "content": user_message
  139. })
  140. current_messages = messages
  141. response_text = ""
  142. tool_calls_made = []
  143. for turn in range(max_turns):
  144. # 获取可用工具
  145. tools = await self.get_available_tools()
  146. # 调用 Claude API
  147. if tools:
  148. response = self.client.messages.create(
  149. model=self.model,
  150. max_tokens=4096,
  151. system=self.system_prompt, # 使用动态系统提示
  152. messages=current_messages,
  153. tools=tools
  154. )
  155. else:
  156. response = self.client.messages.create(
  157. model=self.model,
  158. max_tokens=4096,
  159. system=self.system_prompt, # 使用动态系统提示
  160. messages=current_messages
  161. )
  162. # 检查响应中是否有 tool_use
  163. content_blocks = []
  164. tool_use_blocks = []
  165. text_blocks = []
  166. for block in response.content:
  167. block_type = getattr(block, "type", None)
  168. if block_type == "tool_use":
  169. # 工具调用块
  170. block_dict = {
  171. "type": "tool_use",
  172. "id": getattr(block, "id", ""),
  173. "name": getattr(block, "name", ""),
  174. "input": getattr(block, "input", {})
  175. }
  176. content_blocks.append(block_dict)
  177. tool_use_blocks.append(block_dict)
  178. else:
  179. # 文本块
  180. text_content = getattr(block, "text", "")
  181. if text_content:
  182. text_blocks.append({
  183. "type": "text",
  184. "text": text_content
  185. })
  186. content_blocks.append({
  187. "type": "text",
  188. "text": text_content
  189. })
  190. response_text += text_content
  191. # 如果没有工具调用,返回结果
  192. if not tool_use_blocks:
  193. return {
  194. "response": response_text,
  195. "messages": current_messages,
  196. "tool_calls": tool_calls_made
  197. }
  198. # 处理工具调用
  199. tool_results = await self.tool_handler.process_tool_use_blocks(
  200. tool_use_blocks,
  201. self._tool_to_server_map
  202. )
  203. # 记录工具调用
  204. for tr in tool_results:
  205. tool_calls_made.append({
  206. "tool": tr.get("tool_name"),
  207. "result": tr.get("result", {})
  208. })
  209. # 构建工具结果消息
  210. tool_result_message = ToolCallHandler.create_tool_result_message(
  211. tool_results
  212. )
  213. # 添加到消息历史
  214. current_messages.append({
  215. "role": "assistant",
  216. "content": content_blocks
  217. })
  218. current_messages.append(tool_result_message)
  219. # 达到最大轮数
  220. return {
  221. "response": response_text,
  222. "messages": current_messages,
  223. "tool_calls": tool_calls_made,
  224. "warning": "达到最大对话轮数"
  225. }
  226. @staticmethod
  227. def format_history_for_claude(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  228. """
  229. 格式化对话历史为 Claude API 格式
  230. Args:
  231. history: 原始对话历史
  232. Returns:
  233. Claude API 格式的消息列表
  234. """
  235. formatted = []
  236. for msg in history:
  237. role = msg.get("role")
  238. content = msg.get("content")
  239. if role == "user":
  240. if isinstance(content, str):
  241. formatted.append({"role": "user", "content": content})
  242. elif isinstance(content, list):
  243. formatted.append({"role": "user", "content": content})
  244. elif role == "assistant":
  245. if isinstance(content, str):
  246. formatted.append({"role": "assistant", "content": content})
  247. elif isinstance(content, list):
  248. formatted.append({"role": "assistant", "content": content})
  249. return formatted