| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522 |
- """
- AI MCP Web UI - Flask 后端
- 提供聊天界面与 MCP 工具调用的桥梁
- """
- import os
- import asyncio
- from typing import Optional, Dict
- from flask import Flask, request, jsonify, send_from_directory, Response
- import json as json_module
- from flask_cors import CORS
- import httpx
- from anthropic import Anthropic
- from config import MCP_SERVERS, ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL, ANTHROPIC_MODEL
- from conversation_manager import ConversationManager
- from tool_handler import ToolCallHandler
- app = Flask(__name__)
- CORS(app)
- app.secret_key = os.getenv('SECRET_KEY', 'dev-secret-key-change-in-production')
- # 存储认证会话 (生产环境应使用 Redis 或数据库)
- auth_sessions: Dict[str, dict] = {}
- @app.route('/')
- def index():
- return send_from_directory('../frontend', 'index.html')
- @app.route('/<path:path>')
- def static_files(path):
- return send_from_directory('../frontend', path)
- # 初始化 Claude 客户端
- client = Anthropic(
- api_key=ANTHROPIC_API_KEY,
- base_url=ANTHROPIC_BASE_URL
- )
- @app.route('/api/health', methods=['GET'])
- def health():
- """健康检查端点"""
- return jsonify({
- "status": "ok",
- "model": ANTHROPIC_MODEL,
- "mcp_servers": list(MCP_SERVERS.keys())
- })
- def run_async(coro):
- """在同步上下文中运行异步函数"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- return loop.run_until_complete(coro)
- finally:
- loop.close()
- @app.route('/api/chat', methods=['POST'])
- def chat():
- """
- 聊天端点 - 接收用户消息,返回 Claude 响应(支持 MCP 工具调用)
- """
- try:
- data = request.json
- message = data.get('message', '')
- conversation_history = data.get('history', [])
- session_id = request.headers.get('X-Session-ID')
- if not message:
- return jsonify({"error": "Message is required"}), 400
- # 创建对话管理器
- conv_manager = ConversationManager(
- api_key=ANTHROPIC_API_KEY,
- base_url=ANTHROPIC_BASE_URL,
- model=ANTHROPIC_MODEL,
- session_id=session_id
- )
- # 格式化对话历史
- formatted_history = ConversationManager.format_history_for_claude(conversation_history)
- # 执行多轮对话(自动处理工具调用)
- result = run_async(conv_manager.chat(
- user_message=message,
- conversation_history=formatted_history,
- max_turns=5
- ))
- # 提取响应文本
- response_text = result.get("response", "")
- tool_calls = result.get("tool_calls", [])
- return jsonify({
- "response": response_text,
- "model": ANTHROPIC_MODEL,
- "tool_calls": tool_calls,
- "has_tools": len(tool_calls) > 0
- })
- except Exception as e:
- import traceback
- return jsonify({
- "error": str(e),
- "traceback": traceback.format_exc()
- }), 500
- @app.route('/api/chat/stream', methods=['POST'])
- def chat_stream():
- """
- 聊天端点 - 流式输出版本(解决超时问题)
- 使用 Server-Sent Events (SSE) 实时返回:
- 1. Claude 的思考过程
- 2. 工具调用状态
- 3. 最终响应
- """
- try:
- data = request.json
- message = data.get('message', '')
- conversation_history = data.get('history', [])
- session_id = request.headers.get('X-Session-ID')
- if not message:
- return jsonify({"error": "Message is required"}), 400
- def generate():
- """生成 SSE 流式响应"""
- try:
- # 发送开始事件
- yield f"event: start\ndata: {json_module.dumps({'status': 'started'})}\n\n"
- # 创建对话管理器
- conv_manager = ConversationManager(
- api_key=ANTHROPIC_API_KEY,
- base_url=ANTHROPIC_BASE_URL,
- model=ANTHROPIC_MODEL,
- session_id=session_id
- )
- # 格式化对话历史
- formatted_history = ConversationManager.format_history_for_claude(conversation_history)
- messages = formatted_history + [{"role": "user", "content": message}]
- current_messages = messages
- tool_calls_info = []
- for turn in range(5): # 最多 5 轮
- # 获取可用工具
- tools = run_async(conv_manager.get_available_tools())
- # 发送工具列表
- yield f"event: tools\ndata: {json_module.dumps({'count': len(tools), 'tools': [t['name'] for t in tools[:5]]})}\n\n"
- # 调用 Claude API(流式)
- if tools:
- response_stream = conv_manager.client.messages.create(
- model=conv_manager.model,
- max_tokens=4096,
- messages=current_messages,
- tools=tools,
- stream=True
- )
- else:
- response_stream = conv_manager.client.messages.create(
- model=conv_manager.model,
- max_tokens=4096,
- messages=current_messages,
- stream=True
- )
- # 处理流式响应
- content_blocks = []
- tool_use_blocks = []
- response_text = ""
- current_block_type = None
- current_tool_index = -1
- partial_json = ""
- for event in response_stream:
- # 处理内容块开始 - 检查是否是工具调用
- if event.type == "content_block_start":
- # 检查块的类型
- if hasattr(event, "content_block"):
- current_block_type = getattr(event.content_block, "type", None)
- if current_block_type == "tool_use":
- # 这是工具调用块的开始
- tool_use_id = getattr(event.content_block, "id", "")
- # content_block 包含 name
- tool_name = getattr(event.content_block, "name", "")
- tool_use_blocks.append({
- "type": "tool_use",
- "id": tool_use_id,
- "name": tool_name,
- "input": {}
- })
- current_tool_index = len(tool_use_blocks) - 1
- partial_json = ""
- # 处理内容块增量
- elif event.type == "content_block_delta":
- delta_type = getattr(event.delta, "type", "")
- # 文本增量
- if delta_type == "text_delta":
- text = event.delta.text
- response_text += text
- yield f"event: token\ndata: {json_module.dumps({'text': text})}\n\n"
- # 工具名称增量
- elif delta_type == "tool_use_delta":
- # 获取工具名称和参数增量
- delta_name = getattr(event.delta, "name", None)
- delta_input = getattr(event.delta, "input", None)
- if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
- if delta_name is not None:
- tool_use_blocks[current_tool_index]["name"] = delta_name
- if delta_input is not None:
- # 更新输入参数
- current_input = tool_use_blocks[current_tool_index]["input"]
- if isinstance(delta_input, dict):
- current_input.update(delta_input)
- tool_use_blocks[current_tool_index]["input"] = current_input
- # 工具参数增量 - input_json_delta
- elif delta_type == "input_json_delta":
- # 累积 partial_json 构建完整参数
- partial_json_str = getattr(event.delta, "partial_json", "")
- if partial_json_str:
- partial_json += partial_json_str
- try:
- # 尝试解析累积的 JSON
- parsed_input = json_module.loads(partial_json)
- if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
- tool_use_blocks[current_tool_index]["input"] = parsed_input
- except json_module.JSONDecodeError:
- # JSON 还不完整,继续累积
- pass
- # 处理内容块停止
- elif event.type == "content_block_stop":
- current_block_type = None
- current_tool_index = -1
- partial_json = ""
- # 如果没有工具调用,发送完成事件
- if not tool_use_blocks:
- yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info})}\n\n"
- return
- # 处理工具调用
- yield f"event: tools_start\ndata: {json_module.dumps({'count': len(tool_use_blocks)})}\n\n"
- # 为每个工具调用发送 tool_call 事件
- for tool_block in tool_use_blocks:
- yield f"event: tool_call\ndata: {json_module.dumps({'tool': tool_block['name'], 'args': tool_block['input'], 'tool_id': tool_block['id']})}\n\n"
- tool_results = run_async(conv_manager.tool_handler.process_tool_use_blocks(
- tool_use_blocks
- ))
- for i, tr in enumerate(tool_results):
- tool_name = tr.get("tool_name", "")
- tool_result = tr.get("result", {})
- # 发送工具完成事件
- if "error" in tool_result:
- yield f"event: tool_error\ndata: {json_module.dumps({'tool': tool_name, 'error': tool_result['error']})}\n\n"
- else:
- yield f"event: tool_done\ndata: {json_module.dumps({'tool': tool_name, 'result': tool_result.get('result', '')[:200]})}\n\n"
- tool_calls_info.append({
- "tool": tool_name,
- "result": tool_result
- })
- # 构建工具结果消息
- tool_result_message = ToolCallHandler.create_tool_result_message(
- tool_results
- )
- # 添加到消息历史
- current_messages.append({
- "role": "assistant",
- "content": content_blocks
- })
- current_messages.append(tool_result_message)
- # 达到最大轮数
- yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info, 'warning': '达到最大对话轮数'})}\n\n"
- except Exception as e:
- import traceback
- yield f"event: error\ndata: {json_module.dumps({'error': str(e), 'traceback': traceback.format_exc()})}\n\n"
- return Response(
- generate(),
- mimetype='text/event-stream',
- headers={
- 'Cache-Control': 'no-cache',
- 'X-Accel-Buffering': 'no' # 禁用 Nginx 缓冲
- }
- )
- except Exception as e:
- import traceback
- return jsonify({
- "error": str(e),
- "traceback": traceback.format_exc()
- }), 500
- @app.route('/api/mcp/servers', methods=['GET'])
- def list_mcp_servers():
- """获取已配置的 MCP 服务器列表"""
- servers = []
- for name, server in MCP_SERVERS.items():
- servers.append({
- "id": name,
- "name": server.get("name", name),
- "url": server.get("url", ""),
- "auth_type": server.get("auth_type", "none"),
- "enabled": server.get("enabled", False)
- })
- return jsonify({"servers": servers})
- @app.route('/api/mcp/tools', methods=['GET'])
- def list_mcp_tools():
- """获取可用的 MCP 工具列表"""
- try:
- session_id = request.headers.get('X-Session-ID')
- # 使用静态方法获取工具
- tools = ConversationManager.get_tools(session_id=session_id)
- return jsonify({
- "tools": tools,
- "count": len(tools)
- })
- except Exception as e:
- import traceback
- return jsonify({
- "error": str(e),
- "traceback": traceback.format_exc(),
- "tools": []
- }), 500
- # ========== 认证 API ==========
- @app.route('/api/auth/login', methods=['POST'])
- def login():
- """
- Novel Platform 用户登录
- 代理到实际的登录端点并返回 JWT Token
- """
- try:
- data = request.json
- username = data.get('username')
- password = data.get('password')
- if not username or not password:
- return jsonify({"error": "Username and password are required"}), 400
- # 查找需要 JWT 认证的 MCP 服务器
- target_server = None
- for server_id, config in MCP_SERVERS.items():
- if config.get('auth_type') == 'jwt' and 'login_url' in config:
- target_server = config
- break
- if not target_server:
- return jsonify({"error": "No JWT-authenticated server configured"}), 400
- # 构建登录 URL
- base_url = target_server.get('base_url', '')
- login_path = target_server.get('login_url', '/api/auth/login')
- login_url = f"{base_url}{login_path}"
- # 调用实际的登录接口(同步版本)
- response = httpx.post(
- login_url,
- json={"username": username, "password": password},
- timeout=30.0
- )
- if response.status_code == 200:
- result = response.json()
- import uuid
- session_id = str(uuid.uuid4())
- # 存储会话信息
- auth_sessions[session_id] = {
- "username": username,
- "token": result.get("token"),
- "refresh_token": result.get("refresh_token"),
- "server": target_server.get("name")
- }
- return jsonify({
- "success": True,
- "session_id": session_id,
- "username": username,
- "server": target_server.get("name"),
- "token": result.get("token")
- })
- else:
- return jsonify({
- "error": "Login failed",
- "details": response.text
- }), response.status_code
- except Exception as e:
- return jsonify({"error": str(e)}), 500
- @app.route('/api/auth/admin-login', methods=['POST'])
- def admin_login():
- """
- Novel Platform 管理员登录
- """
- try:
- data = request.json
- username = data.get('username')
- password = data.get('password')
- if not username or not password:
- return jsonify({"error": "Username and password are required"}), 400
- # 查找管理员 MCP 服务器
- target_server = MCP_SERVERS.get('novel-platform-admin')
- if not target_server:
- return jsonify({"error": "Admin server not configured"}), 400
- # 构建登录 URL
- base_url = target_server.get('base_url', '')
- login_path = target_server.get('login_url', '/api/auth/admin-login')
- login_url = f"{base_url}{login_path}"
- # 调用实际的登录接口
- response = httpx.post(
- login_url,
- json={"username": username, "password": password},
- timeout=30.0
- )
- if response.status_code == 200:
- result = response.json()
- import uuid
- session_id = str(uuid.uuid4())
- auth_sessions[session_id] = {
- "username": username,
- "token": result.get("token"),
- "refresh_token": result.get("refresh_token"),
- "server": target_server.get("name"),
- "role": "admin"
- }
- return jsonify({
- "success": True,
- "session_id": session_id,
- "username": username,
- "server": target_server.get("name"),
- "role": "admin",
- "token": result.get("token")
- })
- else:
- return jsonify({
- "error": "Admin login failed",
- "details": response.text
- }), response.status_code
- except Exception as e:
- return jsonify({"error": str(e)}), 500
- @app.route('/api/auth/logout', methods=['POST'])
- def logout():
- """登出并清除会话"""
- try:
- data = request.json
- session_id = data.get('session_id')
- if session_id and session_id in auth_sessions:
- del auth_sessions[session_id]
- return jsonify({"success": True})
- except Exception as e:
- return jsonify({"error": str(e)}), 500
- @app.route('/api/auth/status', methods=['GET'])
- def auth_status():
- """检查认证状态"""
- session_id = request.headers.get('X-Session-ID')
- if session_id and session_id in auth_sessions:
- session = auth_sessions[session_id]
- return jsonify({
- "authenticated": True,
- "username": session.get("username"),
- "server": session.get("server"),
- "role": session.get("role", "user")
- })
- return jsonify({
- "authenticated": False
- })
- if __name__ == '__main__':
- port = int(os.getenv('PORT', 5000))
- debug = os.getenv('DEBUG', 'False').lower() == 'true'
- app.run(host='0.0.0.0', port=port, debug=debug)
|