| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- """
- 多轮对话管理器 - 处理包含工具调用的多轮对话
- """
- import asyncio
- from typing import Dict, List, Any, Optional
- from anthropic import Anthropic
- from mcp_client import MCPClient
- from tool_converter import ToolConverter
- from tool_handler import ToolCallHandler
- class ConversationManager:
- """管理包含工具调用的多轮对话"""
- def __init__(
- self,
- api_key: str,
- base_url: str,
- model: str,
- session_id: str = None
- ):
- self.api_key = api_key
- self.base_url = base_url
- self.model = model
- self.session_id = session_id
- self.tool_handler = ToolCallHandler(session_id=session_id)
- self._cached_tools = None
- self.client = Anthropic(api_key=api_key, base_url=base_url)
- async def get_available_tools(self) -> List[Dict[str, Any]]:
- """获取可用的 Claude 格式工具列表(带缓存)"""
- if self._cached_tools is not None:
- return self._cached_tools
- # 从 MCP 服务器发现工具
- mcp_tools = await MCPClient.get_all_tools_async(self.session_id)
- # 转换为 Claude 格式
- claude_tools = ToolConverter.convert_mcp_tools(mcp_tools)
- self._cached_tools = claude_tools
- return claude_tools
- @classmethod
- async def get_tools_async(cls, session_id: str = None) -> List[Dict[str, Any]]:
- """
- 类方法:获取可用的工具列表(异步)
- 用于 API 端点直接调用,无需创建完整实例
- """
- mcp_tools = await MCPClient.get_all_tools_async(session_id)
- return ToolConverter.convert_mcp_tools(mcp_tools)
- @staticmethod
- def get_tools(session_id: str = None) -> List[Dict[str, Any]]:
- """
- 静态方法:获取可用的工具列表(同步)
- 用于 API 端点直接调用
- """
- return asyncio.run(ConversationManager.get_tools_async(session_id))
- async def chat(
- self,
- user_message: str,
- conversation_history: List[Dict[str, Any]] = None,
- max_turns: int = 5
- ) -> Dict[str, Any]:
- """
- 执行多轮对话(自动处理工具调用)
- Args:
- user_message: 用户消息
- conversation_history: 对话历史
- max_turns: 最大对话轮数(防止无限循环)
- Returns:
- 最终响应和对话历史
- """
- if conversation_history is None:
- conversation_history = []
- messages = conversation_history.copy()
- messages.append({
- "role": "user",
- "content": user_message
- })
- current_messages = messages
- response_text = ""
- tool_calls_made = []
- for turn in range(max_turns):
- # 获取可用工具
- tools = await self.get_available_tools()
- # 调用 Claude API
- if tools:
- response = self.client.messages.create(
- model=self.model,
- max_tokens=4096,
- messages=current_messages,
- tools=tools
- )
- else:
- response = self.client.messages.create(
- model=self.model,
- max_tokens=4096,
- messages=current_messages
- )
- # 检查响应中是否有 tool_use
- content_blocks = []
- tool_use_blocks = []
- text_blocks = []
- for block in response.content:
- block_type = getattr(block, "type", None)
- if block_type == "tool_use":
- # 工具调用块
- block_dict = {
- "type": "tool_use",
- "id": getattr(block, "id", ""),
- "name": getattr(block, "name", ""),
- "input": getattr(block, "input", {})
- }
- content_blocks.append(block_dict)
- tool_use_blocks.append(block_dict)
- else:
- # 文本块
- text_content = getattr(block, "text", "")
- if text_content:
- text_blocks.append({
- "type": "text",
- "text": text_content
- })
- content_blocks.append({
- "type": "text",
- "text": text_content
- })
- response_text += text_content
- # 如果没有工具调用,返回结果
- if not tool_use_blocks:
- return {
- "response": response_text,
- "messages": current_messages,
- "tool_calls": tool_calls_made
- }
- # 处理工具调用
- tool_results = await self.tool_handler.process_tool_use_blocks(
- tool_use_blocks
- )
- # 记录工具调用
- for tr in tool_results:
- tool_calls_made.append({
- "tool": tr.get("tool_name"),
- "result": tr.get("result", {})
- })
- # 构建工具结果消息
- tool_result_message = ToolCallHandler.create_tool_result_message(
- tool_results
- )
- # 添加到消息历史
- current_messages.append({
- "role": "assistant",
- "content": content_blocks
- })
- current_messages.append(tool_result_message)
- # 达到最大轮数
- return {
- "response": response_text,
- "messages": current_messages,
- "tool_calls": tool_calls_made,
- "warning": "达到最大对话轮数"
- }
- @staticmethod
- def format_history_for_claude(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """
- 格式化对话历史为 Claude API 格式
- Args:
- history: 原始对话历史
- Returns:
- Claude API 格式的消息列表
- """
- formatted = []
- for msg in history:
- role = msg.get("role")
- content = msg.get("content")
- if role == "user":
- if isinstance(content, str):
- formatted.append({"role": "user", "content": content})
- elif isinstance(content, list):
- formatted.append({"role": "user", "content": content})
- elif role == "assistant":
- if isinstance(content, str):
- formatted.append({"role": "assistant", "content": content})
- elif isinstance(content, list):
- formatted.append({"role": "assistant", "content": content})
- return formatted
|