""" AI MCP Web UI - Flask 后端 提供聊天界面与 MCP 工具调用的桥梁 """ import os import asyncio from typing import Optional, Dict from flask import Flask, request, jsonify, send_from_directory, Response import json as json_module from flask_cors import CORS import httpx from anthropic import Anthropic from config import MCP_SERVERS, ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL, ANTHROPIC_MODEL from conversation_manager import ConversationManager from tool_handler import ToolCallHandler app = Flask(__name__) CORS(app) app.secret_key = os.getenv('SECRET_KEY', 'dev-secret-key-change-in-production') # 存储认证会话 (生产环境应使用 Redis 或数据库) auth_sessions: Dict[str, dict] = {} def create_anthropic_client(api_key: str, base_url: str) -> Anthropic: """ 创建 Anthropic 客户端,支持自定义认证格式 自定义 API 代理需要 'Authorization: Bearer ' 格式, 而不是 Anthropic SDK 默认的 'x-api-key' header。 """ import httpx # 创建自定义 httpx client,设置正确的 Authorization header http_client = httpx.Client( headers={"Authorization": f"Bearer {api_key}"}, timeout=120.0 ) return Anthropic(base_url=base_url, http_client=http_client) @app.route('/') def index(): return send_from_directory('../frontend', 'index.html') @app.route('/') def static_files(path): return send_from_directory('../frontend', path) # 初始化 Claude 客户端(使用自定义认证格式) client = create_anthropic_client( api_key=ANTHROPIC_API_KEY, base_url=ANTHROPIC_BASE_URL ) @app.route('/api/health', methods=['GET']) def health(): """健康检查端点""" return jsonify({ "status": "ok", "model": ANTHROPIC_MODEL, "mcp_servers": list(MCP_SERVERS.keys()) }) def run_async(coro): """在同步上下文中运行异步函数""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(coro) finally: loop.close() @app.route('/api/chat', methods=['POST']) def chat(): """ 聊天端点 - 接收用户消息,返回 Claude 响应(支持 MCP 工具调用) """ try: data = request.json message = data.get('message', '') conversation_history = data.get('history', []) session_id = request.headers.get('X-Session-ID') if not message: return jsonify({"error": "Message is required"}), 400 # 创建对话管理器 conv_manager = ConversationManager( api_key=ANTHROPIC_API_KEY, base_url=ANTHROPIC_BASE_URL, model=ANTHROPIC_MODEL, session_id=session_id ) # 格式化对话历史 formatted_history = ConversationManager.format_history_for_claude(conversation_history) # 执行多轮对话(自动处理工具调用) result = run_async(conv_manager.chat( user_message=message, conversation_history=formatted_history, max_turns=5 )) # 提取响应文本 response_text = result.get("response", "") tool_calls = result.get("tool_calls", []) return jsonify({ "response": response_text, "model": ANTHROPIC_MODEL, "tool_calls": tool_calls, "has_tools": len(tool_calls) > 0 }) except Exception as e: import traceback return jsonify({ "error": str(e), "traceback": traceback.format_exc() }), 500 @app.route('/api/chat/stream', methods=['POST']) def chat_stream(): """ 聊天端点 - 流式输出版本(解决超时问题) 使用 Server-Sent Events (SSE) 实时返回: 1. Claude 的思考过程 2. 工具调用状态 3. 最终响应 """ try: data = request.json message = data.get('message', '') conversation_history = data.get('history', []) session_id = request.headers.get('X-Session-ID') if not message: return jsonify({"error": "Message is required"}), 400 def generate(): """生成 SSE 流式响应""" try: # 发送开始事件 yield f"event: start\ndata: {json_module.dumps({'status': 'started'})}\n\n" # 创建对话管理器 conv_manager = ConversationManager( api_key=ANTHROPIC_API_KEY, base_url=ANTHROPIC_BASE_URL, model=ANTHROPIC_MODEL, session_id=session_id ) # 格式化对话历史 formatted_history = ConversationManager.format_history_for_claude(conversation_history) messages = formatted_history + [{"role": "user", "content": message}] current_messages = messages tool_calls_info = [] for turn in range(5): # 最多 5 轮 # 获取可用工具 tools = run_async(conv_manager.get_available_tools()) # 发送工具列表 yield f"event: tools\ndata: {json_module.dumps({'count': len(tools), 'tools': [t['name'] for t in tools[:5]]})}\n\n" # 调用 Claude API(流式) if tools: response_stream = conv_manager.client.messages.create( model=conv_manager.model, max_tokens=4096, messages=current_messages, tools=tools, stream=True ) else: response_stream = conv_manager.client.messages.create( model=conv_manager.model, max_tokens=4096, messages=current_messages, stream=True ) # 处理流式响应 content_blocks = [] tool_use_blocks = [] response_text = "" current_block_type = None current_tool_index = -1 partial_json = "" for event in response_stream: # 处理内容块开始 - 检查是否是工具调用 if event.type == "content_block_start": # 检查块的类型 if hasattr(event, "content_block"): current_block_type = getattr(event.content_block, "type", None) if current_block_type == "tool_use": # 这是工具调用块的开始 tool_use_id = getattr(event.content_block, "id", "") # content_block 包含 name tool_name = getattr(event.content_block, "name", "") tool_use_blocks.append({ "type": "tool_use", "id": tool_use_id, "name": tool_name, "input": {} }) current_tool_index = len(tool_use_blocks) - 1 partial_json = "" # 处理内容块增量 elif event.type == "content_block_delta": delta_type = getattr(event.delta, "type", "") # 文本增量 if delta_type == "text_delta": text = event.delta.text response_text += text yield f"event: token\ndata: {json_module.dumps({'text': text})}\n\n" # 工具名称增量 elif delta_type == "tool_use_delta": # 获取工具名称和参数增量 delta_name = getattr(event.delta, "name", None) delta_input = getattr(event.delta, "input", None) if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks): if delta_name is not None: tool_use_blocks[current_tool_index]["name"] = delta_name if delta_input is not None: # 更新输入参数 current_input = tool_use_blocks[current_tool_index]["input"] if isinstance(delta_input, dict): current_input.update(delta_input) tool_use_blocks[current_tool_index]["input"] = current_input # 工具参数增量 - input_json_delta elif delta_type == "input_json_delta": # 累积 partial_json 构建完整参数 partial_json_str = getattr(event.delta, "partial_json", "") if partial_json_str: partial_json += partial_json_str try: # 尝试解析累积的 JSON parsed_input = json_module.loads(partial_json) if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks): tool_use_blocks[current_tool_index]["input"] = parsed_input except json_module.JSONDecodeError: # JSON 还不完整,继续累积 pass # 处理内容块停止 elif event.type == "content_block_stop": current_block_type = None current_tool_index = -1 partial_json = "" # 如果没有工具调用,发送完成事件 if not tool_use_blocks: yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info})}\n\n" return # 处理工具调用 yield f"event: tools_start\ndata: {json_module.dumps({'count': len(tool_use_blocks)})}\n\n" # 为每个工具调用发送 tool_call 事件 for tool_block in tool_use_blocks: yield f"event: tool_call\ndata: {json_module.dumps({'tool': tool_block['name'], 'args': tool_block['input'], 'tool_id': tool_block['id']})}\n\n" tool_results = run_async(conv_manager.tool_handler.process_tool_use_blocks( tool_use_blocks )) for i, tr in enumerate(tool_results): tool_name = tr.get("tool_name", "") tool_result = tr.get("result", {}) tool_use_id = tr.get("tool_use_id", "") # 发送工具完成事件 if "error" in tool_result: yield f"event: tool_error\ndata: {json_module.dumps({'tool': tool_name, 'tool_id': tool_use_id, 'error': tool_result['error']})}\n\n" else: result_data = tool_result.get('result', '') # 限制结果长度避免传输过大 if isinstance(result_data, str) and len(result_data) > 500: result_data = result_data[:500] + '...' yield f"event: tool_done\ndata: {json_module.dumps({'tool': tool_name, 'tool_id': tool_use_id, 'result': result_data})}\n\n" tool_calls_info.append({ "tool": tool_name, "result": tool_result }) # 构建工具结果消息 tool_result_message = ToolCallHandler.create_tool_result_message( tool_results ) # 添加到消息历史 current_messages.append({ "role": "assistant", "content": content_blocks }) current_messages.append(tool_result_message) # 达到最大轮数 yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info, 'warning': '达到最大对话轮数'})}\n\n" except Exception as e: import traceback yield f"event: error\ndata: {json_module.dumps({'error': str(e), 'traceback': traceback.format_exc()})}\n\n" return Response( generate(), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no' # 禁用 Nginx 缓冲 } ) except Exception as e: import traceback return jsonify({ "error": str(e), "traceback": traceback.format_exc() }), 500 @app.route('/api/mcp/servers', methods=['GET']) def list_mcp_servers(): """获取已配置的 MCP 服务器列表""" servers = [] for name, server in MCP_SERVERS.items(): servers.append({ "id": name, "name": server.get("name", name), "url": server.get("url", ""), "auth_type": server.get("auth_type", "none"), "enabled": server.get("enabled", False) }) return jsonify({"servers": servers}) @app.route('/api/mcp/tools', methods=['GET']) def list_mcp_tools(): """获取可用的 MCP 工具列表""" try: session_id = request.headers.get('X-Session-ID') # 使用静态方法获取工具 tools = ConversationManager.get_tools(session_id=session_id) return jsonify({ "tools": tools, "count": len(tools) }) except Exception as e: import traceback return jsonify({ "error": str(e), "traceback": traceback.format_exc(), "tools": [] }), 500 # ========== 认证 API ========== @app.route('/api/auth/login', methods=['POST']) def login(): """ Novel Platform 用户登录 代理到实际的登录端点并返回 JWT Token """ try: data = request.json username = data.get('username') password = data.get('password') if not username or not password: return jsonify({"error": "Username and password are required"}), 400 # 查找需要 JWT 认证的 MCP 服务器 target_server = None for server_id, config in MCP_SERVERS.items(): if config.get('auth_type') == 'jwt' and 'login_url' in config: target_server = config break if not target_server: return jsonify({"error": "No JWT-authenticated server configured"}), 400 # 构建登录 URL base_url = target_server.get('base_url', '') login_path = target_server.get('login_url', '/api/auth/login') login_url = f"{base_url}{login_path}" # 调用实际的登录接口(同步版本) response = httpx.post( login_url, json={"username": username, "password": password}, timeout=30.0 ) if response.status_code == 200: result = response.json() import uuid session_id = str(uuid.uuid4()) # 存储会话信息 auth_sessions[session_id] = { "username": username, "token": result.get("token"), "refresh_token": result.get("refresh_token"), "server": target_server.get("name") } return jsonify({ "success": True, "session_id": session_id, "username": username, "server": target_server.get("name"), "token": result.get("token") }) else: return jsonify({ "error": "Login failed", "details": response.text }), response.status_code except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/api/auth/admin-login', methods=['POST']) def admin_login(): """ Novel Platform 管理员登录 """ try: data = request.json username = data.get('username') password = data.get('password') if not username or not password: return jsonify({"error": "Username and password are required"}), 400 # 查找管理员 MCP 服务器 target_server = MCP_SERVERS.get('novel-platform-admin') if not target_server: return jsonify({"error": "Admin server not configured"}), 400 # 构建登录 URL base_url = target_server.get('base_url', '') login_path = target_server.get('login_url', '/api/auth/admin-login') login_url = f"{base_url}{login_path}" # 调用实际的登录接口 response = httpx.post( login_url, json={"username": username, "password": password}, timeout=30.0 ) if response.status_code == 200: result = response.json() import uuid session_id = str(uuid.uuid4()) auth_sessions[session_id] = { "username": username, "token": result.get("token"), "refresh_token": result.get("refresh_token"), "server": target_server.get("name"), "role": "admin" } return jsonify({ "success": True, "session_id": session_id, "username": username, "server": target_server.get("name"), "role": "admin", "token": result.get("token") }) else: return jsonify({ "error": "Admin login failed", "details": response.text }), response.status_code except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/api/auth/logout', methods=['POST']) def logout(): """登出并清除会话""" try: data = request.json session_id = data.get('session_id') if session_id and session_id in auth_sessions: del auth_sessions[session_id] return jsonify({"success": True}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/api/auth/status', methods=['GET']) def auth_status(): """检查认证状态""" session_id = request.headers.get('X-Session-ID') if session_id and session_id in auth_sessions: session = auth_sessions[session_id] return jsonify({ "authenticated": True, "username": session.get("username"), "server": session.get("server"), "role": session.get("role", "user") }) return jsonify({ "authenticated": False }) if __name__ == '__main__': port = int(os.getenv('PORT', 8080)) debug = os.getenv('DEBUG', 'False').lower() == 'true' app.run(host='0.0.0.0', port=port, debug=debug)