2
0

app.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. """
  2. AI MCP Web UI - Flask 后端
  3. 提供聊天界面与 MCP 工具调用的桥梁
  4. """
  5. import os
  6. import asyncio
  7. from typing import Optional, Dict
  8. from flask import Flask, request, jsonify, send_from_directory, Response
  9. import json as json_module
  10. from flask_cors import CORS
  11. import httpx
  12. from anthropic import Anthropic
  13. from config import MCP_SERVERS, ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL, ANTHROPIC_MODEL
  14. from conversation_manager import ConversationManager
  15. from tool_handler import ToolCallHandler
  16. app = Flask(__name__)
  17. CORS(app)
  18. app.secret_key = os.getenv('SECRET_KEY', 'dev-secret-key-change-in-production')
  19. # 存储认证会话 (生产环境应使用 Redis 或数据库)
  20. auth_sessions: Dict[str, dict] = {}
  21. def create_anthropic_client(api_key: str, base_url: str) -> Anthropic:
  22. """
  23. 创建 Anthropic 客户端,支持自定义认证格式
  24. 自定义 API 代理需要 'Authorization: Bearer <token>' 格式,
  25. 而不是 Anthropic SDK 默认的 'x-api-key' header。
  26. """
  27. import httpx
  28. # 创建自定义 httpx client,设置正确的 Authorization header
  29. http_client = httpx.Client(
  30. headers={"Authorization": f"Bearer {api_key}"},
  31. timeout=120.0
  32. )
  33. return Anthropic(base_url=base_url, http_client=http_client)
  34. @app.route('/')
  35. def index():
  36. return send_from_directory('../frontend', 'index.html')
  37. @app.route('/<path:path>')
  38. def static_files(path):
  39. return send_from_directory('../frontend', path)
  40. # 初始化 Claude 客户端(使用自定义认证格式)
  41. client = create_anthropic_client(
  42. api_key=ANTHROPIC_API_KEY,
  43. base_url=ANTHROPIC_BASE_URL
  44. )
  45. @app.route('/api/health', methods=['GET'])
  46. def health():
  47. """健康检查端点"""
  48. return jsonify({
  49. "status": "ok",
  50. "model": ANTHROPIC_MODEL,
  51. "mcp_servers": list(MCP_SERVERS.keys())
  52. })
  53. def run_async(coro):
  54. """在同步上下文中运行异步函数"""
  55. loop = asyncio.new_event_loop()
  56. asyncio.set_event_loop(loop)
  57. try:
  58. return loop.run_until_complete(coro)
  59. finally:
  60. loop.close()
  61. @app.route('/api/chat', methods=['POST'])
  62. def chat():
  63. """
  64. 聊天端点 - 接收用户消息,返回 Claude 响应(支持 MCP 工具调用)
  65. """
  66. try:
  67. data = request.json
  68. message = data.get('message', '')
  69. conversation_history = data.get('history', [])
  70. session_id = request.headers.get('X-Session-ID')
  71. if not message:
  72. return jsonify({"error": "Message is required"}), 400
  73. # 创建对话管理器
  74. conv_manager = ConversationManager(
  75. api_key=ANTHROPIC_API_KEY,
  76. base_url=ANTHROPIC_BASE_URL,
  77. model=ANTHROPIC_MODEL,
  78. session_id=session_id
  79. )
  80. # 格式化对话历史
  81. formatted_history = ConversationManager.format_history_for_claude(conversation_history)
  82. # 执行多轮对话(自动处理工具调用)
  83. result = run_async(conv_manager.chat(
  84. user_message=message,
  85. conversation_history=formatted_history,
  86. max_turns=5
  87. ))
  88. # 提取响应文本
  89. response_text = result.get("response", "")
  90. tool_calls = result.get("tool_calls", [])
  91. return jsonify({
  92. "response": response_text,
  93. "model": ANTHROPIC_MODEL,
  94. "tool_calls": tool_calls,
  95. "has_tools": len(tool_calls) > 0
  96. })
  97. except Exception as e:
  98. import traceback
  99. return jsonify({
  100. "error": str(e),
  101. "traceback": traceback.format_exc()
  102. }), 500
  103. @app.route('/api/chat/stream', methods=['POST'])
  104. def chat_stream():
  105. """
  106. 聊天端点 - 流式输出版本(解决超时问题)
  107. 使用 Server-Sent Events (SSE) 实时返回:
  108. 1. Claude 的思考过程
  109. 2. 工具调用状态
  110. 3. 最终响应
  111. """
  112. try:
  113. data = request.json
  114. message = data.get('message', '')
  115. conversation_history = data.get('history', [])
  116. session_id = request.headers.get('X-Session-ID')
  117. if not message:
  118. return jsonify({"error": "Message is required"}), 400
  119. def generate():
  120. """生成 SSE 流式响应"""
  121. try:
  122. # 发送开始事件
  123. yield f"event: start\ndata: {json_module.dumps({'status': 'started'})}\n\n"
  124. # 创建对话管理器
  125. conv_manager = ConversationManager(
  126. api_key=ANTHROPIC_API_KEY,
  127. base_url=ANTHROPIC_BASE_URL,
  128. model=ANTHROPIC_MODEL,
  129. session_id=session_id
  130. )
  131. # 格式化对话历史
  132. formatted_history = ConversationManager.format_history_for_claude(conversation_history)
  133. messages = formatted_history + [{"role": "user", "content": message}]
  134. current_messages = messages
  135. tool_calls_info = []
  136. for turn in range(5): # 最多 5 轮
  137. # 获取可用工具
  138. tools = run_async(conv_manager.get_available_tools())
  139. # 发送工具列表
  140. yield f"event: tools\ndata: {json_module.dumps({'count': len(tools), 'tools': [t['name'] for t in tools[:5]]})}\n\n"
  141. # 调用 Claude API(流式)
  142. if tools:
  143. response_stream = conv_manager.client.messages.create(
  144. model=conv_manager.model,
  145. max_tokens=4096,
  146. messages=current_messages,
  147. tools=tools,
  148. stream=True
  149. )
  150. else:
  151. response_stream = conv_manager.client.messages.create(
  152. model=conv_manager.model,
  153. max_tokens=4096,
  154. messages=current_messages,
  155. stream=True
  156. )
  157. # 处理流式响应
  158. content_blocks = []
  159. tool_use_blocks = []
  160. response_text = ""
  161. current_block_type = None
  162. current_tool_index = -1
  163. partial_json = ""
  164. for event in response_stream:
  165. # 处理内容块开始 - 检查是否是工具调用
  166. if event.type == "content_block_start":
  167. # 检查块的类型
  168. if hasattr(event, "content_block"):
  169. current_block_type = getattr(event.content_block, "type", None)
  170. if current_block_type == "tool_use":
  171. # 这是工具调用块的开始
  172. tool_use_id = getattr(event.content_block, "id", "")
  173. # content_block 包含 name
  174. tool_name = getattr(event.content_block, "name", "")
  175. tool_use_blocks.append({
  176. "type": "tool_use",
  177. "id": tool_use_id,
  178. "name": tool_name,
  179. "input": {}
  180. })
  181. current_tool_index = len(tool_use_blocks) - 1
  182. partial_json = ""
  183. # 处理内容块增量
  184. elif event.type == "content_block_delta":
  185. delta_type = getattr(event.delta, "type", "")
  186. # 文本增量
  187. if delta_type == "text_delta":
  188. text = event.delta.text
  189. response_text += text
  190. yield f"event: token\ndata: {json_module.dumps({'text': text})}\n\n"
  191. # 工具名称增量
  192. elif delta_type == "tool_use_delta":
  193. # 获取工具名称和参数增量
  194. delta_name = getattr(event.delta, "name", None)
  195. delta_input = getattr(event.delta, "input", None)
  196. if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
  197. if delta_name is not None:
  198. tool_use_blocks[current_tool_index]["name"] = delta_name
  199. if delta_input is not None:
  200. # 更新输入参数
  201. current_input = tool_use_blocks[current_tool_index]["input"]
  202. if isinstance(delta_input, dict):
  203. current_input.update(delta_input)
  204. tool_use_blocks[current_tool_index]["input"] = current_input
  205. # 工具参数增量 - input_json_delta
  206. elif delta_type == "input_json_delta":
  207. # 累积 partial_json 构建完整参数
  208. partial_json_str = getattr(event.delta, "partial_json", "")
  209. if partial_json_str:
  210. partial_json += partial_json_str
  211. try:
  212. # 尝试解析累积的 JSON
  213. parsed_input = json_module.loads(partial_json)
  214. if current_tool_index >= 0 and current_tool_index < len(tool_use_blocks):
  215. tool_use_blocks[current_tool_index]["input"] = parsed_input
  216. except json_module.JSONDecodeError:
  217. # JSON 还不完整,继续累积
  218. pass
  219. # 处理内容块停止
  220. elif event.type == "content_block_stop":
  221. current_block_type = None
  222. current_tool_index = -1
  223. partial_json = ""
  224. # 如果没有工具调用,发送完成事件
  225. if not tool_use_blocks:
  226. yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info})}\n\n"
  227. return
  228. # 处理工具调用
  229. yield f"event: tools_start\ndata: {json_module.dumps({'count': len(tool_use_blocks)})}\n\n"
  230. # 为每个工具调用发送 tool_call 事件
  231. for tool_block in tool_use_blocks:
  232. yield f"event: tool_call\ndata: {json_module.dumps({'tool': tool_block['name'], 'args': tool_block['input'], 'tool_id': tool_block['id']})}\n\n"
  233. tool_results = run_async(conv_manager.tool_handler.process_tool_use_blocks(
  234. tool_use_blocks
  235. ))
  236. for i, tr in enumerate(tool_results):
  237. tool_name = tr.get("tool_name", "")
  238. tool_result = tr.get("result", {})
  239. tool_use_id = tr.get("tool_use_id", "")
  240. # 发送工具完成事件
  241. if "error" in tool_result:
  242. yield f"event: tool_error\ndata: {json_module.dumps({'tool': tool_name, 'tool_id': tool_use_id, 'error': tool_result['error']})}\n\n"
  243. else:
  244. result_data = tool_result.get('result', '')
  245. # 限制结果长度避免传输过大
  246. if isinstance(result_data, str) and len(result_data) > 500:
  247. result_data = result_data[:500] + '...'
  248. yield f"event: tool_done\ndata: {json_module.dumps({'tool': tool_name, 'tool_id': tool_use_id, 'result': result_data})}\n\n"
  249. tool_calls_info.append({
  250. "tool": tool_name,
  251. "result": tool_result
  252. })
  253. # 构建工具结果消息
  254. tool_result_message = ToolCallHandler.create_tool_result_message(
  255. tool_results
  256. )
  257. # 添加到消息历史
  258. current_messages.append({
  259. "role": "assistant",
  260. "content": content_blocks
  261. })
  262. current_messages.append(tool_result_message)
  263. # 达到最大轮数
  264. yield f"event: complete\ndata: {json_module.dumps({'response': response_text, 'tool_calls': tool_calls_info, 'warning': '达到最大对话轮数'})}\n\n"
  265. except Exception as e:
  266. import traceback
  267. yield f"event: error\ndata: {json_module.dumps({'error': str(e), 'traceback': traceback.format_exc()})}\n\n"
  268. return Response(
  269. generate(),
  270. mimetype='text/event-stream',
  271. headers={
  272. 'Cache-Control': 'no-cache',
  273. 'X-Accel-Buffering': 'no' # 禁用 Nginx 缓冲
  274. }
  275. )
  276. except Exception as e:
  277. import traceback
  278. return jsonify({
  279. "error": str(e),
  280. "traceback": traceback.format_exc()
  281. }), 500
  282. @app.route('/api/mcp/servers', methods=['GET'])
  283. def list_mcp_servers():
  284. """获取已配置的 MCP 服务器列表"""
  285. servers = []
  286. for name, server in MCP_SERVERS.items():
  287. servers.append({
  288. "id": name,
  289. "name": server.get("name", name),
  290. "url": server.get("url", ""),
  291. "auth_type": server.get("auth_type", "none"),
  292. "enabled": server.get("enabled", False)
  293. })
  294. return jsonify({"servers": servers})
  295. @app.route('/api/mcp/tools', methods=['GET'])
  296. def list_mcp_tools():
  297. """获取可用的 MCP 工具列表"""
  298. try:
  299. session_id = request.headers.get('X-Session-ID')
  300. # 使用静态方法获取工具
  301. tools = ConversationManager.get_tools(session_id=session_id)
  302. return jsonify({
  303. "tools": tools,
  304. "count": len(tools)
  305. })
  306. except Exception as e:
  307. import traceback
  308. return jsonify({
  309. "error": str(e),
  310. "traceback": traceback.format_exc(),
  311. "tools": []
  312. }), 500
  313. # ========== 认证 API ==========
  314. @app.route('/api/auth/login', methods=['POST'])
  315. def login():
  316. """
  317. Novel Platform 用户登录
  318. 代理到实际的登录端点并返回 JWT Token
  319. """
  320. try:
  321. data = request.json
  322. username = data.get('username')
  323. password = data.get('password')
  324. if not username or not password:
  325. return jsonify({"error": "Username and password are required"}), 400
  326. # 查找需要 JWT 认证的 MCP 服务器
  327. target_server = None
  328. for server_id, config in MCP_SERVERS.items():
  329. if config.get('auth_type') == 'jwt' and 'login_url' in config:
  330. target_server = config
  331. break
  332. if not target_server:
  333. return jsonify({"error": "No JWT-authenticated server configured"}), 400
  334. # 构建登录 URL
  335. base_url = target_server.get('base_url', '')
  336. login_path = target_server.get('login_url', '/api/auth/login')
  337. login_url = f"{base_url}{login_path}"
  338. # 调用实际的登录接口(同步版本)
  339. response = httpx.post(
  340. login_url,
  341. json={"username": username, "password": password},
  342. timeout=30.0
  343. )
  344. if response.status_code == 200:
  345. result = response.json()
  346. import uuid
  347. session_id = str(uuid.uuid4())
  348. # 存储会话信息
  349. auth_sessions[session_id] = {
  350. "username": username,
  351. "token": result.get("token"),
  352. "refresh_token": result.get("refresh_token"),
  353. "server": target_server.get("name")
  354. }
  355. return jsonify({
  356. "success": True,
  357. "session_id": session_id,
  358. "username": username,
  359. "server": target_server.get("name"),
  360. "token": result.get("token")
  361. })
  362. else:
  363. return jsonify({
  364. "error": "Login failed",
  365. "details": response.text
  366. }), response.status_code
  367. except Exception as e:
  368. return jsonify({"error": str(e)}), 500
  369. @app.route('/api/auth/admin-login', methods=['POST'])
  370. def admin_login():
  371. """
  372. Novel Platform 管理员登录
  373. """
  374. try:
  375. data = request.json
  376. username = data.get('username')
  377. password = data.get('password')
  378. if not username or not password:
  379. return jsonify({"error": "Username and password are required"}), 400
  380. # 查找管理员 MCP 服务器
  381. target_server = MCP_SERVERS.get('novel-platform-admin')
  382. if not target_server:
  383. return jsonify({"error": "Admin server not configured"}), 400
  384. # 构建登录 URL
  385. base_url = target_server.get('base_url', '')
  386. login_path = target_server.get('login_url', '/api/auth/admin-login')
  387. login_url = f"{base_url}{login_path}"
  388. # 调用实际的登录接口
  389. response = httpx.post(
  390. login_url,
  391. json={"username": username, "password": password},
  392. timeout=30.0
  393. )
  394. if response.status_code == 200:
  395. result = response.json()
  396. import uuid
  397. session_id = str(uuid.uuid4())
  398. auth_sessions[session_id] = {
  399. "username": username,
  400. "token": result.get("token"),
  401. "refresh_token": result.get("refresh_token"),
  402. "server": target_server.get("name"),
  403. "role": "admin"
  404. }
  405. return jsonify({
  406. "success": True,
  407. "session_id": session_id,
  408. "username": username,
  409. "server": target_server.get("name"),
  410. "role": "admin",
  411. "token": result.get("token")
  412. })
  413. else:
  414. return jsonify({
  415. "error": "Admin login failed",
  416. "details": response.text
  417. }), response.status_code
  418. except Exception as e:
  419. return jsonify({"error": str(e)}), 500
  420. @app.route('/api/auth/logout', methods=['POST'])
  421. def logout():
  422. """登出并清除会话"""
  423. try:
  424. data = request.json
  425. session_id = data.get('session_id')
  426. if session_id and session_id in auth_sessions:
  427. del auth_sessions[session_id]
  428. return jsonify({"success": True})
  429. except Exception as e:
  430. return jsonify({"error": str(e)}), 500
  431. @app.route('/api/auth/status', methods=['GET'])
  432. def auth_status():
  433. """检查认证状态"""
  434. session_id = request.headers.get('X-Session-ID')
  435. if session_id and session_id in auth_sessions:
  436. session = auth_sessions[session_id]
  437. return jsonify({
  438. "authenticated": True,
  439. "username": session.get("username"),
  440. "server": session.get("server"),
  441. "role": session.get("role", "user")
  442. })
  443. return jsonify({
  444. "authenticated": False
  445. })
  446. if __name__ == '__main__':
  447. port = int(os.getenv('PORT', 8080))
  448. debug = os.getenv('DEBUG', 'False').lower() == 'true'
  449. app.run(host='0.0.0.0', port=port, debug=debug)