test_core_persistence.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. """
  2. Unit tests for state machine persistence.
  3. Tests cover saving, loading, atomic writes, and error handling.
  4. """
  5. import json
  6. import tempfile
  7. from pathlib import Path
  8. import pytest
  9. from src.core.states import PipelineState
  10. from src.core.state_machine import StateMachine, TransitionEvent
  11. from src.core.persistence import (
  12. StateMachinePersistence,
  13. StateMachinePersistenceError,
  14. PersistentStateMachine,
  15. PersistentTransitionEvent,
  16. )
  17. class TestPersistentTransitionEvent:
  18. """Test PersistentTransitionEvent dataclass."""
  19. def test_creation(self):
  20. """Test creating a PersistentTransitionEvent."""
  21. event = PersistentTransitionEvent(
  22. from_state="idle",
  23. to_state="fingerprinting",
  24. context={"file": "test.txt"},
  25. )
  26. assert event.from_state == "idle"
  27. assert event.to_state == "fingerprinting"
  28. assert event.context == {"file": "test.txt"}
  29. def test_default_timestamp(self):
  30. """Test that timestamp is set automatically."""
  31. event = PersistentTransitionEvent(
  32. from_state="idle",
  33. to_state="fingerprinting",
  34. )
  35. assert event.timestamp != ""
  36. assert len(event.timestamp) > 0
  37. class TestPersistentStateMachine:
  38. """Test PersistentStateMachine dataclass."""
  39. def test_default_values(self):
  40. """Test default values."""
  41. psm = PersistentStateMachine()
  42. assert psm.state == "idle"
  43. assert psm.context == {}
  44. assert psm.history == []
  45. assert psm.metadata == {}
  46. def test_to_dict(self):
  47. """Test converting to dictionary."""
  48. psm = PersistentStateMachine(
  49. state="translating",
  50. context={"progress": 50},
  51. history=[
  52. PersistentTransitionEvent(
  53. from_state="idle",
  54. to_state="translating",
  55. )
  56. ],
  57. )
  58. data = psm.to_dict()
  59. assert data["state"] == "translating"
  60. assert data["context"]["progress"] == 50
  61. assert len(data["history"]) == 1
  62. def test_from_dict(self):
  63. """Test creating from dictionary."""
  64. data = {
  65. "version": "1.0",
  66. "state": "translating",
  67. "context": {"progress": 50},
  68. "history": [
  69. {
  70. "from_state": "idle",
  71. "to_state": "translating",
  72. "context": {},
  73. "timestamp": "2026-03-15T10:00:00",
  74. }
  75. ],
  76. "metadata": {},
  77. }
  78. psm = PersistentStateMachine.from_dict(data)
  79. assert psm.state == "translating"
  80. assert psm.context["progress"] == 50
  81. assert len(psm.history) == 1
  82. class TestStateMachinePersistence:
  83. """Test StateMachinePersistence class."""
  84. def test_init_with_path(self):
  85. """Test initialization with path."""
  86. with tempfile.TemporaryDirectory() as tmpdir:
  87. path = Path(tmpdir) / "state.json"
  88. persistence = StateMachinePersistence(path)
  89. assert persistence.state_file == path
  90. def test_save_and_load(self):
  91. """Test saving and loading state machine."""
  92. with tempfile.TemporaryDirectory() as tmpdir:
  93. path = Path(tmpdir) / "state.json"
  94. # Create and configure state machine
  95. sm = StateMachine()
  96. sm.transition_to(PipelineState.FINGERPRINTING, file="novel.txt")
  97. sm.transition_to(PipelineState.CLEANING, mode="deep")
  98. # Save
  99. persistence = StateMachinePersistence(path)
  100. persistence.save(sm)
  101. # Load
  102. loaded_data = persistence.load()
  103. assert loaded_data is not None
  104. assert loaded_data.state == "cleaning"
  105. assert loaded_data.context["file"] == "novel.txt"
  106. assert loaded_data.context["mode"] == "deep"
  107. assert len(loaded_data.history) == 2
  108. def test_save_creates_parent_directories(self):
  109. """Test that save creates parent directories."""
  110. with tempfile.TemporaryDirectory() as tmpdir:
  111. nested_path = Path(tmpdir) / "nested" / "dir" / "state.json"
  112. sm = StateMachine()
  113. persistence = StateMachinePersistence(nested_path)
  114. persistence.save(sm)
  115. assert nested_path.exists()
  116. def test_save_creates_valid_json(self):
  117. """Test that save creates valid JSON file."""
  118. with tempfile.TemporaryDirectory() as tmpdir:
  119. path = Path(tmpdir) / "state.json"
  120. sm = StateMachine()
  121. sm.transition_to(PipelineState.TRANSLATING, progress=75)
  122. persistence = StateMachinePersistence(path)
  123. persistence.save(sm)
  124. # Read and verify JSON
  125. with open(path, "r") as f:
  126. data = json.load(f)
  127. assert data["state"] == "translating"
  128. assert data["context"]["progress"] == 75
  129. def test_atomic_write(self):
  130. """Test that save uses atomic write (temp file + rename)."""
  131. with tempfile.TemporaryDirectory() as tmpdir:
  132. path = Path(tmpdir) / "state.json"
  133. sm = StateMachine()
  134. persistence = StateMachinePersistence(path)
  135. persistence.save(sm)
  136. # Temp file should not exist after successful save
  137. temp_file = path.with_suffix(path.suffix + ".tmp")
  138. assert not temp_file.exists()
  139. # Final file should exist
  140. assert path.exists()
  141. def test_load_nonexistent_file(self):
  142. """Test loading a non-existent file returns None."""
  143. with tempfile.TemporaryDirectory() as tmpdir:
  144. path = Path(tmpdir) / "nonexistent.json"
  145. persistence = StateMachinePersistence(path)
  146. result = persistence.load()
  147. assert result is None
  148. def test_load_invalid_json(self):
  149. """Test loading invalid JSON raises error."""
  150. with tempfile.TemporaryDirectory() as tmpdir:
  151. path = Path(tmpdir) / "invalid.json"
  152. # Write invalid JSON
  153. with open(path, "w") as f:
  154. f.write("{ invalid json }")
  155. persistence = StateMachinePersistence(path)
  156. with pytest.raises(StateMachinePersistenceError):
  157. persistence.load()
  158. def test_load_and_restore(self):
  159. """Test load_and_restore method."""
  160. with tempfile.TemporaryDirectory() as tmpdir:
  161. path = Path(tmpdir) / "state.json"
  162. # Save original state
  163. sm1 = StateMachine()
  164. sm1.transition_to(PipelineState.TRANSLATING, progress=50)
  165. persistence = StateMachinePersistence(path)
  166. persistence.save(sm1)
  167. # Restore into new state machine
  168. sm2 = StateMachine()
  169. result = persistence.load_and_restore(sm2)
  170. assert result is True
  171. assert sm2.state == PipelineState.TRANSLATING
  172. assert sm2.get_context_value("progress") == 50
  173. assert len(sm2.history) == 1
  174. def test_load_and_restore_nonexistent(self):
  175. """Test load_and_restore with non-existent file."""
  176. with tempfile.TemporaryDirectory() as tmpdir:
  177. path = Path(tmpdir) / "nonexistent.json"
  178. sm = StateMachine()
  179. persistence = StateMachinePersistence(path)
  180. result = persistence.load_and_restore(sm)
  181. assert result is False
  182. assert sm.state == PipelineState.IDLE
  183. def test_exists(self):
  184. """Test exists method."""
  185. with tempfile.TemporaryDirectory() as tmpdir:
  186. path = Path(tmpdir) / "state.json"
  187. persistence = StateMachinePersistence(path)
  188. assert persistence.exists() is False
  189. sm = StateMachine()
  190. persistence.save(sm)
  191. assert persistence.exists() is True
  192. def test_delete(self):
  193. """Test delete method."""
  194. with tempfile.TemporaryDirectory() as tmpdir:
  195. path = Path(tmpdir) / "state.json"
  196. sm = StateMachine()
  197. persistence = StateMachinePersistence(path)
  198. persistence.save(sm)
  199. assert path.exists()
  200. assert persistence.delete() is True
  201. assert not path.exists()
  202. def test_delete_nonexistent(self):
  203. """Test deleting non-existent file."""
  204. with tempfile.TemporaryDirectory() as tmpdir:
  205. path = Path(tmpdir) / "nonexistent.json"
  206. persistence = StateMachinePersistence(path)
  207. result = persistence.delete()
  208. assert result is False
  209. class TestStateMachinePersistenceMethods:
  210. """Test StateMachine save_to_file and load_from_file methods."""
  211. def test_save_to_file(self):
  212. """Test StateMachine.save_to_file method."""
  213. with tempfile.TemporaryDirectory() as tmpdir:
  214. path = Path(tmpdir) / "state.json"
  215. sm = StateMachine()
  216. sm.transition_to(PipelineState.UPLOADING, target="web")
  217. sm.save_to_file(path)
  218. # Verify file was created
  219. assert path.exists()
  220. # Verify content
  221. with open(path, "r") as f:
  222. data = json.load(f)
  223. assert data["state"] == "uploading"
  224. assert data["context"]["target"] == "web"
  225. def test_load_from_file(self):
  226. """Test StateMachine.load_from_file class method."""
  227. with tempfile.TemporaryDirectory() as tmpdir:
  228. path = Path(tmpdir) / "state.json"
  229. # Save original
  230. sm1 = StateMachine()
  231. sm1.transition_to(PipelineState.COMPLETED, output="/path/to/file.txt")
  232. sm1.save_to_file(path)
  233. # Load
  234. sm2 = StateMachine.load_from_file(path)
  235. assert sm2 is not None
  236. assert sm2.state == PipelineState.COMPLETED
  237. assert sm2.get_context_value("output") == "/path/to/file.txt"
  238. def test_load_from_file_nonexistent(self):
  239. """Test load_from_file with non-existent file."""
  240. with tempfile.TemporaryDirectory() as tmpdir:
  241. path = Path(tmpdir) / "nonexistent.json"
  242. sm = StateMachine.load_from_file(path)
  243. assert sm is None
  244. def test_roundtrip_preserves_all_data(self):
  245. """Test that save/load roundtrip preserves all data."""
  246. with tempfile.TemporaryDirectory() as tmpdir:
  247. path = Path(tmpdir) / "state.json"
  248. # Create state machine with various states
  249. sm1 = StateMachine()
  250. sm1.transition_to(PipelineState.FINGERPRINTING, file="novel.txt")
  251. sm1.transition_to(PipelineState.CLEANING)
  252. sm1.transition_to(PipelineState.TERM_EXTRACTION, terms=5)
  253. sm1.transition_to(PipelineState.TRANSLATING, progress=25, chapter=1)
  254. sm1.transition_to(PipelineState.TRANSLATING, progress=50, chapter=2)
  255. # Save and load
  256. sm1.save_to_file(path)
  257. sm2 = StateMachine.load_from_file(path)
  258. # Verify all data
  259. assert sm2.state == sm1.state
  260. assert sm2.context == sm1.context
  261. assert len(sm2.history) == len(sm1.history)
  262. for i, (h1, h2) in enumerate(zip(sm1.history, sm2.history)):
  263. assert h1.from_state == h2.from_state
  264. assert h1.to_state == h2.to_state
  265. assert h1.context == h2.context
  266. class TestPersistenceEdgeCases:
  267. """Test edge cases and error conditions."""
  268. def test_save_empty_state_machine(self):
  269. """Test saving an empty (initial) state machine."""
  270. with tempfile.TemporaryDirectory() as tmpdir:
  271. path = Path(tmpdir) / "state.json"
  272. sm = StateMachine()
  273. sm.save_to_file(path)
  274. sm2 = StateMachine.load_from_file(path)
  275. assert sm2 is not None
  276. assert sm2.state == PipelineState.IDLE
  277. assert sm2.context == {}
  278. assert sm2.history == []
  279. def test_save_with_large_context(self):
  280. """Test saving with large context data."""
  281. with tempfile.TemporaryDirectory() as tmpdir:
  282. path = Path(tmpdir) / "state.json"
  283. sm = StateMachine()
  284. large_context = {f"key_{i}": f"value_{i}" * 100 for i in range(50)}
  285. sm.transition_to(PipelineState.TRANSLATING, **large_context)
  286. sm.save_to_file(path)
  287. sm2 = StateMachine.load_from_file(path)
  288. assert sm2 is not None
  289. assert len(sm2.context) == 50
  290. def test_version_field_preserved(self):
  291. """Test that version field is saved and can be checked."""
  292. with tempfile.TemporaryDirectory() as tmpdir:
  293. path = Path(tmpdir) / "state.json"
  294. sm = StateMachine()
  295. sm.save_to_file(path)
  296. with open(path, "r") as f:
  297. data = json.load(f)
  298. assert "version" in data
  299. assert data["version"] == "1.0"
  300. def test_metadata_includes_saved_at(self):
  301. """Test that metadata includes timestamp."""
  302. with tempfile.TemporaryDirectory() as tmpdir:
  303. path = Path(tmpdir) / "state.json"
  304. sm = StateMachine()
  305. sm.save_to_file(path)
  306. persistence = StateMachinePersistence(path)
  307. data = persistence.load()
  308. assert "metadata" in data.to_dict()
  309. assert "saved_at" in data.metadata
  310. assert data.metadata["saved_at"] != ""