2
0

conversation_manager.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """
  2. 多轮对话管理器 - 处理包含工具调用的多轮对话
  3. """
  4. import asyncio
  5. from typing import Dict, List, Any, Optional
  6. import httpx
  7. from anthropic import Anthropic
  8. from mcp_client import MCPClient
  9. from tool_converter import ToolConverter
  10. from tool_handler import ToolCallHandler
  11. def create_anthropic_client(api_key: str, base_url: str) -> Anthropic:
  12. """
  13. 创建 Anthropic 客户端,支持自定义认证格式
  14. 自定义 API 代理需要 'Authorization: Bearer <token>' 格式,
  15. 而不是 Anthropic SDK 默认的 'x-api-key' header。
  16. """
  17. # 创建自定义 httpx client,设置正确的 Authorization header
  18. http_client = httpx.Client(
  19. headers={"Authorization": f"Bearer {api_key}"},
  20. timeout=120.0
  21. )
  22. return Anthropic(base_url=base_url, http_client=http_client)
  23. class ConversationManager:
  24. """管理包含工具调用的多轮对话"""
  25. def __init__(
  26. self,
  27. api_key: str,
  28. base_url: str,
  29. model: str,
  30. session_id: str = None,
  31. mcp_tokens: dict = None
  32. ):
  33. self.api_key = api_key
  34. self.base_url = base_url
  35. self.model = model
  36. self.session_id = session_id
  37. self.mcp_tokens = mcp_tokens or {} # MCP 服务器 token 映射
  38. # DEBUG: 打印接收到的 token
  39. print(f"[DEBUG ConversationManager.__init__] mcp_tokens keys: {list(self.mcp_tokens.keys())}")
  40. for k, v in self.mcp_tokens.items():
  41. print(f"[DEBUG ConversationManager.__init__] {k}: {v[:30] if v else 'None'}...")
  42. self.tool_handler = ToolCallHandler(session_id=session_id, mcp_tokens=mcp_tokens)
  43. self._cached_tools = None
  44. self._tool_to_server_map = {} # 工具名到服务器 ID 的映射
  45. # 使用自定义 client,支持 Bearer token 认证
  46. self.client = create_anthropic_client(api_key, base_url)
  47. async def get_available_tools(self) -> List[Dict[str, Any]]:
  48. """获取可用的 Claude 格式工具列表(带缓存)"""
  49. if self._cached_tools is not None:
  50. return self._cached_tools
  51. # 从 MCP 服务器发现工具(带 token)
  52. mcp_tools = await MCPClient.get_all_tools_with_tokens_async(
  53. self.session_id, self.mcp_tokens
  54. )
  55. # 转换为 Claude 格式
  56. claude_tools = []
  57. for tool in mcp_tools:
  58. claude_tool = ToolConverter.mcp_to_claude_tool(tool)
  59. claude_tools.append(claude_tool)
  60. # 构建工具名到服务器 ID 的映射
  61. server_id = tool.get("_server_id", "")
  62. if server_id:
  63. self._tool_to_server_map[claude_tool["name"]] = server_id
  64. self._cached_tools = claude_tools
  65. return claude_tools
  66. @classmethod
  67. async def get_tools_async(cls, session_id: str = None) -> List[Dict[str, Any]]:
  68. """
  69. 类方法:获取可用的工具列表(异步)
  70. 用于 API 端点直接调用,无需创建完整实例
  71. """
  72. mcp_tools = await MCPClient.get_all_tools_async(session_id)
  73. return ToolConverter.convert_mcp_tools(mcp_tools)
  74. @staticmethod
  75. def get_tools(session_id: str = None) -> List[Dict[str, Any]]:
  76. """
  77. 静态方法:获取可用的工具列表(同步)
  78. 用于 API 端点直接调用
  79. """
  80. return asyncio.run(ConversationManager.get_tools_async(session_id))
  81. async def chat(
  82. self,
  83. user_message: str,
  84. conversation_history: List[Dict[str, Any]] = None,
  85. max_turns: int = 5
  86. ) -> Dict[str, Any]:
  87. """
  88. 执行多轮对话(自动处理工具调用)
  89. Args:
  90. user_message: 用户消息
  91. conversation_history: 对话历史
  92. max_turns: 最大对话轮数(防止无限循环)
  93. Returns:
  94. 最终响应和对话历史
  95. """
  96. if conversation_history is None:
  97. conversation_history = []
  98. messages = conversation_history.copy()
  99. messages.append({
  100. "role": "user",
  101. "content": user_message
  102. })
  103. current_messages = messages
  104. response_text = ""
  105. tool_calls_made = []
  106. for turn in range(max_turns):
  107. # 获取可用工具
  108. tools = await self.get_available_tools()
  109. # 调用 Claude API
  110. if tools:
  111. response = self.client.messages.create(
  112. model=self.model,
  113. max_tokens=4096,
  114. messages=current_messages,
  115. tools=tools
  116. )
  117. else:
  118. response = self.client.messages.create(
  119. model=self.model,
  120. max_tokens=4096,
  121. messages=current_messages
  122. )
  123. # 检查响应中是否有 tool_use
  124. content_blocks = []
  125. tool_use_blocks = []
  126. text_blocks = []
  127. for block in response.content:
  128. block_type = getattr(block, "type", None)
  129. if block_type == "tool_use":
  130. # 工具调用块
  131. block_dict = {
  132. "type": "tool_use",
  133. "id": getattr(block, "id", ""),
  134. "name": getattr(block, "name", ""),
  135. "input": getattr(block, "input", {})
  136. }
  137. content_blocks.append(block_dict)
  138. tool_use_blocks.append(block_dict)
  139. else:
  140. # 文本块
  141. text_content = getattr(block, "text", "")
  142. if text_content:
  143. text_blocks.append({
  144. "type": "text",
  145. "text": text_content
  146. })
  147. content_blocks.append({
  148. "type": "text",
  149. "text": text_content
  150. })
  151. response_text += text_content
  152. # 如果没有工具调用,返回结果
  153. if not tool_use_blocks:
  154. return {
  155. "response": response_text,
  156. "messages": current_messages,
  157. "tool_calls": tool_calls_made
  158. }
  159. # 处理工具调用
  160. tool_results = await self.tool_handler.process_tool_use_blocks(
  161. tool_use_blocks,
  162. self._tool_to_server_map
  163. )
  164. # 记录工具调用
  165. for tr in tool_results:
  166. tool_calls_made.append({
  167. "tool": tr.get("tool_name"),
  168. "result": tr.get("result", {})
  169. })
  170. # 构建工具结果消息
  171. tool_result_message = ToolCallHandler.create_tool_result_message(
  172. tool_results
  173. )
  174. # 添加到消息历史
  175. current_messages.append({
  176. "role": "assistant",
  177. "content": content_blocks
  178. })
  179. current_messages.append(tool_result_message)
  180. # 达到最大轮数
  181. return {
  182. "response": response_text,
  183. "messages": current_messages,
  184. "tool_calls": tool_calls_made,
  185. "warning": "达到最大对话轮数"
  186. }
  187. @staticmethod
  188. def format_history_for_claude(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  189. """
  190. 格式化对话历史为 Claude API 格式
  191. Args:
  192. history: 原始对话历史
  193. Returns:
  194. Claude API 格式的消息列表
  195. """
  196. formatted = []
  197. for msg in history:
  198. role = msg.get("role")
  199. content = msg.get("content")
  200. if role == "user":
  201. if isinstance(content, str):
  202. formatted.append({"role": "user", "content": content})
  203. elif isinstance(content, list):
  204. formatted.append({"role": "user", "content": content})
  205. elif role == "assistant":
  206. if isinstance(content, str):
  207. formatted.append({"role": "assistant", "content": content})
  208. elif isinstance(content, list):
  209. formatted.append({"role": "assistant", "content": content})
  210. return formatted