2
0
Просмотр исходного кода

feat: 将后端从 Flask 迁移到 FastAPI

主要变更:
- 创建 app_fastapi.py 作为新的 FastAPI 应用入口
- 使用 FastAPI 的异步特性替代 Flask 的同步模式
- 更新 requirements.txt,移除 Flask 依赖,添加 FastAPI 和 Uvicorn
- 所有路由端点保持相同的 API 接口,确保前端无需修改

技术细节:
- 使用 @app.get/@app.post 替代 @app.route
- 使用 async def 进行异步处理
- 使用 await request.json() 替代 request.json
- 使用 StreamingResponse 实现流式 SSE 响应
- 使用 CORSMiddleware 替代 flask-cors
- 保留端口 8080 配置,确保外网访问正常

测试结果:
- /api/health - 健康检查正常
- /api/mcp/servers - MCP 服务器列表正常
- /api/auth/status - 认证状态检查正常
- 服务器在端口 8080 成功启动

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Claude AI 4 часов назад
Родитель
Сommit
6d7a3901dd
2 измененных файлов с 601 добавлено и 2 удалено
  1. 598 0
      backend/app_fastapi.py
  2. 3 2
      backend/requirements.txt

+ 598 - 0
backend/app_fastapi.py

@@ -0,0 +1,598 @@
+"""
+AI MCP Web UI - FastAPI 后端
+提供聊天界面与 MCP 工具调用的桥梁
+"""
+import os
+import asyncio
+import uuid
+import json as json_module
+from typing import Optional, Dict, List, Any
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, Request, HTTPException, Header
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse, JSONResponse
+from fastapi.staticfiles import StaticFiles
+import httpx
+from anthropic import Anthropic
+
+from config import MCP_SERVERS, ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL, ANTHROPIC_MODEL
+from conversation_manager import ConversationManager
+from tool_handler import ToolCallHandler
+
+
+# 存储认证会话 (生产环境应使用 Redis 或数据库)
+auth_sessions: Dict[str, dict] = {}
+
+
+def create_anthropic_client(api_key: str, base_url: str) -> Anthropic:
+    """
+    创建 Anthropic 客户端,支持自定义认证格式
+
+    自定义 API 代理需要 'Authorization: Bearer <token>' 格式,
+    而不是 Anthropic SDK 默认的 'x-api-key' header。
+    """
+    # 创建自定义 httpx client,设置正确的 Authorization header
+    http_client = httpx.Client(
+        headers={"Authorization": f"Bearer {api_key}"},
+        timeout=120.0
+    )
+    return Anthropic(base_url=base_url, http_client=http_client)
+
+
+# 初始化 Claude 客户端(使用自定义认证格式)
+client = create_anthropic_client(
+    api_key=ANTHROPIC_API_KEY,
+    base_url=ANTHROPIC_BASE_URL
+)
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    """应用生命周期管理"""
+    # 启动时执行
+    print(f"FastAPI 应用启动 - 模型: {ANTHROPIC_MODEL}")
+    print(f"MCP 服务器: {list(MCP_SERVERS.keys())}")
+    yield
+    # 关闭时执行
+    print("FastAPI 应用关闭")
+
+
+# 创建 FastAPI 应用
+app = FastAPI(
+    title="AI MCP Web UI Backend",
+    description="AI MCP Web UI 后端服务 - 支持 Claude AI 和 MCP 工具调用",
+    version="2.0.0",
+    lifespan=lifespan
+)
+
+# CORS 配置
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+# 挂载静态文件
+frontend_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "frontend")
+app.mount("/static", StaticFiles(directory=frontend_path), name="static")
+
+
+# ========== 根路由 ==========
+
+@app.get("/")
+async def index():
+    """返回前端主页"""
+    from fastapi.responses import FileResponse
+    index_path = os.path.join(frontend_path, "index.html")
+    return FileResponse(index_path)
+
+
+# ========== 健康检查 ==========
+
+@app.get("/api/health")
+async def health():
+    """健康检查端点"""
+    return {
+        "status": "ok",
+        "model": ANTHROPIC_MODEL,
+        "mcp_servers": list(MCP_SERVERS.keys())
+    }
+
+
+# ========== 聊天 API ==========
+
+@app.post("/api/chat")
+async def chat(request: Request):
+    """
+    聊天端点 - 接收用户消息,返回 Claude 响应(支持 MCP 工具调用)
+    """
+    try:
+        data = await request.json()
+        message = data.get('message', '')
+        conversation_history = data.get('history', [])
+        session_id = request.headers.get('X-Session-ID')
+
+        if not message:
+            raise HTTPException(status_code=400, detail="Message is required")
+
+        # 创建对话管理器
+        conv_manager = ConversationManager(
+            api_key=ANTHROPIC_API_KEY,
+            base_url=ANTHROPIC_BASE_URL,
+            model=ANTHROPIC_MODEL,
+            session_id=session_id
+        )
+
+        # 格式化对话历史
+        formatted_history = ConversationManager.format_history_for_claude(conversation_history)
+
+        # 执行多轮对话(自动处理工具调用)
+        result = await conv_manager.chat(
+            user_message=message,
+            conversation_history=formatted_history,
+            max_turns=5
+        )
+
+        # 提取响应文本
+        response_text = result.get("response", "")
+        tool_calls = result.get("tool_calls", [])
+
+        return {
+            "response": response_text,
+            "model": ANTHROPIC_MODEL,
+            "tool_calls": tool_calls,
+            "has_tools": len(tool_calls) > 0
+        }
+
+    except HTTPException:
+        raise
+    except Exception as e:
+        import traceback
+        return JSONResponse(
+            status_code=500,
+            content={
+                "error": str(e),
+                "traceback": traceback.format_exc()
+            }
+        )
+
+
+async def generate_chat_stream(
+    message: str,
+    conversation_history: List[Dict[str, Any]],
+    session_id: Optional[str]
+):
+    """生成 SSE 流式响应的异步生成器"""
+    try:
+        # 发送开始事件
+        yield f"event: start\ndata: {json_module.dumps({'status': 'started'})}\n\n"
+
+        # 创建对话管理器
+        conv_manager = ConversationManager(
+            api_key=ANTHROPIC_API_KEY,
+            base_url=ANTHROPIC_BASE_URL,
+            model=ANTHROPIC_MODEL,
+            session_id=session_id
+        )
+
+        # 格式化对话历史
+        formatted_history = ConversationManager.format_history_for_claude(conversation_history)
+        messages = formatted_history + [{"role": "user", "content": message}]
+
+        current_messages = messages
+        tool_calls_info = []
+
+        for turn in range(5):  # 最多 5 轮
+            # 获取可用工具
+            tools = await conv_manager.get_available_tools()
+
+            # 发送工具列表
+            yield f"event: tools\ndata: {json_module.dumps({'count': len(tools), 'tools': [t['name'] for t in tools[:5]]})}\n\n"
+
+            # 调用 Claude API(流式)
+            if tools:
+                response_stream = conv_manager.client.messages.create(
+                    model=conv_manager.model,
+                    max_tokens=4096,
+                    messages=current_messages,
+                    tools=tools,
+                    stream=True
+                )
+            else:
+                response_stream = conv_manager.client.messages.create(
+                    model=conv_manager.model,
+                    max_tokens=4096,
+                    messages=current_messages,
+                    stream=True
+                )
+
+            # 处理流式响应
+            content_blocks = []
+            tool_use_blocks = []
+            response_text = ""
+            current_block_type = None
+            current_tool_index = -1
+            partial_json = ""
+
+            for event in response_stream:
+                # 处理内容块开始 - 检查是否是工具调用
+                if event.type == "content_block_start":
+                    # 检查块的类型
+                    if hasattr(event, "content_block"):
+                        current_block_type = getattr(event.content_block, "type", None)
+                        if current_block_type == "tool_use":
+                            # 这是工具调用块的开始
+                            tool_use_id = getattr(event.content_block, "id", "")
+                            # content_block 包含 name
+                            tool_name = getattr(event.content_block, "name", "")
+                            tool_use_blocks.append({
+                                "type": "tool_use",
+                                "id": tool_use_id,
+                                "name": tool_name,
+                                "input": {}
+                            })
+                            current_tool_index = len(tool_use_blocks) - 1
+                            partial_json = ""
+
+                # 处理内容块增量
+                elif event.type == "content_block_delta":
+                    delta_type = getattr(event.delta, "type", "")
+
+                    # 文本增量
+                    if delta_type == "text_delta":
+                        text = event.delta.text
+                        response_text += text
+                        yield f"event: token\ndata: {json_module.dumps({'text': text})}\n\n"
+
+                    # 工具名称增量
+                    elif delta_type == "tool_use_delta":
+                        # 获取工具名称和参数增量
+                        delta_name = getattr(event.delta, "name", None)
+                        delta_input = getattr(event.delta, "input", None)
+
+                        if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
+                            if delta_name is not None:
+                                tool_use_blocks[current_tool_index]["name"] = delta_name
+                            if delta_input is not None:
+                                # 更新输入参数
+                                current_input = tool_use_blocks[current_tool_index]["input"]
+                                if isinstance(delta_input, dict):
+                                    current_input.update(delta_input)
+                                    tool_use_blocks[current_tool_index]["input"] = current_input
+
+                    # 工具参数增量 - input_json_delta
+                    elif delta_type == "input_json_delta":
+                        # 累积 partial_json 构建完整参数
+                        partial_json_str = getattr(event.delta, "partial_json", "")
+                        if partial_json_str:
+                            partial_json += partial_json_str
+                            try:
+                                # 尝试解析累积的 JSON
+                                parsed_input = json_module.loads(partial_json)
+                                if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
+                                    tool_use_blocks[current_tool_index]["input"] = parsed_input
+                            except json_module.JSONDecodeError:
+                                # JSON 还不完整,继续累积
+                                pass
+
+                # 处理内容块停止
+                elif event.type == "content_block_stop":
+                    current_block_type = None
+                    current_tool_index = -1
+                    partial_json = ""
+
+            # 如果没有工具调用,发送完成事件
+            if not tool_use_blocks:
+                yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info})}\n\n"
+                return
+
+            # 处理工具调用
+            yield f"event: tools_start\ndata: {json_module.dumps({'count': len(tool_use_blocks)})}\n\n"
+
+            # 为每个工具调用发送 tool_call 事件
+            for tool_block in tool_use_blocks:
+                yield f"event: tool_call\ndata: {json_module.dumps({'tool': tool_block['name'], 'args': tool_block['input'], 'tool_id': tool_block['id']})}\n\n"
+
+            tool_results = await conv_manager.tool_handler.process_tool_use_blocks(
+                tool_use_blocks
+            )
+
+            for tr in tool_results:
+                tool_name = tr.get("tool_name", "")
+                tool_result = tr.get("result", {})
+                tool_use_id = tr.get("tool_use_id", "")
+
+                # 发送工具完成事件
+                if "error" in tool_result:
+                    yield f"event: tool_error\ndata: {json_module.dumps({'tool': tool_name, 'tool_id': tool_use_id, 'error': tool_result['error']})}\n\n"
+                else:
+                    result_data = tool_result.get('result', '')
+                    # 限制结果长度避免传输过大
+                    if isinstance(result_data, str) and len(result_data) > 500:
+                        result_data = result_data[:500] + '...'
+                    yield f"event: tool_done\ndata: {json_module.dumps({'tool': tool_name, 'tool_id': tool_use_id, 'result': result_data})}\n\n"
+
+                tool_calls_info.append({
+                    "tool": tool_name,
+                    "result": tool_result
+                })
+
+            # 构建工具结果消息
+            tool_result_message = ToolCallHandler.create_tool_result_message(
+                tool_results
+            )
+
+            # 添加到消息历史
+            current_messages.append({
+                "role": "assistant",
+                "content": content_blocks
+            })
+            current_messages.append(tool_result_message)
+
+        # 达到最大轮数
+        yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info, 'warning': '达到最大对话轮数'})}\n\n"
+
+    except Exception as e:
+        import traceback
+        yield f"event: error\ndata: {json_module.dumps({'error': str(e), 'traceback': traceback.format_exc()})}\n\n"
+
+
+@app.post("/api/chat/stream")
+async def chat_stream(request: Request):
+    """
+    聊天端点 - 流式输出版本(解决超时问题)
+
+    使用 Server-Sent Events (SSE) 实时返回:
+    1. Claude 的思考过程
+    2. 工具调用状态
+    3. 最终响应
+    """
+    try:
+        data = await request.json()
+        message = data.get('message', '')
+        conversation_history = data.get('history', [])
+        session_id = request.headers.get('X-Session-ID')
+
+        if not message:
+            raise HTTPException(status_code=400, detail="Message is required")
+
+        return StreamingResponse(
+            generate_chat_stream(message, conversation_history, session_id),
+            media_type="text/event-stream",
+            headers={
+                'Cache-Control': 'no-cache',
+                'X-Accel-Buffering': 'no'  # 禁用 Nginx 缓冲
+            }
+        )
+
+    except HTTPException:
+        raise
+    except Exception as e:
+        import traceback
+        return JSONResponse(
+            status_code=500,
+            content={
+                "error": str(e),
+                "traceback": traceback.format_exc()
+            }
+        )
+
+
+# ========== MCP API ==========
+
+@app.get("/api/mcp/servers")
+async def list_mcp_servers():
+    """获取已配置的 MCP 服务器列表"""
+    servers = []
+    for name, server in MCP_SERVERS.items():
+        servers.append({
+            "id": name,
+            "name": server.get("name", name),
+            "url": server.get("url", ""),
+            "auth_type": server.get("auth_type", "none"),
+            "enabled": server.get("enabled", False)
+        })
+    return {"servers": servers}
+
+
+@app.get("/api/mcp/tools")
+async def list_mcp_tools(x_session_id: Optional[str] = Header(None, alias='X-Session-ID')):
+    """获取可用的 MCP 工具列表"""
+    try:
+        # 使用静态方法获取工具
+        tools = ConversationManager.get_tools(session_id=x_session_id)
+
+        return {
+            "tools": tools,
+            "count": len(tools)
+        }
+    except Exception as e:
+        import traceback
+        return JSONResponse(
+            status_code=500,
+            content={
+                "error": str(e),
+                "traceback": traceback.format_exc(),
+                "tools": []
+            }
+        )
+
+
+# ========== 认证 API ==========
+
+@app.post("/api/auth/login")
+async def login(request: Request):
+    """
+    Novel Platform 用户登录
+    代理到实际的登录端点并返回 JWT Token
+    """
+    try:
+        data = await request.json()
+        username = data.get('username')
+        password = data.get('password')
+
+        if not username or not password:
+            raise HTTPException(status_code=400, detail="Username and password are required")
+
+        # 查找需要 JWT 认证的 MCP 服务器
+        target_server = None
+        for server_id, config in MCP_SERVERS.items():
+            if config.get('auth_type') == 'jwt' and 'login_url' in config:
+                target_server = config
+                break
+
+        if not target_server:
+            raise HTTPException(status_code=400, detail="No JWT-authenticated server configured")
+
+        # 构建登录 URL
+        base_url = target_server.get('base_url', '')
+        login_path = target_server.get('login_url', '/api/auth/login')
+        login_url = f"{base_url}{login_path}"
+
+        # 调用实际的登录接口(异步版本)
+        async with httpx.AsyncClient(timeout=30.0) as http_client:
+            response = await http_client.post(
+                login_url,
+                json={"username": username, "password": password}
+            )
+
+        if response.status_code == 200:
+            result = response.json()
+            session_id = str(uuid.uuid4())
+
+            # 存储会话信息
+            auth_sessions[session_id] = {
+                "username": username,
+                "token": result.get("token"),
+                "refresh_token": result.get("refresh_token"),
+                "server": target_server.get("name")
+            }
+
+            return {
+                "success": True,
+                "session_id": session_id,
+                "username": username,
+                "server": target_server.get("name"),
+                "token": result.get("token")
+            }
+        else:
+            raise HTTPException(
+                status_code=response.status_code,
+                detail=f"Login failed: {response.text}"
+            )
+
+    except HTTPException:
+        raise
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.post("/api/auth/admin-login")
+async def admin_login(request: Request):
+    """
+    Novel Platform 管理员登录
+    """
+    try:
+        data = await request.json()
+        username = data.get('username')
+        password = data.get('password')
+
+        if not username or not password:
+            raise HTTPException(status_code=400, detail="Username and password are required")
+
+        # 查找管理员 MCP 服务器
+        target_server = MCP_SERVERS.get('novel-platform-admin')
+        if not target_server:
+            raise HTTPException(status_code=400, detail="Admin server not configured")
+
+        # 构建登录 URL
+        base_url = target_server.get('base_url', '')
+        login_path = target_server.get('login_url', '/api/auth/admin-login')
+        login_url = f"{base_url}{login_path}"
+
+        # 调用实际的登录接口(异步版本)
+        async with httpx.AsyncClient(timeout=30.0) as http_client:
+            response = await http_client.post(
+                login_url,
+                json={"username": username, "password": password}
+            )
+
+        if response.status_code == 200:
+            result = response.json()
+            session_id = str(uuid.uuid4())
+
+            auth_sessions[session_id] = {
+                "username": username,
+                "token": result.get("token"),
+                "refresh_token": result.get("refresh_token"),
+                "server": target_server.get("name"),
+                "role": "admin"
+            }
+
+            return {
+                "success": True,
+                "session_id": session_id,
+                "username": username,
+                "server": target_server.get("name"),
+                "role": "admin",
+                "token": result.get("token")
+            }
+        else:
+            raise HTTPException(
+                status_code=response.status_code,
+                detail=f"Admin login failed: {response.text}"
+            )
+
+    except HTTPException:
+        raise
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.post("/api/auth/logout")
+async def logout(request: Request):
+    """登出并清除会话"""
+    try:
+        data = await request.json()
+        session_id = data.get('session_id')
+
+        if session_id and session_id in auth_sessions:
+            del auth_sessions[session_id]
+
+        return {"success": True}
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.get("/api/auth/status")
+async def auth_status(x_session_id: Optional[str] = Header(None, alias='X-Session-ID')):
+    """检查认证状态"""
+    if x_session_id and x_session_id in auth_sessions:
+        session = auth_sessions[x_session_id]
+        return {
+            "authenticated": True,
+            "username": session.get("username"),
+            "server": session.get("server"),
+            "role": session.get("role", "user")
+        }
+
+    return {"authenticated": False}
+
+
+# ========== 主程序入口 ==========
+
+if __name__ == '__main__':
+    import uvicorn
+
+    port = int(os.getenv('PORT', 8080))
+    debug = os.getenv('DEBUG', 'False').lower() == 'true'
+
+    uvicorn.run(
+        "app_fastapi:app",
+        host='0.0.0.0',
+        port=port,
+        reload=debug
+    )

+ 3 - 2
backend/requirements.txt

@@ -1,5 +1,6 @@
-flask==3.0.0
-flask-cors==4.0.0
+# FastAPI 后端依赖
+fastapi>=0.135.1
+uvicorn[standard]>=0.30.0
 anthropic==0.40.0
 python-dotenv==1.0.0
 mcp==0.9.1