""" 多轮对话管理器 - 处理包含工具调用的多轮对话 """ import asyncio from typing import Dict, List, Any, Optional import httpx from anthropic import Anthropic from mcp_client import MCPClient from tool_converter import ToolConverter from tool_handler import ToolCallHandler def create_anthropic_client(api_key: str, base_url: str) -> Anthropic: """ 创建 Anthropic 客户端,支持自定义认证格式 自定义 API 代理需要 'Authorization: Bearer ' 格式, 而不是 Anthropic SDK 默认的 'x-api-key' header。 """ # 创建自定义 httpx client,设置正确的 Authorization header http_client = httpx.Client( headers={"Authorization": f"Bearer {api_key}"}, timeout=120.0 ) return Anthropic(base_url=base_url, http_client=http_client) class ConversationManager: """管理包含工具调用的多轮对话""" def __init__( self, api_key: str, base_url: str, model: str, session_id: str = None ): self.api_key = api_key self.base_url = base_url self.model = model self.session_id = session_id self.tool_handler = ToolCallHandler(session_id=session_id) self._cached_tools = None # 使用自定义 client,支持 Bearer token 认证 self.client = create_anthropic_client(api_key, base_url) async def get_available_tools(self) -> List[Dict[str, Any]]: """获取可用的 Claude 格式工具列表(带缓存)""" if self._cached_tools is not None: return self._cached_tools # 从 MCP 服务器发现工具 mcp_tools = await MCPClient.get_all_tools_async(self.session_id) # 转换为 Claude 格式 claude_tools = ToolConverter.convert_mcp_tools(mcp_tools) self._cached_tools = claude_tools return claude_tools @classmethod async def get_tools_async(cls, session_id: str = None) -> List[Dict[str, Any]]: """ 类方法:获取可用的工具列表(异步) 用于 API 端点直接调用,无需创建完整实例 """ mcp_tools = await MCPClient.get_all_tools_async(session_id) return ToolConverter.convert_mcp_tools(mcp_tools) @staticmethod def get_tools(session_id: str = None) -> List[Dict[str, Any]]: """ 静态方法:获取可用的工具列表(同步) 用于 API 端点直接调用 """ return asyncio.run(ConversationManager.get_tools_async(session_id)) async def chat( self, user_message: str, conversation_history: List[Dict[str, Any]] = None, max_turns: int = 5 ) -> Dict[str, Any]: """ 执行多轮对话(自动处理工具调用) Args: user_message: 用户消息 conversation_history: 对话历史 max_turns: 最大对话轮数(防止无限循环) Returns: 最终响应和对话历史 """ if conversation_history is None: conversation_history = [] messages = conversation_history.copy() messages.append({ "role": "user", "content": user_message }) current_messages = messages response_text = "" tool_calls_made = [] for turn in range(max_turns): # 获取可用工具 tools = await self.get_available_tools() # 调用 Claude API if tools: response = self.client.messages.create( model=self.model, max_tokens=4096, messages=current_messages, tools=tools ) else: response = self.client.messages.create( model=self.model, max_tokens=4096, messages=current_messages ) # 检查响应中是否有 tool_use content_blocks = [] tool_use_blocks = [] text_blocks = [] for block in response.content: block_type = getattr(block, "type", None) if block_type == "tool_use": # 工具调用块 block_dict = { "type": "tool_use", "id": getattr(block, "id", ""), "name": getattr(block, "name", ""), "input": getattr(block, "input", {}) } content_blocks.append(block_dict) tool_use_blocks.append(block_dict) else: # 文本块 text_content = getattr(block, "text", "") if text_content: text_blocks.append({ "type": "text", "text": text_content }) content_blocks.append({ "type": "text", "text": text_content }) response_text += text_content # 如果没有工具调用,返回结果 if not tool_use_blocks: return { "response": response_text, "messages": current_messages, "tool_calls": tool_calls_made } # 处理工具调用 tool_results = await self.tool_handler.process_tool_use_blocks( tool_use_blocks ) # 记录工具调用 for tr in tool_results: tool_calls_made.append({ "tool": tr.get("tool_name"), "result": tr.get("result", {}) }) # 构建工具结果消息 tool_result_message = ToolCallHandler.create_tool_result_message( tool_results ) # 添加到消息历史 current_messages.append({ "role": "assistant", "content": content_blocks }) current_messages.append(tool_result_message) # 达到最大轮数 return { "response": response_text, "messages": current_messages, "tool_calls": tool_calls_made, "warning": "达到最大对话轮数" } @staticmethod def format_history_for_claude(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 格式化对话历史为 Claude API 格式 Args: history: 原始对话历史 Returns: Claude API 格式的消息列表 """ formatted = [] for msg in history: role = msg.get("role") content = msg.get("content") if role == "user": if isinstance(content, str): formatted.append({"role": "user", "content": content}) elif isinstance(content, list): formatted.append({"role": "user", "content": content}) elif role == "assistant": if isinstance(content, str): formatted.append({"role": "assistant", "content": content}) elif isinstance(content, list): formatted.append({"role": "assistant", "content": content}) return formatted