|
|
@@ -5,12 +5,14 @@ AI MCP Web UI - Flask 后端
|
|
|
import os
|
|
|
import asyncio
|
|
|
from typing import Optional, Dict
|
|
|
-from flask import Flask, request, jsonify, send_from_directory
|
|
|
+from flask import Flask, request, jsonify, send_from_directory, Response
|
|
|
+import json as json_module
|
|
|
from flask_cors import CORS
|
|
|
import httpx
|
|
|
from anthropic import Anthropic
|
|
|
from config import MCP_SERVERS, ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL, ANTHROPIC_MODEL
|
|
|
from conversation_manager import ConversationManager
|
|
|
+from tool_handler import ToolCallHandler
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
CORS(app)
|
|
|
@@ -108,6 +110,192 @@ def chat():
|
|
|
}), 500
|
|
|
|
|
|
|
|
|
+@app.route('/api/chat/stream', methods=['POST'])
|
|
|
+def chat_stream():
|
|
|
+ """
|
|
|
+ 聊天端点 - 流式输出版本(解决超时问题)
|
|
|
+
|
|
|
+ 使用 Server-Sent Events (SSE) 实时返回:
|
|
|
+ 1. Claude 的思考过程
|
|
|
+ 2. 工具调用状态
|
|
|
+ 3. 最终响应
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ data = request.json
|
|
|
+ message = data.get('message', '')
|
|
|
+ conversation_history = data.get('history', [])
|
|
|
+ session_id = request.headers.get('X-Session-ID')
|
|
|
+
|
|
|
+ if not message:
|
|
|
+ return jsonify({"error": "Message is required"}), 400
|
|
|
+
|
|
|
+ def generate():
|
|
|
+ """生成 SSE 流式响应"""
|
|
|
+ try:
|
|
|
+ # 发送开始事件
|
|
|
+ yield f"event: start\ndata: {json_module.dumps({'status': 'started'})}\n\n"
|
|
|
+
|
|
|
+ # 创建对话管理器
|
|
|
+ conv_manager = ConversationManager(
|
|
|
+ api_key=ANTHROPIC_API_KEY,
|
|
|
+ base_url=ANTHROPIC_BASE_URL,
|
|
|
+ model=ANTHROPIC_MODEL,
|
|
|
+ session_id=session_id
|
|
|
+ )
|
|
|
+
|
|
|
+ # 格式化对话历史
|
|
|
+ formatted_history = ConversationManager.format_history_for_claude(conversation_history)
|
|
|
+ messages = formatted_history + [{"role": "user", "content": message}]
|
|
|
+
|
|
|
+ current_messages = messages
|
|
|
+ tool_calls_info = []
|
|
|
+
|
|
|
+ for turn in range(5): # 最多 5 轮
|
|
|
+ # 获取可用工具
|
|
|
+ tools = run_async(conv_manager.get_available_tools())
|
|
|
+
|
|
|
+ # 发送工具列表
|
|
|
+ yield f"event: tools\ndata: {json_module.dumps({'count': len(tools), 'tools': [t['name'] for t in tools[:5]]})}\n\n"
|
|
|
+
|
|
|
+ # 调用 Claude API(流式)
|
|
|
+ if tools:
|
|
|
+ response_stream = conv_manager.client.messages.create(
|
|
|
+ model=conv_manager.model,
|
|
|
+ max_tokens=4096,
|
|
|
+ messages=current_messages,
|
|
|
+ tools=tools,
|
|
|
+ stream=True
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ response_stream = conv_manager.client.messages.create(
|
|
|
+ model=conv_manager.model,
|
|
|
+ max_tokens=4096,
|
|
|
+ messages=current_messages,
|
|
|
+ stream=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # 处理流式响应
|
|
|
+ content_blocks = []
|
|
|
+ tool_use_blocks = []
|
|
|
+ response_text = ""
|
|
|
+ current_block_type = None
|
|
|
+ current_tool_index = -1
|
|
|
+ partial_json = ""
|
|
|
+
|
|
|
+ for event in response_stream:
|
|
|
+ # 处理内容块开始 - 检查是否是工具调用
|
|
|
+ if event.type == "content_block_start":
|
|
|
+ # 检查块的类型
|
|
|
+ if hasattr(event, "content_block"):
|
|
|
+ current_block_type = getattr(event.content_block, "type", None)
|
|
|
+ if current_block_type == "tool_use":
|
|
|
+ # 这是工具调用块的开始
|
|
|
+ tool_use_id = getattr(event.content_block, "id", "")
|
|
|
+ # content_block 包含 name
|
|
|
+ tool_name = getattr(event.content_block, "name", "")
|
|
|
+ tool_use_blocks.append({
|
|
|
+ "type": "tool_use",
|
|
|
+ "id": tool_use_id,
|
|
|
+ "name": tool_name,
|
|
|
+ "input": {}
|
|
|
+ })
|
|
|
+ current_tool_index = len(tool_use_blocks) - 1
|
|
|
+ partial_json = ""
|
|
|
+
|
|
|
+ # 处理内容块增量
|
|
|
+ elif event.type == "content_block_delta":
|
|
|
+ delta_type = getattr(event.delta, "type", "")
|
|
|
+
|
|
|
+ # 文本增量
|
|
|
+ if delta_type == "text_delta":
|
|
|
+ text = event.delta.text
|
|
|
+ response_text += text
|
|
|
+ yield f"event: token\ndata: {json_module.dumps({'text': text})}\n\n"
|
|
|
+
|
|
|
+ # 工具参数增量 - input_json_delta
|
|
|
+ elif delta_type == "input_json_delta":
|
|
|
+ # 累积 partial_json 构建完整参数
|
|
|
+ partial_json_str = getattr(event.delta, "partial_json", "")
|
|
|
+ if partial_json_str:
|
|
|
+ partial_json += partial_json_str
|
|
|
+ try:
|
|
|
+ # 尝试解析累积的 JSON
|
|
|
+ parsed_input = json_module.loads(partial_json)
|
|
|
+ if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
|
|
|
+ tool_use_blocks[current_tool_index]["input"] = parsed_input
|
|
|
+ except json_module.JSONDecodeError:
|
|
|
+ # JSON 还不完整,继续累积
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 处理内容块停止
|
|
|
+ elif event.type == "content_block_stop":
|
|
|
+ current_block_type = None
|
|
|
+ current_tool_index = -1
|
|
|
+ partial_json = ""
|
|
|
+
|
|
|
+ # 如果没有工具调用,发送完成事件
|
|
|
+ if not tool_use_blocks:
|
|
|
+ yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info})}\n\n"
|
|
|
+ return
|
|
|
+
|
|
|
+ # 处理工具调用
|
|
|
+ yield f"event: tools_start\ndata: {json_module.dumps({'count': len(tool_use_blocks)})}\n\n"
|
|
|
+
|
|
|
+ tool_results = run_async(conv_manager.tool_handler.process_tool_use_blocks(
|
|
|
+ tool_use_blocks
|
|
|
+ ))
|
|
|
+
|
|
|
+ for i, tr in enumerate(tool_results):
|
|
|
+ tool_name = tr.get("tool_name", "")
|
|
|
+ tool_result = tr.get("result", {})
|
|
|
+
|
|
|
+ # 发送工具完成事件
|
|
|
+ if "error" in tool_result:
|
|
|
+ yield f"event: tool_error\ndata: {json_module.dumps({'tool': tool_name, 'error': tool_result['error']})}\n\n"
|
|
|
+ else:
|
|
|
+ yield f"event: tool_done\ndata: {json_module.dumps({'tool': tool_name, 'result': tool_result.get('result', '')[:200]})}\n\n"
|
|
|
+
|
|
|
+ tool_calls_info.append({
|
|
|
+ "tool": tool_name,
|
|
|
+ "result": tool_result
|
|
|
+ })
|
|
|
+
|
|
|
+ # 构建工具结果消息
|
|
|
+ tool_result_message = ToolCallHandler.create_tool_result_message(
|
|
|
+ tool_results
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加到消息历史
|
|
|
+ current_messages.append({
|
|
|
+ "role": "assistant",
|
|
|
+ "content": content_blocks
|
|
|
+ })
|
|
|
+ current_messages.append(tool_result_message)
|
|
|
+
|
|
|
+ # 达到最大轮数
|
|
|
+ yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info, 'warning': '达到最大对话轮数'})}\n\n"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ import traceback
|
|
|
+ yield f"event: error\ndata: {json_module.dumps({'error': str(e), 'traceback': traceback.format_exc()})}\n\n"
|
|
|
+
|
|
|
+ return Response(
|
|
|
+ generate(),
|
|
|
+ mimetype='text/event-stream',
|
|
|
+ headers={
|
|
|
+ 'Cache-Control': 'no-cache',
|
|
|
+ 'X-Accel-Buffering': 'no' # 禁用 Nginx 缓冲
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ import traceback
|
|
|
+ return jsonify({
|
|
|
+ "error": str(e),
|
|
|
+ "traceback": traceback.format_exc()
|
|
|
+ }), 500
|
|
|
+
|
|
|
+
|
|
|
@app.route('/api/mcp/servers', methods=['GET'])
|
|
|
def list_mcp_servers():
|
|
|
"""获取已配置的 MCP 服务器列表"""
|