app_fastapi.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. """
  2. AI MCP Web UI - FastAPI 后端
  3. 提供聊天界面与 MCP 工具调用的桥梁
  4. """
  5. import os
  6. import asyncio
  7. import uuid
  8. import json as json_module
  9. from typing import Optional, Dict, List, Any
  10. from contextlib import asynccontextmanager
  11. from fastapi import FastAPI, Request, HTTPException, Header
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from fastapi.responses import StreamingResponse, JSONResponse
  14. from fastapi.staticfiles import StaticFiles
  15. import httpx
  16. from anthropic import Anthropic
  17. from config import MCP_SERVERS, ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL, ANTHROPIC_MODEL
  18. from conversation_manager import ConversationManager
  19. from tool_handler import ToolCallHandler
  20. # 存储认证会话 (生产环境应使用 Redis 或数据库)
  21. auth_sessions: Dict[str, dict] = {}
  22. def create_anthropic_client(api_key: str, base_url: str) -> Anthropic:
  23. """
  24. 创建 Anthropic 客户端,支持自定义认证格式
  25. 自定义 API 代理需要 'Authorization: Bearer <token>' 格式,
  26. 而不是 Anthropic SDK 默认的 'x-api-key' header。
  27. """
  28. # 创建自定义 httpx client,设置正确的 Authorization header
  29. http_client = httpx.Client(
  30. headers={"Authorization": f"Bearer {api_key}"},
  31. timeout=120.0
  32. )
  33. return Anthropic(base_url=base_url, http_client=http_client)
  34. # 初始化 Claude 客户端(使用自定义认证格式)
  35. client = create_anthropic_client(
  36. api_key=ANTHROPIC_API_KEY,
  37. base_url=ANTHROPIC_BASE_URL
  38. )
  39. @asynccontextmanager
  40. async def lifespan(app: FastAPI):
  41. """应用生命周期管理"""
  42. # 启动时执行
  43. print(f"FastAPI 应用启动 - 模型: {ANTHROPIC_MODEL}")
  44. print(f"MCP 服务器: {list(MCP_SERVERS.keys())}")
  45. yield
  46. # 关闭时执行
  47. print("FastAPI 应用关闭")
  48. # 创建 FastAPI 应用
  49. app = FastAPI(
  50. title="AI MCP Web UI Backend",
  51. description="AI MCP Web UI 后端服务 - 支持 Claude AI 和 MCP 工具调用",
  52. version="2.0.0",
  53. lifespan=lifespan
  54. )
  55. # CORS 配置
  56. app.add_middleware(
  57. CORSMiddleware,
  58. allow_origins=["*"],
  59. allow_credentials=True,
  60. allow_methods=["*"],
  61. allow_headers=["*"],
  62. )
  63. # 挂载静态文件
  64. frontend_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "frontend")
  65. app.mount("/static", StaticFiles(directory=frontend_path), name="static")
  66. # ========== 根路由 ==========
  67. @app.get("/")
  68. async def index():
  69. """返回前端主页"""
  70. from fastapi.responses import FileResponse
  71. index_path = os.path.join(frontend_path, "index.html")
  72. return FileResponse(index_path)
  73. # ========== 健康检查 ==========
  74. @app.get("/api/health")
  75. async def health():
  76. """健康检查端点"""
  77. return {
  78. "status": "ok",
  79. "model": ANTHROPIC_MODEL,
  80. "mcp_servers": list(MCP_SERVERS.keys())
  81. }
  82. # ========== 聊天 API ==========
  83. @app.post("/api/chat")
  84. async def chat(request: Request):
  85. """
  86. 聊天端点 - 接收用户消息,返回 Claude 响应(支持 MCP 工具调用)
  87. """
  88. try:
  89. data = await request.json()
  90. message = data.get('message', '')
  91. conversation_history = data.get('history', [])
  92. session_id = request.headers.get('X-Session-ID')
  93. if not message:
  94. raise HTTPException(status_code=400, detail="Message is required")
  95. # 创建对话管理器
  96. conv_manager = ConversationManager(
  97. api_key=ANTHROPIC_API_KEY,
  98. base_url=ANTHROPIC_BASE_URL,
  99. model=ANTHROPIC_MODEL,
  100. session_id=session_id
  101. )
  102. # 格式化对话历史
  103. formatted_history = ConversationManager.format_history_for_claude(conversation_history)
  104. # 执行多轮对话(自动处理工具调用)
  105. result = await conv_manager.chat(
  106. user_message=message,
  107. conversation_history=formatted_history,
  108. max_turns=5
  109. )
  110. # 提取响应文本
  111. response_text = result.get("response", "")
  112. tool_calls = result.get("tool_calls", [])
  113. return {
  114. "response": response_text,
  115. "model": ANTHROPIC_MODEL,
  116. "tool_calls": tool_calls,
  117. "has_tools": len(tool_calls) > 0
  118. }
  119. except HTTPException:
  120. raise
  121. except Exception as e:
  122. import traceback
  123. return JSONResponse(
  124. status_code=500,
  125. content={
  126. "error": str(e),
  127. "traceback": traceback.format_exc()
  128. }
  129. )
  130. async def generate_chat_stream(
  131. message: str,
  132. conversation_history: List[Dict[str, Any]],
  133. session_id: Optional[str]
  134. ):
  135. """生成 SSE 流式响应的异步生成器"""
  136. try:
  137. # 发送开始事件
  138. yield f"event: start\ndata: {json_module.dumps({'status': 'started'})}\n\n"
  139. # 创建对话管理器
  140. conv_manager = ConversationManager(
  141. api_key=ANTHROPIC_API_KEY,
  142. base_url=ANTHROPIC_BASE_URL,
  143. model=ANTHROPIC_MODEL,
  144. session_id=session_id
  145. )
  146. # 格式化对话历史
  147. formatted_history = ConversationManager.format_history_for_claude(conversation_history)
  148. messages = formatted_history + [{"role": "user", "content": message}]
  149. current_messages = messages
  150. tool_calls_info = []
  151. for turn in range(5): # 最多 5 轮
  152. # 获取可用工具
  153. tools = await conv_manager.get_available_tools()
  154. # 发送工具列表
  155. yield f"event: tools\ndata: {json_module.dumps({'count': len(tools), 'tools': [t['name'] for t in tools[:5]]})}\n\n"
  156. # 调用 Claude API(流式)
  157. if tools:
  158. response_stream = conv_manager.client.messages.create(
  159. model=conv_manager.model,
  160. max_tokens=4096,
  161. messages=current_messages,
  162. tools=tools,
  163. stream=True
  164. )
  165. else:
  166. response_stream = conv_manager.client.messages.create(
  167. model=conv_manager.model,
  168. max_tokens=4096,
  169. messages=current_messages,
  170. stream=True
  171. )
  172. # 处理流式响应
  173. content_blocks = []
  174. tool_use_blocks = []
  175. response_text = ""
  176. current_block_type = None
  177. current_tool_index = -1
  178. partial_json = ""
  179. for event in response_stream:
  180. # 处理内容块开始 - 检查是否是工具调用
  181. if event.type == "content_block_start":
  182. # 检查块的类型
  183. if hasattr(event, "content_block"):
  184. current_block_type = getattr(event.content_block, "type", None)
  185. if current_block_type == "tool_use":
  186. # 这是工具调用块的开始
  187. tool_use_id = getattr(event.content_block, "id", "")
  188. # content_block 包含 name
  189. tool_name = getattr(event.content_block, "name", "")
  190. tool_use_blocks.append({
  191. "type": "tool_use",
  192. "id": tool_use_id,
  193. "name": tool_name,
  194. "input": {}
  195. })
  196. current_tool_index = len(tool_use_blocks) - 1
  197. partial_json = ""
  198. # 处理内容块增量
  199. elif event.type == "content_block_delta":
  200. delta_type = getattr(event.delta, "type", "")
  201. # 文本增量
  202. if delta_type == "text_delta":
  203. text = event.delta.text
  204. response_text += text
  205. yield f"event: token\ndata: {json_module.dumps({'text': text})}\n\n"
  206. # 工具名称增量
  207. elif delta_type == "tool_use_delta":
  208. # 获取工具名称和参数增量
  209. delta_name = getattr(event.delta, "name", None)
  210. delta_input = getattr(event.delta, "input", None)
  211. if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
  212. if delta_name is not None:
  213. tool_use_blocks[current_tool_index]["name"] = delta_name
  214. if delta_input is not None:
  215. # 更新输入参数
  216. current_input = tool_use_blocks[current_tool_index]["input"]
  217. if isinstance(delta_input, dict):
  218. current_input.update(delta_input)
  219. tool_use_blocks[current_tool_index]["input"] = current_input
  220. # 工具参数增量 - input_json_delta
  221. elif delta_type == "input_json_delta":
  222. # 累积 partial_json 构建完整参数
  223. partial_json_str = getattr(event.delta, "partial_json", "")
  224. if partial_json_str:
  225. partial_json += partial_json_str
  226. try:
  227. # 尝试解析累积的 JSON
  228. parsed_input = json_module.loads(partial_json)
  229. if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
  230. tool_use_blocks[current_tool_index]["input"] = parsed_input
  231. except json_module.JSONDecodeError:
  232. # JSON 还不完整,继续累积
  233. pass
  234. # 处理内容块停止
  235. elif event.type == "content_block_stop":
  236. current_block_type = None
  237. current_tool_index = -1
  238. partial_json = ""
  239. # 如果没有工具调用,发送完成事件
  240. if not tool_use_blocks:
  241. yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info})}\n\n"
  242. return
  243. # 处理工具调用
  244. yield f"event: tools_start\ndata: {json_module.dumps({'count': len(tool_use_blocks)})}\n\n"
  245. # 为每个工具调用发送 tool_call 事件
  246. for tool_block in tool_use_blocks:
  247. yield f"event: tool_call\ndata: {json_module.dumps({'tool': tool_block['name'], 'args': tool_block['input'], 'tool_id': tool_block['id']})}\n\n"
  248. tool_results = await conv_manager.tool_handler.process_tool_use_blocks(
  249. tool_use_blocks
  250. )
  251. for tr in tool_results:
  252. tool_name = tr.get("tool_name", "")
  253. tool_result = tr.get("result", {})
  254. tool_use_id = tr.get("tool_use_id", "")
  255. # 发送工具完成事件
  256. if "error" in tool_result:
  257. yield f"event: tool_error\ndata: {json_module.dumps({'tool': tool_name, 'tool_id': tool_use_id, 'error': tool_result['error']})}\n\n"
  258. else:
  259. result_data = tool_result.get('result', '')
  260. # 限制结果长度避免传输过大
  261. if isinstance(result_data, str) and len(result_data) > 500:
  262. result_data = result_data[:500] + '...'
  263. yield f"event: tool_done\ndata: {json_module.dumps({'tool': tool_name, 'tool_id': tool_use_id, 'result': result_data})}\n\n"
  264. tool_calls_info.append({
  265. "tool": tool_name,
  266. "result": tool_result
  267. })
  268. # 构建工具结果消息
  269. tool_result_message = ToolCallHandler.create_tool_result_message(
  270. tool_results
  271. )
  272. # 添加到消息历史
  273. current_messages.append({
  274. "role": "assistant",
  275. "content": content_blocks
  276. })
  277. current_messages.append(tool_result_message)
  278. # 达到最大轮数
  279. yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info, 'warning': '达到最大对话轮数'})}\n\n"
  280. except Exception as e:
  281. import traceback
  282. yield f"event: error\ndata: {json_module.dumps({'error': str(e), 'traceback': traceback.format_exc()})}\n\n"
  283. @app.post("/api/chat/stream")
  284. async def chat_stream(request: Request):
  285. """
  286. 聊天端点 - 流式输出版本(解决超时问题)
  287. 使用 Server-Sent Events (SSE) 实时返回:
  288. 1. Claude 的思考过程
  289. 2. 工具调用状态
  290. 3. 最终响应
  291. """
  292. try:
  293. data = await request.json()
  294. message = data.get('message', '')
  295. conversation_history = data.get('history', [])
  296. session_id = request.headers.get('X-Session-ID')
  297. if not message:
  298. raise HTTPException(status_code=400, detail="Message is required")
  299. return StreamingResponse(
  300. generate_chat_stream(message, conversation_history, session_id),
  301. media_type="text/event-stream",
  302. headers={
  303. 'Cache-Control': 'no-cache',
  304. 'X-Accel-Buffering': 'no' # 禁用 Nginx 缓冲
  305. }
  306. )
  307. except HTTPException:
  308. raise
  309. except Exception as e:
  310. import traceback
  311. return JSONResponse(
  312. status_code=500,
  313. content={
  314. "error": str(e),
  315. "traceback": traceback.format_exc()
  316. }
  317. )
  318. # ========== MCP API ==========
  319. @app.get("/api/mcp/servers")
  320. async def list_mcp_servers():
  321. """获取已配置的 MCP 服务器列表"""
  322. servers = []
  323. for name, server in MCP_SERVERS.items():
  324. servers.append({
  325. "id": name,
  326. "name": server.get("name", name),
  327. "url": server.get("url", ""),
  328. "auth_type": server.get("auth_type", "none"),
  329. "enabled": server.get("enabled", False)
  330. })
  331. return {"servers": servers}
  332. @app.get("/api/mcp/tools")
  333. async def list_mcp_tools(x_session_id: Optional[str] = Header(None, alias='X-Session-ID')):
  334. """获取可用的 MCP 工具列表"""
  335. try:
  336. # 使用静态方法获取工具
  337. tools = ConversationManager.get_tools(session_id=x_session_id)
  338. return {
  339. "tools": tools,
  340. "count": len(tools)
  341. }
  342. except Exception as e:
  343. import traceback
  344. return JSONResponse(
  345. status_code=500,
  346. content={
  347. "error": str(e),
  348. "traceback": traceback.format_exc(),
  349. "tools": []
  350. }
  351. )
  352. # ========== 认证 API ==========
  353. @app.post("/api/auth/login")
  354. async def login(request: Request):
  355. """
  356. Novel Platform 用户登录
  357. 代理到实际的登录端点并返回 JWT Token
  358. """
  359. try:
  360. data = await request.json()
  361. username = data.get('username')
  362. password = data.get('password')
  363. if not username or not password:
  364. raise HTTPException(status_code=400, detail="Username and password are required")
  365. # 查找需要 JWT 认证的 MCP 服务器
  366. target_server = None
  367. for server_id, config in MCP_SERVERS.items():
  368. if config.get('auth_type') == 'jwt' and 'login_url' in config:
  369. target_server = config
  370. break
  371. if not target_server:
  372. raise HTTPException(status_code=400, detail="No JWT-authenticated server configured")
  373. # 构建登录 URL
  374. base_url = target_server.get('base_url', '')
  375. login_path = target_server.get('login_url', '/api/auth/login')
  376. login_url = f"{base_url}{login_path}"
  377. # 调用实际的登录接口(异步版本)
  378. async with httpx.AsyncClient(timeout=30.0) as http_client:
  379. response = await http_client.post(
  380. login_url,
  381. json={"username": username, "password": password}
  382. )
  383. if response.status_code == 200:
  384. result = response.json()
  385. session_id = str(uuid.uuid4())
  386. # 存储会话信息
  387. auth_sessions[session_id] = {
  388. "username": username,
  389. "token": result.get("token"),
  390. "refresh_token": result.get("refresh_token"),
  391. "server": target_server.get("name")
  392. }
  393. return {
  394. "success": True,
  395. "session_id": session_id,
  396. "username": username,
  397. "server": target_server.get("name"),
  398. "token": result.get("token")
  399. }
  400. else:
  401. raise HTTPException(
  402. status_code=response.status_code,
  403. detail=f"Login failed: {response.text}"
  404. )
  405. except HTTPException:
  406. raise
  407. except Exception as e:
  408. raise HTTPException(status_code=500, detail=str(e))
  409. @app.post("/api/auth/admin-login")
  410. async def admin_login(request: Request):
  411. """
  412. Novel Platform 管理员登录
  413. """
  414. try:
  415. data = await request.json()
  416. username = data.get('username')
  417. password = data.get('password')
  418. if not username or not password:
  419. raise HTTPException(status_code=400, detail="Username and password are required")
  420. # 查找管理员 MCP 服务器
  421. target_server = MCP_SERVERS.get('novel-platform-admin')
  422. if not target_server:
  423. raise HTTPException(status_code=400, detail="Admin server not configured")
  424. # 构建登录 URL
  425. base_url = target_server.get('base_url', '')
  426. login_path = target_server.get('login_url', '/api/auth/admin-login')
  427. login_url = f"{base_url}{login_path}"
  428. # 调用实际的登录接口(异步版本)
  429. async with httpx.AsyncClient(timeout=30.0) as http_client:
  430. response = await http_client.post(
  431. login_url,
  432. json={"username": username, "password": password}
  433. )
  434. if response.status_code == 200:
  435. result = response.json()
  436. session_id = str(uuid.uuid4())
  437. auth_sessions[session_id] = {
  438. "username": username,
  439. "token": result.get("token"),
  440. "refresh_token": result.get("refresh_token"),
  441. "server": target_server.get("name"),
  442. "role": "admin"
  443. }
  444. return {
  445. "success": True,
  446. "session_id": session_id,
  447. "username": username,
  448. "server": target_server.get("name"),
  449. "role": "admin",
  450. "token": result.get("token")
  451. }
  452. else:
  453. raise HTTPException(
  454. status_code=response.status_code,
  455. detail=f"Admin login failed: {response.text}"
  456. )
  457. except HTTPException:
  458. raise
  459. except Exception as e:
  460. raise HTTPException(status_code=500, detail=str(e))
  461. @app.post("/api/auth/logout")
  462. async def logout(request: Request):
  463. """登出并清除会话"""
  464. try:
  465. data = await request.json()
  466. session_id = data.get('session_id')
  467. if session_id and session_id in auth_sessions:
  468. del auth_sessions[session_id]
  469. return {"success": True}
  470. except Exception as e:
  471. raise HTTPException(status_code=500, detail=str(e))
  472. @app.get("/api/auth/status")
  473. async def auth_status(x_session_id: Optional[str] = Header(None, alias='X-Session-ID')):
  474. """检查认证状态"""
  475. if x_session_id and x_session_id in auth_sessions:
  476. session = auth_sessions[x_session_id]
  477. return {
  478. "authenticated": True,
  479. "username": session.get("username"),
  480. "server": session.get("server"),
  481. "role": session.get("role", "user")
  482. }
  483. return {"authenticated": False}
  484. # ========== 主程序入口 ==========
  485. if __name__ == '__main__':
  486. import uvicorn
  487. port = int(os.getenv('PORT', 8080))
  488. debug = os.getenv('DEBUG', 'False').lower() == 'true'
  489. uvicorn.run(
  490. "app_fastapi:app",
  491. host='0.0.0.0',
  492. port=port,
  493. reload=debug
  494. )