test_core_persistence.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. """
  2. Unit tests for state machine persistence.
  3. Tests cover saving, loading, atomic writes, and error handling.
  4. """
  5. import json
  6. import tempfile
  7. from pathlib import Path
  8. import pytest
  9. from src.core.states import PipelineState
  10. from src.core.state_machine import StateMachine, TransitionEvent
  11. from src.core.persistence import (
  12. StateMachinePersistence,
  13. StateMachinePersistenceError,
  14. PersistentStateMachine,
  15. PersistentTransitionEvent,
  16. )
  17. class TestPersistentTransitionEvent:
  18. """Test PersistentTransitionEvent dataclass."""
  19. def test_creation(self):
  20. """Test creating a PersistentTransitionEvent."""
  21. event = PersistentTransitionEvent(
  22. from_state="idle",
  23. to_state="fingerprinting",
  24. context={"file": "test.txt"},
  25. )
  26. assert event.from_state == "idle"
  27. assert event.to_state == "fingerprinting"
  28. assert event.context == {"file": "test.txt"}
  29. def test_default_timestamp(self):
  30. """Test that timestamp is set automatically."""
  31. event = PersistentTransitionEvent(
  32. from_state="idle",
  33. to_state="fingerprinting",
  34. )
  35. assert event.timestamp != ""
  36. assert len(event.timestamp) > 0
  37. class TestPersistentStateMachine:
  38. """Test PersistentStateMachine dataclass."""
  39. def test_default_values(self):
  40. """Test default values."""
  41. psm = PersistentStateMachine()
  42. assert psm.state == "idle"
  43. assert psm.context == {}
  44. assert psm.history == []
  45. assert psm.metadata == {}
  46. def test_to_dict(self):
  47. """Test converting to dictionary."""
  48. psm = PersistentStateMachine(
  49. state="translating",
  50. context={"progress": 50},
  51. history=[
  52. PersistentTransitionEvent(
  53. from_state="idle",
  54. to_state="translating",
  55. )
  56. ],
  57. )
  58. data = psm.to_dict()
  59. assert data["state"] == "translating"
  60. assert data["context"]["progress"] == 50
  61. assert len(data["history"]) == 1
  62. def test_from_dict(self):
  63. """Test creating from dictionary."""
  64. data = {
  65. "version": "1.0",
  66. "state": "translating",
  67. "context": {"progress": 50},
  68. "history": [
  69. {
  70. "from_state": "idle",
  71. "to_state": "translating",
  72. "context": {},
  73. "timestamp": "2026-03-15T10:00:00",
  74. }
  75. ],
  76. "metadata": {},
  77. }
  78. psm = PersistentStateMachine.from_dict(data)
  79. assert psm.state == "translating"
  80. assert psm.context["progress"] == 50
  81. assert len(psm.history) == 1
  82. class TestStateMachinePersistence:
  83. """Test StateMachinePersistence class."""
  84. def test_init_with_path(self):
  85. """Test initialization with path."""
  86. with tempfile.TemporaryDirectory() as tmpdir:
  87. path = Path(tmpdir) / "state.json"
  88. persistence = StateMachinePersistence(path)
  89. assert persistence.state_file == path
  90. def test_save_and_load(self):
  91. """Test saving and loading state machine."""
  92. with tempfile.TemporaryDirectory() as tmpdir:
  93. path = Path(tmpdir) / "state.json"
  94. # Create and configure state machine
  95. sm = StateMachine()
  96. sm.transition_to(PipelineState.FINGERPRINTING, file="novel.txt")
  97. sm.transition_to(PipelineState.CLEANING, mode="deep")
  98. # Save
  99. persistence = StateMachinePersistence(path)
  100. persistence.save(sm)
  101. # Load
  102. loaded_data = persistence.load()
  103. assert loaded_data is not None
  104. assert loaded_data.state == "cleaning"
  105. assert loaded_data.context["file"] == "novel.txt"
  106. assert loaded_data.context["mode"] == "deep"
  107. assert len(loaded_data.history) == 2
  108. def test_save_creates_parent_directories(self):
  109. """Test that save creates parent directories."""
  110. with tempfile.TemporaryDirectory() as tmpdir:
  111. nested_path = Path(tmpdir) / "nested" / "dir" / "state.json"
  112. sm = StateMachine()
  113. persistence = StateMachinePersistence(nested_path)
  114. persistence.save(sm)
  115. assert nested_path.exists()
  116. def test_save_creates_valid_json(self):
  117. """Test that save creates valid JSON file."""
  118. with tempfile.TemporaryDirectory() as tmpdir:
  119. path = Path(tmpdir) / "state.json"
  120. sm = StateMachine()
  121. sm.transition_to(PipelineState.TRANSLATING, progress=75)
  122. persistence = StateMachinePersistence(path)
  123. persistence.save(sm)
  124. # Read and verify JSON
  125. with open(path, "r") as f:
  126. data = json.load(f)
  127. assert data["state"] == "translating"
  128. assert data["context"]["progress"] == 75
  129. def test_atomic_write(self):
  130. """Test that save uses atomic write (temp file + rename)."""
  131. with tempfile.TemporaryDirectory() as tmpdir:
  132. path = Path(tmpdir) / "state.json"
  133. sm = StateMachine()
  134. persistence = StateMachinePersistence(path)
  135. persistence.save(sm)
  136. # Temp file should not exist after successful save
  137. temp_file = path.with_suffix(path.suffix + ".tmp")
  138. assert not temp_file.exists()
  139. # Final file should exist
  140. assert path.exists()
  141. def test_load_nonexistent_file(self):
  142. """Test loading a non-existent file returns None."""
  143. with tempfile.TemporaryDirectory() as tmpdir:
  144. path = Path(tmpdir) / "nonexistent.json"
  145. persistence = StateMachinePersistence(path)
  146. result = persistence.load()
  147. assert result is None
  148. def test_load_invalid_json(self):
  149. """Test loading invalid JSON raises error."""
  150. with tempfile.TemporaryDirectory() as tmpdir:
  151. path = Path(tmpdir) / "invalid.json"
  152. # Write invalid JSON
  153. with open(path, "w") as f:
  154. f.write("{ invalid json }")
  155. persistence = StateMachinePersistence(path)
  156. with pytest.raises(StateMachinePersistenceError):
  157. persistence.load()
  158. def test_load_and_restore(self):
  159. """Test load_and_restore method."""
  160. with tempfile.TemporaryDirectory() as tmpdir:
  161. path = Path(tmpdir) / "state.json"
  162. # Save original state
  163. sm1 = StateMachine()
  164. sm1.transition_to(PipelineState.TRANSLATING, progress=50)
  165. persistence = StateMachinePersistence(path)
  166. persistence.save(sm1)
  167. # Restore into new state machine
  168. sm2 = StateMachine()
  169. result = persistence.load_and_restore(sm2)
  170. assert result is True
  171. assert sm2.state == PipelineState.TRANSLATING
  172. assert sm2.get_context_value("progress") == 50
  173. assert len(sm2.history) == 1
  174. def test_load_and_restore_nonexistent(self):
  175. """Test load_and_restore with non-existent file."""
  176. with tempfile.TemporaryDirectory() as tmpdir:
  177. path = Path(tmpdir) / "nonexistent.json"
  178. sm = StateMachine()
  179. persistence = StateMachinePersistence(path)
  180. result = persistence.load_and_restore(sm)
  181. assert result is False
  182. assert sm.state == PipelineState.IDLE
  183. def test_exists(self):
  184. """Test exists method."""
  185. with tempfile.TemporaryDirectory() as tmpdir:
  186. path = Path(tmpdir) / "state.json"
  187. persistence = StateMachinePersistence(path)
  188. assert persistence.exists() is False
  189. sm = StateMachine()
  190. persistence.save(sm)
  191. assert persistence.exists() is True
  192. def test_delete(self):
  193. """Test delete method."""
  194. with tempfile.TemporaryDirectory() as tmpdir:
  195. path = Path(tmpdir) / "state.json"
  196. sm = StateMachine()
  197. persistence = StateMachinePersistence(path)
  198. persistence.save(sm)
  199. assert path.exists()
  200. assert persistence.delete() is True
  201. assert not path.exists()
  202. def test_delete_nonexistent(self):
  203. """Test deleting non-existent file."""
  204. with tempfile.TemporaryDirectory() as tmpdir:
  205. path = Path(tmpdir) / "nonexistent.json"
  206. persistence = StateMachinePersistence(path)
  207. result = persistence.delete()
  208. assert result is False
  209. class TestStateMachinePersistenceMethods:
  210. """Test StateMachine save_to_file and load_from_file methods."""
  211. def test_save_to_file(self):
  212. """Test StateMachine.save_to_file method."""
  213. with tempfile.TemporaryDirectory() as tmpdir:
  214. path = Path(tmpdir) / "state.json"
  215. sm = StateMachine()
  216. sm.transition_to(PipelineState.UPLOADING, target="web")
  217. sm.save_to_file(path)
  218. # Verify file was created
  219. assert path.exists()
  220. # Verify content
  221. with open(path, "r") as f:
  222. data = json.load(f)
  223. assert data["state"] == "uploading"
  224. assert data["context"]["target"] == "web"
  225. def test_load_from_file(self):
  226. """Test StateMachine.load_from_file class method."""
  227. with tempfile.TemporaryDirectory() as tmpdir:
  228. path = Path(tmpdir) / "state.json"
  229. # Save original
  230. sm1 = StateMachine()
  231. sm1.transition_to(PipelineState.COMPLETED, output="/path/to/file.txt")
  232. sm1.save_to_file(path)
  233. # Load
  234. sm2 = StateMachine.load_from_file(path)
  235. assert sm2 is not None
  236. assert sm2.state == PipelineState.COMPLETED
  237. assert sm2.get_context_value("output") == "/path/to/file.txt"
  238. def test_load_from_file_nonexistent(self):
  239. """Test load_from_file with non-existent file."""
  240. with tempfile.TemporaryDirectory() as tmpdir:
  241. path = Path(tmpdir) / "nonexistent.json"
  242. sm = StateMachine.load_from_file(path)
  243. assert sm is None
  244. def test_roundtrip_preserves_all_data(self):
  245. """Test that save/load roundtrip preserves all data."""
  246. with tempfile.TemporaryDirectory() as tmpdir:
  247. path = Path(tmpdir) / "state.json"
  248. # Create state machine with various states
  249. sm1 = StateMachine()
  250. sm1.transition_to(PipelineState.FINGERPRINTING, file="novel.txt")
  251. sm1.transition_to(PipelineState.CLEANING)
  252. sm1.transition_to(PipelineState.TERM_EXTRACTION, terms=5)
  253. sm1.transition_to(PipelineState.TRANSLATING, progress=25, chapter=1)
  254. sm1.transition_to(PipelineState.TRANSLATING, progress=50, chapter=2)
  255. # Save and load
  256. sm1.save_to_file(path)
  257. sm2 = StateMachine.load_from_file(path)
  258. # Verify all data
  259. assert sm2.state == sm1.state
  260. assert sm2.context == sm1.context
  261. assert len(sm2.history) == len(sm1.history)
  262. for i, (h1, h2) in enumerate(zip(sm1.history, sm2.history)):
  263. assert h1.from_state == h2.from_state
  264. assert h1.to_state == h2.to_state
  265. assert h1.context == h2.context
  266. class TestPersistenceEdgeCases:
  267. """Test edge cases and error conditions."""
  268. def test_save_empty_state_machine(self):
  269. """Test saving an empty (initial) state machine."""
  270. with tempfile.TemporaryDirectory() as tmpdir:
  271. path = Path(tmpdir) / "state.json"
  272. sm = StateMachine()
  273. sm.save_to_file(path)
  274. sm2 = StateMachine.load_from_file(path)
  275. assert sm2 is not None
  276. assert sm2.state == PipelineState.IDLE
  277. assert sm2.context == {}
  278. assert sm2.history == []
  279. def test_save_with_large_context(self):
  280. """Test saving with large context data."""
  281. with tempfile.TemporaryDirectory() as tmpdir:
  282. path = Path(tmpdir) / "state.json"
  283. sm = StateMachine()
  284. large_context = {f"key_{i}": f"value_{i}" * 100 for i in range(50)}
  285. sm.transition_to(PipelineState.TRANSLATING, **large_context)
  286. sm.save_to_file(path)
  287. sm2 = StateMachine.load_from_file(path)
  288. assert sm2 is not None
  289. assert len(sm2.context) == 50
  290. def test_version_field_preserved(self):
  291. """Test that version field is saved and can be checked."""
  292. with tempfile.TemporaryDirectory() as tmpdir:
  293. path = Path(tmpdir) / "state.json"
  294. sm = StateMachine()
  295. sm.save_to_file(path)
  296. with open(path, "r") as f:
  297. data = json.load(f)
  298. assert "version" in data
  299. assert data["version"] == "1.0"
  300. def test_metadata_includes_saved_at(self):
  301. """Test that metadata includes timestamp."""
  302. with tempfile.TemporaryDirectory() as tmpdir:
  303. path = Path(tmpdir) / "state.json"
  304. sm = StateMachine()
  305. sm.save_to_file(path)
  306. persistence = StateMachinePersistence(path)
  307. data = persistence.load()
  308. assert "metadata" in data.to_dict()
  309. assert "saved_at" in data.metadata
  310. assert data.metadata["saved_at"] != ""
  311. class TestStateValidation:
  312. """Test state validation functionality."""
  313. def test_validate_on_restore_valid(self):
  314. """Test validation of valid state machine."""
  315. with tempfile.TemporaryDirectory() as tmpdir:
  316. path = Path(tmpdir) / "state.json"
  317. sm = StateMachine()
  318. sm.transition_to(PipelineState.TRANSLATING, progress=50)
  319. sm.save_to_file(path)
  320. sm2 = StateMachine.load_from_file(path)
  321. assert sm2.validate_on_restore() is True
  322. def test_validate_on_restore_empty(self):
  323. """Test validation of empty state machine."""
  324. sm = StateMachine()
  325. assert sm.validate_on_restore() is True
  326. def test_validate_on_restored_with_complete_flow(self):
  327. """Test validation of complete pipeline flow."""
  328. with tempfile.TemporaryDirectory() as tmpdir:
  329. path = Path(tmpdir) / "state.json"
  330. sm = StateMachine()
  331. for state in [
  332. PipelineState.FINGERPRINTING,
  333. PipelineState.CLEANING,
  334. PipelineState.TERM_EXTRACTION,
  335. PipelineState.TRANSLATING,
  336. PipelineState.UPLOADING,
  337. PipelineState.COMPLETED,
  338. ]:
  339. sm.transition_to(state)
  340. sm.save_to_file(path)
  341. sm2 = StateMachine.load_from_file(path)
  342. assert sm2.validate_on_restore() is True
  343. def test_get_resume_point(self):
  344. """Test getting resume point description."""
  345. with tempfile.TemporaryDirectory() as tmpdir:
  346. path = Path(tmpdir) / "state.json"
  347. # Test active state
  348. sm = StateMachine()
  349. sm.transition_to(PipelineState.TRANSLATING, progress=75)
  350. sm.save_to_file(path)
  351. sm2 = StateMachine.load_from_file(path)
  352. resume_point = sm2.get_resume_point()
  353. assert "Translating" in resume_point
  354. # Test terminal state
  355. sm3 = StateMachine()
  356. sm3.transition_to(PipelineState.COMPLETED)
  357. assert "completed" in sm3.get_resume_point().lower()
  358. # Test idle state
  359. sm4 = StateMachine()
  360. assert "Ready to start" in sm4.get_resume_point()
  361. def test_validate_detects_invalid_state(self):
  362. """Test validation detects manually corrupted state."""
  363. sm = StateMachine()
  364. # Manually corrupt the state
  365. sm._state = "invalid_state"
  366. assert sm.validate_on_restore() is False
  367. def test_validate_detects_invalid_history(self):
  368. """Test validation detects invalid history transitions."""
  369. sm = StateMachine()
  370. # Manually add invalid history entry
  371. from src.core.state_machine import TransitionEvent
  372. sm._history.append(
  373. TransitionEvent(
  374. from_state=PipelineState.IDLE,
  375. to_state=PipelineState.TRANSLATING, # Invalid: IDLE can't go to TRANSLATING
  376. context={},
  377. )
  378. )
  379. sm._state = PipelineState.TRANSLATING
  380. assert sm.validate_on_restore() is False
  381. class TestPersistenceEdgeCases:
  382. """Test edge cases for persistence."""
  383. def test_save_with_special_characters_in_context(self):
  384. """Test saving with special characters in context values."""
  385. with tempfile.TemporaryDirectory() as tmpdir:
  386. path = Path(tmpdir) / "state.json"
  387. sm = StateMachine()
  388. sm.transition_to(
  389. PipelineState.TRANSLATING,
  390. text="Hello\nWorld\t!",
  391. path="C:\\Users\\Test",
  392. quote='Test "quoted" string',
  393. )
  394. sm.save_to_file(path)
  395. sm2 = StateMachine.load_from_file(path)
  396. assert sm2.context["text"] == "Hello\nWorld\t!"
  397. assert sm2.context["path"] == "C:\\Users\\Test"
  398. def test_save_with_unicode(self):
  399. """Test saving with Unicode characters."""
  400. with tempfile.TemporaryDirectory() as tmpdir:
  401. path = Path(tmpdir) / "state.json"
  402. sm = StateMachine()
  403. sm.transition_to(
  404. PipelineState.TRANSLATING,
  405. chinese="林风是主角",
  406. emoji="😀🎉",
  407. mixed="Hello 世界 🌍",
  408. )
  409. sm.save_to_file(path)
  410. sm2 = StateMachine.load_from_file(path)
  411. assert sm2.context["chinese"] == "林风是主角"
  412. assert sm2.context["emoji"] == "😀🎉"
  413. assert sm2.context["mixed"] == "Hello 世界 🌍"
  414. def test_overwrite_existing_state_file(self):
  415. """Test overwriting an existing state file."""
  416. with tempfile.TemporaryDirectory() as tmpdir:
  417. path = Path(tmpdir) / "state.json"
  418. # Save first state
  419. sm1 = StateMachine()
  420. sm1.transition_to(PipelineState.FINGERPRINTING)
  421. sm1.save_to_file(path)
  422. # Overwrite with new state
  423. sm2 = StateMachine()
  424. sm2.transition_to(PipelineState.UPLOADING, target="web")
  425. sm2.save_to_file(path)
  426. # Load should get the new state
  427. sm3 = StateMachine.load_from_file(path)
  428. assert sm3.state == PipelineState.UPLOADING
  429. assert sm3.context["target"] == "web"
  430. def test_save_load_cycle_preserves_callbacks_config(self):
  431. """Test that callbacks are not persisted (as expected)."""
  432. with tempfile.TemporaryDirectory() as tmpdir:
  433. path = Path(tmpdir) / "state.json"
  434. sm1 = StateMachine()
  435. sm1.register_callback("on_transition", lambda e: None)
  436. sm1.transition_to(PipelineState.TRANSLATING)
  437. sm1.save_to_file(path)
  438. sm2 = StateMachine.load_from_file(path)
  439. # Loaded machine should have no callbacks registered
  440. assert len(sm2._callbacks) == 0
  441. # But state should be preserved
  442. assert sm2.state == PipelineState.TRANSLATING
  443. class TestLargeScalePersistence:
  444. """Test persistence with larger data sets."""
  445. def test_large_history(self):
  446. """Test saving and loading with large history."""
  447. with tempfile.TemporaryDirectory() as tmpdir:
  448. path = Path(tmpdir) / "state.json"
  449. sm = StateMachine()
  450. # Create a back-and-forth pattern
  451. for i in range(50):
  452. sm.transition_to(PipelineState.TRANSLATING, iteration=i)
  453. sm.transition_to(PipelineState.UPLOADING)
  454. sm.transition_to(PipelineState.COMPLETED)
  455. sm._state = PipelineState.IDLE # Reset for next iteration
  456. sm.transition_to(PipelineState.FINGERPRINTING)
  457. sm.save_to_file(path)
  458. sm2 = StateMachine.load_from_file(path)
  459. assert len(sm2.history) == len(sm.history)
  460. assert sm2.context["iteration"] == 49 # Last iteration
  461. def test_many_context_entries(self):
  462. """Test saving with many context entries."""
  463. with tempfile.TemporaryDirectory() as tmpdir:
  464. path = Path(tmpdir) / "state.json"
  465. sm = StateMachine()
  466. large_context = {f"key_{i}": f"value_{i}" for i in range(100)}
  467. sm.transition_to(PipelineState.TRANSLATING, **large_context)
  468. sm.save_to_file(path)
  469. sm2 = StateMachine.load_from_file(path)
  470. assert len(sm2.context) == 100
  471. assert sm2.context["key_0"] == "value_0"
  472. assert sm2.context["key_99"] == "value_99"
  473. class TestPersistenceWithDifferentStates:
  474. """Test persistence across different pipeline states."""
  475. def test_persist_from_each_state(self):
  476. """Test saving and restoring from each possible state."""
  477. with tempfile.TemporaryDirectory() as tmpdir:
  478. for state in PipelineState:
  479. path = Path(tmpdir) / f"state_{state.value}.json"
  480. sm1 = StateMachine()
  481. # Flow to the target state
  482. if state != PipelineState.IDLE:
  483. # Find a path to the state
  484. if state == PipelineState.PAUSED:
  485. sm1.transition_to(PipelineState.TRANSLATING)
  486. sm1.transition_to(PipelineState.PAUSED)
  487. elif state == PipelineState.FAILED:
  488. sm1.transition_to(PipelineState.TRANSLATING)
  489. sm1.transition_to(PipelineState.FAILED)
  490. elif state == PipelineState.COMPLETED:
  491. for s in [
  492. PipelineState.FINGERPRINTING,
  493. PipelineState.CLEANING,
  494. PipelineState.TERM_EXTRACTION,
  495. PipelineState.TRANSLATING,
  496. PipelineState.UPLOADING,
  497. PipelineState.COMPLETED,
  498. ]:
  499. sm1.transition_to(s)
  500. else:
  501. # For other states, try direct flow
  502. try:
  503. sm1.transition_to(state)
  504. except:
  505. pass # Skip if not reachable directly
  506. sm1.save_to_file(path)
  507. sm2 = StateMachine.load_from_file(path)
  508. assert sm2 is not None
  509. if sm1.state == state: # Only check if we successfully reached the state
  510. assert sm2.state == state
  511. assert sm2.validate_on_restore()