""" Unit tests for state machine persistence. Tests cover saving, loading, atomic writes, and error handling. """ import json import tempfile from pathlib import Path import pytest from src.core.states import PipelineState from src.core.state_machine import StateMachine, TransitionEvent from src.core.persistence import ( StateMachinePersistence, StateMachinePersistenceError, PersistentStateMachine, PersistentTransitionEvent, ) class TestPersistentTransitionEvent: """Test PersistentTransitionEvent dataclass.""" def test_creation(self): """Test creating a PersistentTransitionEvent.""" event = PersistentTransitionEvent( from_state="idle", to_state="fingerprinting", context={"file": "test.txt"}, ) assert event.from_state == "idle" assert event.to_state == "fingerprinting" assert event.context == {"file": "test.txt"} def test_default_timestamp(self): """Test that timestamp is set automatically.""" event = PersistentTransitionEvent( from_state="idle", to_state="fingerprinting", ) assert event.timestamp != "" assert len(event.timestamp) > 0 class TestPersistentStateMachine: """Test PersistentStateMachine dataclass.""" def test_default_values(self): """Test default values.""" psm = PersistentStateMachine() assert psm.state == "idle" assert psm.context == {} assert psm.history == [] assert psm.metadata == {} def test_to_dict(self): """Test converting to dictionary.""" psm = PersistentStateMachine( state="translating", context={"progress": 50}, history=[ PersistentTransitionEvent( from_state="idle", to_state="translating", ) ], ) data = psm.to_dict() assert data["state"] == "translating" assert data["context"]["progress"] == 50 assert len(data["history"]) == 1 def test_from_dict(self): """Test creating from dictionary.""" data = { "version": "1.0", "state": "translating", "context": {"progress": 50}, "history": [ { "from_state": "idle", "to_state": "translating", "context": {}, "timestamp": "2026-03-15T10:00:00", } ], "metadata": {}, } psm = PersistentStateMachine.from_dict(data) assert psm.state == "translating" assert psm.context["progress"] == 50 assert len(psm.history) == 1 class TestStateMachinePersistence: """Test StateMachinePersistence class.""" def test_init_with_path(self): """Test initialization with path.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" persistence = StateMachinePersistence(path) assert persistence.state_file == path def test_save_and_load(self): """Test saving and loading state machine.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" # Create and configure state machine sm = StateMachine() sm.transition_to(PipelineState.FINGERPRINTING, file="novel.txt") sm.transition_to(PipelineState.CLEANING, mode="deep") # Save persistence = StateMachinePersistence(path) persistence.save(sm) # Load loaded_data = persistence.load() assert loaded_data is not None assert loaded_data.state == "cleaning" assert loaded_data.context["file"] == "novel.txt" assert loaded_data.context["mode"] == "deep" assert len(loaded_data.history) == 2 def test_save_creates_parent_directories(self): """Test that save creates parent directories.""" with tempfile.TemporaryDirectory() as tmpdir: nested_path = Path(tmpdir) / "nested" / "dir" / "state.json" sm = StateMachine() persistence = StateMachinePersistence(nested_path) persistence.save(sm) assert nested_path.exists() def test_save_creates_valid_json(self): """Test that save creates valid JSON file.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() sm.transition_to(PipelineState.TRANSLATING, progress=75) persistence = StateMachinePersistence(path) persistence.save(sm) # Read and verify JSON with open(path, "r") as f: data = json.load(f) assert data["state"] == "translating" assert data["context"]["progress"] == 75 def test_atomic_write(self): """Test that save uses atomic write (temp file + rename).""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() persistence = StateMachinePersistence(path) persistence.save(sm) # Temp file should not exist after successful save temp_file = path.with_suffix(path.suffix + ".tmp") assert not temp_file.exists() # Final file should exist assert path.exists() def test_load_nonexistent_file(self): """Test loading a non-existent file returns None.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "nonexistent.json" persistence = StateMachinePersistence(path) result = persistence.load() assert result is None def test_load_invalid_json(self): """Test loading invalid JSON raises error.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "invalid.json" # Write invalid JSON with open(path, "w") as f: f.write("{ invalid json }") persistence = StateMachinePersistence(path) with pytest.raises(StateMachinePersistenceError): persistence.load() def test_load_and_restore(self): """Test load_and_restore method.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" # Save original state sm1 = StateMachine() sm1.transition_to(PipelineState.TRANSLATING, progress=50) persistence = StateMachinePersistence(path) persistence.save(sm1) # Restore into new state machine sm2 = StateMachine() result = persistence.load_and_restore(sm2) assert result is True assert sm2.state == PipelineState.TRANSLATING assert sm2.get_context_value("progress") == 50 assert len(sm2.history) == 1 def test_load_and_restore_nonexistent(self): """Test load_and_restore with non-existent file.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "nonexistent.json" sm = StateMachine() persistence = StateMachinePersistence(path) result = persistence.load_and_restore(sm) assert result is False assert sm.state == PipelineState.IDLE def test_exists(self): """Test exists method.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" persistence = StateMachinePersistence(path) assert persistence.exists() is False sm = StateMachine() persistence.save(sm) assert persistence.exists() is True def test_delete(self): """Test delete method.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() persistence = StateMachinePersistence(path) persistence.save(sm) assert path.exists() assert persistence.delete() is True assert not path.exists() def test_delete_nonexistent(self): """Test deleting non-existent file.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "nonexistent.json" persistence = StateMachinePersistence(path) result = persistence.delete() assert result is False class TestStateMachinePersistenceMethods: """Test StateMachine save_to_file and load_from_file methods.""" def test_save_to_file(self): """Test StateMachine.save_to_file method.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() sm.transition_to(PipelineState.UPLOADING, target="web") sm.save_to_file(path) # Verify file was created assert path.exists() # Verify content with open(path, "r") as f: data = json.load(f) assert data["state"] == "uploading" assert data["context"]["target"] == "web" def test_load_from_file(self): """Test StateMachine.load_from_file class method.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" # Save original sm1 = StateMachine() sm1.transition_to(PipelineState.COMPLETED, output="/path/to/file.txt") sm1.save_to_file(path) # Load sm2 = StateMachine.load_from_file(path) assert sm2 is not None assert sm2.state == PipelineState.COMPLETED assert sm2.get_context_value("output") == "/path/to/file.txt" def test_load_from_file_nonexistent(self): """Test load_from_file with non-existent file.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "nonexistent.json" sm = StateMachine.load_from_file(path) assert sm is None def test_roundtrip_preserves_all_data(self): """Test that save/load roundtrip preserves all data.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" # Create state machine with various states sm1 = StateMachine() sm1.transition_to(PipelineState.FINGERPRINTING, file="novel.txt") sm1.transition_to(PipelineState.CLEANING) sm1.transition_to(PipelineState.TERM_EXTRACTION, terms=5) sm1.transition_to(PipelineState.TRANSLATING, progress=25, chapter=1) sm1.transition_to(PipelineState.TRANSLATING, progress=50, chapter=2) # Save and load sm1.save_to_file(path) sm2 = StateMachine.load_from_file(path) # Verify all data assert sm2.state == sm1.state assert sm2.context == sm1.context assert len(sm2.history) == len(sm1.history) for i, (h1, h2) in enumerate(zip(sm1.history, sm2.history)): assert h1.from_state == h2.from_state assert h1.to_state == h2.to_state assert h1.context == h2.context class TestPersistenceEdgeCases: """Test edge cases and error conditions.""" def test_save_empty_state_machine(self): """Test saving an empty (initial) state machine.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert sm2 is not None assert sm2.state == PipelineState.IDLE assert sm2.context == {} assert sm2.history == [] def test_save_with_large_context(self): """Test saving with large context data.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() large_context = {f"key_{i}": f"value_{i}" * 100 for i in range(50)} sm.transition_to(PipelineState.TRANSLATING, **large_context) sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert sm2 is not None assert len(sm2.context) == 50 def test_version_field_preserved(self): """Test that version field is saved and can be checked.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() sm.save_to_file(path) with open(path, "r") as f: data = json.load(f) assert "version" in data assert data["version"] == "1.0" def test_metadata_includes_saved_at(self): """Test that metadata includes timestamp.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() sm.save_to_file(path) persistence = StateMachinePersistence(path) data = persistence.load() assert "metadata" in data.to_dict() assert "saved_at" in data.metadata assert data.metadata["saved_at"] != "" class TestStateValidation: """Test state validation functionality.""" def test_validate_on_restore_valid(self): """Test validation of valid state machine.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() sm.transition_to(PipelineState.TRANSLATING, progress=50) sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert sm2.validate_on_restore() is True def test_validate_on_restore_empty(self): """Test validation of empty state machine.""" sm = StateMachine() assert sm.validate_on_restore() is True def test_validate_on_restored_with_complete_flow(self): """Test validation of complete pipeline flow.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() for state in [ PipelineState.FINGERPRINTING, PipelineState.CLEANING, PipelineState.TERM_EXTRACTION, PipelineState.TRANSLATING, PipelineState.UPLOADING, PipelineState.COMPLETED, ]: sm.transition_to(state) sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert sm2.validate_on_restore() is True def test_get_resume_point(self): """Test getting resume point description.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" # Test active state sm = StateMachine() sm.transition_to(PipelineState.TRANSLATING, progress=75) sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) resume_point = sm2.get_resume_point() assert "Translating" in resume_point # Test terminal state sm3 = StateMachine() sm3.transition_to(PipelineState.COMPLETED) assert "completed" in sm3.get_resume_point().lower() # Test idle state sm4 = StateMachine() assert "Ready to start" in sm4.get_resume_point() def test_validate_detects_invalid_state(self): """Test validation detects manually corrupted state.""" sm = StateMachine() # Manually corrupt the state sm._state = "invalid_state" assert sm.validate_on_restore() is False def test_validate_detects_invalid_history(self): """Test validation detects invalid history transitions.""" sm = StateMachine() # Manually add invalid history entry from src.core.state_machine import TransitionEvent sm._history.append( TransitionEvent( from_state=PipelineState.IDLE, to_state=PipelineState.TRANSLATING, # Invalid: IDLE can't go to TRANSLATING context={}, ) ) sm._state = PipelineState.TRANSLATING assert sm.validate_on_restore() is False class TestPersistenceEdgeCases: """Test edge cases for persistence.""" def test_save_with_special_characters_in_context(self): """Test saving with special characters in context values.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() sm.transition_to( PipelineState.TRANSLATING, text="Hello\nWorld\t!", path="C:\\Users\\Test", quote='Test "quoted" string', ) sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert sm2.context["text"] == "Hello\nWorld\t!" assert sm2.context["path"] == "C:\\Users\\Test" def test_save_with_unicode(self): """Test saving with Unicode characters.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() sm.transition_to( PipelineState.TRANSLATING, chinese="林风是主角", emoji="😀🎉", mixed="Hello 世界 🌍", ) sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert sm2.context["chinese"] == "林风是主角" assert sm2.context["emoji"] == "😀🎉" assert sm2.context["mixed"] == "Hello 世界 🌍" def test_overwrite_existing_state_file(self): """Test overwriting an existing state file.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" # Save first state sm1 = StateMachine() sm1.transition_to(PipelineState.FINGERPRINTING) sm1.save_to_file(path) # Overwrite with new state sm2 = StateMachine() sm2.transition_to(PipelineState.UPLOADING, target="web") sm2.save_to_file(path) # Load should get the new state sm3 = StateMachine.load_from_file(path) assert sm3.state == PipelineState.UPLOADING assert sm3.context["target"] == "web" def test_save_load_cycle_preserves_callbacks_config(self): """Test that callbacks are not persisted (as expected).""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm1 = StateMachine() sm1.register_callback("on_transition", lambda e: None) sm1.transition_to(PipelineState.TRANSLATING) sm1.save_to_file(path) sm2 = StateMachine.load_from_file(path) # Loaded machine should have no callbacks registered assert len(sm2._callbacks) == 0 # But state should be preserved assert sm2.state == PipelineState.TRANSLATING class TestLargeScalePersistence: """Test persistence with larger data sets.""" def test_large_history(self): """Test saving and loading with large history.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() # Create a back-and-forth pattern for i in range(50): sm.transition_to(PipelineState.TRANSLATING, iteration=i) sm.transition_to(PipelineState.UPLOADING) sm.transition_to(PipelineState.COMPLETED) sm._state = PipelineState.IDLE # Reset for next iteration sm.transition_to(PipelineState.FINGERPRINTING) sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert len(sm2.history) == len(sm.history) assert sm2.context["iteration"] == 49 # Last iteration def test_many_context_entries(self): """Test saving with many context entries.""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "state.json" sm = StateMachine() large_context = {f"key_{i}": f"value_{i}" for i in range(100)} sm.transition_to(PipelineState.TRANSLATING, **large_context) sm.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert len(sm2.context) == 100 assert sm2.context["key_0"] == "value_0" assert sm2.context["key_99"] == "value_99" class TestPersistenceWithDifferentStates: """Test persistence across different pipeline states.""" def test_persist_from_each_state(self): """Test saving and restoring from each possible state.""" with tempfile.TemporaryDirectory() as tmpdir: for state in PipelineState: path = Path(tmpdir) / f"state_{state.value}.json" sm1 = StateMachine() # Flow to the target state if state != PipelineState.IDLE: # Find a path to the state if state == PipelineState.PAUSED: sm1.transition_to(PipelineState.TRANSLATING) sm1.transition_to(PipelineState.PAUSED) elif state == PipelineState.FAILED: sm1.transition_to(PipelineState.TRANSLATING) sm1.transition_to(PipelineState.FAILED) elif state == PipelineState.COMPLETED: for s in [ PipelineState.FINGERPRINTING, PipelineState.CLEANING, PipelineState.TERM_EXTRACTION, PipelineState.TRANSLATING, PipelineState.UPLOADING, PipelineState.COMPLETED, ]: sm1.transition_to(s) else: # For other states, try direct flow try: sm1.transition_to(state) except: pass # Skip if not reachable directly sm1.save_to_file(path) sm2 = StateMachine.load_from_file(path) assert sm2 is not None if sm1.state == state: # Only check if we successfully reached the state assert sm2.state == state assert sm2.validate_on_restore()