2
0

hooks.ts 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. /**
  2. * 自定义 React Hooks - 用于与 FastAPI 后端交互
  3. */
  4. 'use client';
  5. import { useState, useCallback, useRef, useEffect } from 'react';
  6. import { createMixedStreamParser, createSpecStreamCompiler, type SpecStreamLine } from '@json-render/core';
  7. import type { Spec } from '@json-render/core';
  8. import { apiClient, type ChatMessage, type UserInfo, type UserRole } from './api-client';
  9. // SSE 事件类型
  10. export type SSEEventType =
  11. | 'start'
  12. | 'token'
  13. | 'tools'
  14. | 'tools_start'
  15. | 'tool_call'
  16. | 'tool_done'
  17. | 'tool_error'
  18. | 'complete'
  19. | 'error';
  20. export interface SSEEvent {
  21. type: SSEEventType;
  22. data: unknown;
  23. }
  24. // 聊天 Hook
  25. export function useChat() {
  26. const [isLoading, setIsLoading] = useState(false);
  27. const [response, setResponse] = useState('');
  28. const [toolCalls, setToolCalls] = useState<Array<{ tool: string; tool_id?: string; args?: unknown; result: unknown }>>([]);
  29. const [specs, setSpecs] = useState<any[]>([]);
  30. const [error, setError] = useState<string | null>(null);
  31. const abortRef = useRef<(() => void) | null>(null);
  32. const toolCallsRef = useRef<Array<{ tool: string; result: unknown }>>([]);
  33. // 流式编译器相关 refs
  34. const streamCompilerRef = useRef<ReturnType<typeof createSpecStreamCompiler> | null>(null);
  35. const mixedParserRef = useRef<ReturnType<typeof createMixedStreamParser> | null>(null);
  36. const textBufferRef = useRef<string>('');
  37. // 保持 ref 同步
  38. useEffect(() => {
  39. toolCallsRef.current = toolCalls;
  40. console.log('[useEffect] toolCalls updated:', toolCalls);
  41. }, [toolCalls]);
  42. // 监控 response 状态变化,并在完成时提取 JSON spec
  43. useEffect(() => {
  44. console.log('[useEffect] response updated:', {
  45. length: response.length,
  46. preview: response.substring(0, 100),
  47. });
  48. // 只在加载完成且有响应内容时尝试提取 spec
  49. // 这确保在流式输出完成后,从 response 中提取 JSON
  50. if (!isLoading && response && response.trim()) {
  51. console.log('[useEffect] Loading complete, extracting JSON from response...');
  52. const extractedData = extractJsonFromMarkdown(response);
  53. console.log('[useEffect] Extracted data:', extractedData);
  54. if (extractedData) {
  55. const newSpecs: any[] = [];
  56. // 检查是否为有效的组件 spec(有 type 字段)
  57. if (extractedData.type && typeof extractedData.type === 'string') {
  58. newSpecs.push(extractedData);
  59. console.log('[useEffect] Added spec from response:', extractedData);
  60. }
  61. // 如果是数组,检查每个元素
  62. else if (Array.isArray(extractedData)) {
  63. for (const item of extractedData) {
  64. if (item && item.type && typeof item.type === 'string') {
  65. newSpecs.push(item);
  66. console.log('[useEffect] Added spec from response array:', item);
  67. }
  68. }
  69. }
  70. if (newSpecs.length > 0) {
  71. console.log('[useEffect] Setting specs from response:', newSpecs);
  72. setSpecs(newSpecs);
  73. }
  74. }
  75. }
  76. }, [response, isLoading]); // extractJsonFromMarkdown 依赖为空数组,不需要在依赖项中
  77. /**
  78. * 从 markdown 代码块中提取 JSON
  79. * 支持格式:```json ... ``` 或 ``` ... ```
  80. */
  81. const extractJsonFromMarkdown = useCallback((text: string): any | null => {
  82. if (typeof text !== 'string') {
  83. return null;
  84. }
  85. // 1. 匹配 ```json ... ``` 代码块(先提取整个代码块内容,不限制格式)
  86. const jsonCodeBlockMatch = text.match(/```json\s*([\s\S]*?)\s*```/);
  87. if (jsonCodeBlockMatch) {
  88. try {
  89. const jsonStr = jsonCodeBlockMatch[1].trim();
  90. console.log('[extractJsonFromMarkdown] Found json code block, parsing:', jsonStr.substring(0, 100));
  91. const parsed = JSON.parse(jsonStr);
  92. console.log('[extractJsonFromMarkdown] Parsed successfully:', parsed);
  93. return parsed;
  94. } catch (e) {
  95. console.error('[extractJsonFromMarkdown] Failed to parse JSON code block:', e);
  96. // 继续尝试其他方式
  97. }
  98. }
  99. // 2. 匹配 ``` ... ``` 代码块(无语言标识)
  100. const codeBlockMatch = text.match(/```\s*([\s\S]*?)\s*```/);
  101. if (codeBlockMatch) {
  102. try {
  103. const jsonStr = codeBlockMatch[1].trim();
  104. console.log('[extractJsonFromMarkdown] Found code block, trying to parse:', jsonStr.substring(0, 100));
  105. const parsed = JSON.parse(jsonStr);
  106. console.log('[extractJsonFromMarkdown] Parsed successfully:', parsed);
  107. return parsed;
  108. } catch (e) {
  109. console.log('[extractJsonFromMarkdown] Code block is not JSON, continuing...');
  110. // 不是 JSON,继续尝试其他方式
  111. }
  112. }
  113. // 3. 尝试直接解析整个文本
  114. try {
  115. console.log('[extractJsonFromMarkdown] Trying to parse entire text as JSON');
  116. const parsed = JSON.parse(text);
  117. console.log('[extractJsonFromMarkdown] Parsed successfully:', parsed);
  118. return parsed;
  119. } catch {
  120. console.log('[extractJsonFromMarkdown] Failed to parse as JSON');
  121. return null;
  122. }
  123. }, []);
  124. // 生成 specs 的通用函数
  125. const generateSpecs = useCallback(() => {
  126. const currentToolCalls = toolCallsRef.current;
  127. console.log('[generateSpecs] Current toolCalls:', currentToolCalls);
  128. console.log('[generateSpecs] Current response:', response?.substring(0, 200));
  129. const newSpecs: any[] = [];
  130. // 1. 从 toolCalls 生成 specs(原有逻辑)
  131. if (currentToolCalls.length > 0) {
  132. const { specFromToolCall } = require('@/lib/json-render-catalog');
  133. const toolSpecs = currentToolCalls.map((call) => specFromToolCall(call.tool, call.result)).filter(Boolean);
  134. newSpecs.push(...toolSpecs);
  135. console.log('[generateSpecs] Generated specs from toolCalls:', toolSpecs);
  136. }
  137. // 2. 从 response 中提取 JSON 并生成 specs(新增逻辑)
  138. if (response && response.trim()) {
  139. const extractedData = extractJsonFromMarkdown(response);
  140. console.log('[generateSpecs] Extracted JSON from response:', extractedData);
  141. if (extractedData) {
  142. // 检查是否为有效的组件 spec(有 type 字段)
  143. if (extractedData.type && typeof extractedData.type === 'string') {
  144. newSpecs.push(extractedData);
  145. console.log('[generateSpecs] Added spec from response:', extractedData);
  146. }
  147. // 如果是数组,检查每个元素
  148. else if (Array.isArray(extractedData)) {
  149. for (const item of extractedData) {
  150. if (item && item.type && typeof item.type === 'string') {
  151. newSpecs.push(item);
  152. console.log('[generateSpecs] Added spec from response array:', item);
  153. }
  154. }
  155. }
  156. }
  157. }
  158. if (newSpecs.length > 0) {
  159. console.log('[generateSpecs] Setting specs:', newSpecs);
  160. setSpecs(newSpecs);
  161. }
  162. // 使用 setTimeout 确保 specs 状态更新先被处理
  163. setTimeout(() => {
  164. setIsLoading(false);
  165. console.log('[generateSpecs] Set isLoading to false');
  166. }, 0);
  167. }, [response]); // extractJsonFromMarkdown 依赖为空数组,不需要在依赖项中
  168. const sendMessage = useCallback(async (message: string, history: ChatMessage[] = []) => {
  169. setIsLoading(true);
  170. setResponse('');
  171. setToolCalls([]);
  172. setSpecs([]);
  173. setError(null);
  174. // 初始化流式编译器和混合解析器
  175. streamCompilerRef.current = createSpecStreamCompiler<Spec>();
  176. textBufferRef.current = '';
  177. mixedParserRef.current = createMixedStreamParser({
  178. onPatch: (patch: SpecStreamLine) => {
  179. console.log('[MixedParser] Received patch:', patch);
  180. // 直接使用 compiler 处理 patch
  181. if (streamCompilerRef.current) {
  182. const { result } = streamCompilerRef.current.push(JSON.stringify(patch) + '\n');
  183. console.log('[MixedParser] Compiler result:', result);
  184. // 更新 specs 状态
  185. if (result && typeof result === 'object') {
  186. setSpecs([result]);
  187. }
  188. }
  189. },
  190. onText: (text: string) => {
  191. console.log('[MixedParser] Received text:', text.substring(0, 50));
  192. textBufferRef.current += text;
  193. setResponse(textBufferRef.current);
  194. }
  195. });
  196. try {
  197. abortRef.current = await apiClient.chatStreamFetch(
  198. message,
  199. history,
  200. (event) => {
  201. try {
  202. console.log('[SSE Raw] type:', event.type, 'data:', event.data?.substring(0, 100));
  203. const data = JSON.parse(event.data);
  204. // 将 SSE 事件类型添加到数据中,以便 handleSSEEvent 可以访问
  205. (data as any).type = event.type;
  206. handleSSEEvent(data);
  207. } catch (e) {
  208. console.error('Failed to parse SSE data:', e);
  209. }
  210. },
  211. (err) => {
  212. console.error('[sendMessage] Error:', err);
  213. setError(err.message);
  214. // 即使出错也生成 specs
  215. generateSpecs();
  216. },
  217. () => {
  218. // 对话完成时,生成最终的 specs
  219. generateSpecs();
  220. }
  221. );
  222. } catch (err) {
  223. setError(err instanceof Error ? err.message : 'Unknown error');
  224. setIsLoading(false);
  225. }
  226. }, [generateSpecs]);
  227. const handleSSEEvent = (data: unknown) => {
  228. const event = data as { type?: string; [key: string]: unknown };
  229. // 调试日志
  230. console.log('[SSE Event]', event.type, JSON.stringify(event).substring(0, 200));
  231. switch (event.type) {
  232. case 'token':
  233. const tokenData = event as { text?: string };
  234. if (tokenData.text && mixedParserRef.current) {
  235. // 使用混合解析器处理流式数据
  236. mixedParserRef.current.push(tokenData.text);
  237. }
  238. break;
  239. case 'tool_call':
  240. const toolCallData = event as { tool?: string; args?: unknown; tool_id?: string };
  241. if (toolCallData.tool) {
  242. const newToolCall = {
  243. tool: toolCallData.tool!,
  244. tool_id: toolCallData.tool_id,
  245. args: toolCallData.args, // 保存原始参数
  246. result: toolCallData.args, // 初始 result 设为 args
  247. };
  248. setToolCalls((prev) => [...prev, newToolCall]);
  249. console.log('[Tool Call] Added:', newToolCall);
  250. }
  251. break;
  252. case 'tool_done':
  253. const toolDoneData = event as { tool?: string; tool_id?: string; result?: unknown };
  254. console.log('[Tool Done] Raw event:', toolDoneData);
  255. console.log('[Tool Done] Result type:', typeof toolDoneData.result, 'Result:', toolDoneData.result);
  256. // 优先使用 tool_id 匹配,回退到 tool 名称匹配
  257. if (toolDoneData.tool_id || toolDoneData.tool) {
  258. setToolCalls((prev) => {
  259. console.log('[Tool Done] Current toolCalls before update:', prev);
  260. console.log('[Tool Done] Looking for tool_id:', toolDoneData.tool_id, 'type:', typeof toolDoneData.tool_id);
  261. const updated = [...prev];
  262. // 使用 tool_id 查找
  263. let index = updated.findIndex((t) => {
  264. console.log('[Tool Done] Comparing:', t.tool_id, '===', toolDoneData.tool_id, 'result:', t.tool_id === toolDoneData.tool_id);
  265. return t.tool_id === toolDoneData.tool_id;
  266. });
  267. console.log('[Tool Done] Found index:', index);
  268. // 回退到 tool 名称查找
  269. if (index < 0 && toolDoneData.tool) {
  270. index = updated.findIndex((t) => t.tool === toolDoneData.tool);
  271. console.log('[Tool Done] Found by tool name index:', index);
  272. }
  273. if (index >= 0) {
  274. // 保留 args,只更新 result
  275. const existing = updated[index];
  276. updated[index] = {
  277. ...existing,
  278. result: toolDoneData.result,
  279. // 确保 args 不被覆盖
  280. args: existing.args,
  281. };
  282. console.log('[Tool Done] Updated tool call:', updated[index]);
  283. } else {
  284. console.log('[Tool Done] No matching tool_call found for:', toolDoneData);
  285. }
  286. return updated;
  287. });
  288. }
  289. break;
  290. case 'complete':
  291. const completeData = event as { response?: string; tool_calls?: unknown };
  292. console.log('[Complete]', completeData);
  293. // 刷新混合解析器缓冲区
  294. if (mixedParserRef.current) {
  295. mixedParserRef.current.flush();
  296. }
  297. console.log('[Complete] Response field:', completeData.response?.substring(0, 100));
  298. // 使用函数式更新来避免闭包陷阱
  299. setResponse((prev) => {
  300. // 如果 complete 事件有响应,优先使用它(因为它包含完整响应)
  301. // 否则保留从 token 事件构建的响应
  302. if (completeData.response) {
  303. console.log('[Complete] Using complete response, replacing prev with length:', prev.length);
  304. return completeData.response;
  305. }
  306. console.log('[Complete] No response in complete event, keeping prev with length:', prev.length);
  307. return prev;
  308. });
  309. // 不在这里设置 isLoading(false),让 onComplete 回调处理
  310. // 这样可以确保 specs 在 isLoading 变为 false 之前生成
  311. break;
  312. case 'error':
  313. const errorData = event as { error?: string };
  314. console.error('[SSE Error]', errorData);
  315. setError(errorData.error || 'Unknown error');
  316. // 不在这里设置 isLoading(false),让 stream 的 onError/onComplete 处理
  317. // 这样可以确保 specs 在 isLoading 变为 false 之前生成
  318. break;
  319. }
  320. };
  321. const abort = useCallback(() => {
  322. abortRef.current?.();
  323. setIsLoading(false);
  324. }, []);
  325. return {
  326. sendMessage,
  327. abort,
  328. isLoading,
  329. response,
  330. toolCalls,
  331. specs,
  332. error,
  333. };
  334. }
  335. // 认证 Hook
  336. export function useAuth() {
  337. const [isAuthenticated, setIsAuthenticated] = useState(false);
  338. const [username, setUsername] = useState<string | null>(null);
  339. const [role, setRole] = useState<UserRole | null>(null);
  340. const [isLoading, setIsLoading] = useState(false);
  341. // 从 localStorage 恢复认证状态
  342. useEffect(() => {
  343. const storedSessionId = localStorage.getItem('session_id');
  344. const storedUserInfo = localStorage.getItem('userInfo');
  345. if (storedSessionId) {
  346. apiClient.setSession(storedSessionId);
  347. setIsAuthenticated(true);
  348. }
  349. if (storedUserInfo) {
  350. try {
  351. const userInfo: UserInfo = JSON.parse(storedUserInfo);
  352. setUsername(userInfo.username);
  353. setRole(userInfo.role);
  354. apiClient.setUserInfo(userInfo);
  355. } catch (e) {
  356. console.error('Failed to parse stored userInfo:', e);
  357. }
  358. }
  359. }, []);
  360. // 根据角色获取 MCP URL
  361. const getMcpUrl = useCallback((): string => {
  362. return apiClient.getMcpUrl();
  363. }, []);
  364. const login = useCallback(async (email: string, password: string) => {
  365. setIsLoading(true);
  366. try {
  367. console.log('[useAuth] Starting login:', email);
  368. const response = await apiClient.login(email, password);
  369. console.log('[useAuth] Login response:', response);
  370. // 检查 session_id 而非 success 字段
  371. if (response.session_id) {
  372. console.log('[useAuth] Login successful, setting state...');
  373. setIsAuthenticated(true);
  374. setUsername(response.username);
  375. setRole(response.role);
  376. // 持久化到 localStorage
  377. localStorage.setItem('session_id', response.session_id);
  378. localStorage.setItem('username', response.username);
  379. const userInfo: UserInfo = {
  380. username: response.username,
  381. role: response.role,
  382. };
  383. localStorage.setItem('userInfo', JSON.stringify(userInfo));
  384. console.log('[useAuth] State updated, userInfo:', userInfo);
  385. return response;
  386. }
  387. throw new Error('Login failed: No session_id returned');
  388. } finally {
  389. setIsLoading(false);
  390. }
  391. }, []);
  392. const register = useCallback(async (email: string, username: string, password: string) => {
  393. setIsLoading(true);
  394. try {
  395. // 先注册
  396. const registerResponse = await apiClient.register(email, username, password);
  397. if (!registerResponse.success) {
  398. throw new Error(registerResponse.message || 'Registration failed');
  399. }
  400. // 注册成功后自动登录
  401. const loginResponse = await apiClient.login(email, password);
  402. if (loginResponse.session_id) {
  403. setIsAuthenticated(true);
  404. setUsername(loginResponse.username);
  405. setRole(loginResponse.role);
  406. // 持久化到 localStorage
  407. localStorage.setItem('session_id', loginResponse.session_id);
  408. localStorage.setItem('username', loginResponse.username);
  409. const userInfo: UserInfo = {
  410. username: loginResponse.username,
  411. role: loginResponse.role,
  412. };
  413. localStorage.setItem('userInfo', JSON.stringify(userInfo));
  414. return { ...registerResponse, session: loginResponse };
  415. }
  416. throw new Error('Auto-login after registration failed');
  417. } finally {
  418. setIsLoading(false);
  419. }
  420. }, []);
  421. const logout = useCallback(async () => {
  422. setIsLoading(true);
  423. try {
  424. await apiClient.logout();
  425. setIsAuthenticated(false);
  426. setUsername(null);
  427. setRole(null);
  428. // 清除 localStorage
  429. localStorage.removeItem('session_id');
  430. localStorage.removeItem('username');
  431. localStorage.removeItem('userInfo');
  432. } finally {
  433. setIsLoading(false);
  434. }
  435. }, []);
  436. const checkStatus = useCallback(async () => {
  437. try {
  438. const status = await apiClient.authStatus();
  439. setIsAuthenticated(status.authenticated);
  440. setUsername(status.username || null);
  441. setRole((status.role as UserRole) || null);
  442. return status;
  443. } catch (err) {
  444. setIsAuthenticated(false);
  445. setUsername(null);
  446. setRole(null);
  447. return null;
  448. }
  449. }, []);
  450. return {
  451. isAuthenticated,
  452. username,
  453. role,
  454. isLoading,
  455. login,
  456. register,
  457. logout,
  458. checkStatus,
  459. getMcpUrl,
  460. };
  461. }
  462. // MCP 工具 Hook
  463. export function useMcpTools() {
  464. const [tools, setTools] = useState<Array<{ name: string; description: string }>>([]);
  465. const [isLoading, setIsLoading] = useState(false);
  466. const fetchTools = useCallback(async () => {
  467. setIsLoading(true);
  468. try {
  469. const response = await apiClient.listMcpTools();
  470. setTools(response.tools);
  471. return response;
  472. } finally {
  473. setIsLoading(false);
  474. }
  475. }, []);
  476. return {
  477. tools,
  478. isLoading,
  479. fetchTools,
  480. };
  481. }