| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653 |
- """
- Unit tests for state machine persistence.
- Tests cover saving, loading, atomic writes, and error handling.
- """
- import json
- import tempfile
- from pathlib import Path
- import pytest
- from src.core.states import PipelineState
- from src.core.state_machine import StateMachine, TransitionEvent
- from src.core.persistence import (
- StateMachinePersistence,
- StateMachinePersistenceError,
- PersistentStateMachine,
- PersistentTransitionEvent,
- )
- class TestPersistentTransitionEvent:
- """Test PersistentTransitionEvent dataclass."""
- def test_creation(self):
- """Test creating a PersistentTransitionEvent."""
- event = PersistentTransitionEvent(
- from_state="idle",
- to_state="fingerprinting",
- context={"file": "test.txt"},
- )
- assert event.from_state == "idle"
- assert event.to_state == "fingerprinting"
- assert event.context == {"file": "test.txt"}
- def test_default_timestamp(self):
- """Test that timestamp is set automatically."""
- event = PersistentTransitionEvent(
- from_state="idle",
- to_state="fingerprinting",
- )
- assert event.timestamp != ""
- assert len(event.timestamp) > 0
- class TestPersistentStateMachine:
- """Test PersistentStateMachine dataclass."""
- def test_default_values(self):
- """Test default values."""
- psm = PersistentStateMachine()
- assert psm.state == "idle"
- assert psm.context == {}
- assert psm.history == []
- assert psm.metadata == {}
- def test_to_dict(self):
- """Test converting to dictionary."""
- psm = PersistentStateMachine(
- state="translating",
- context={"progress": 50},
- history=[
- PersistentTransitionEvent(
- from_state="idle",
- to_state="translating",
- )
- ],
- )
- data = psm.to_dict()
- assert data["state"] == "translating"
- assert data["context"]["progress"] == 50
- assert len(data["history"]) == 1
- def test_from_dict(self):
- """Test creating from dictionary."""
- data = {
- "version": "1.0",
- "state": "translating",
- "context": {"progress": 50},
- "history": [
- {
- "from_state": "idle",
- "to_state": "translating",
- "context": {},
- "timestamp": "2026-03-15T10:00:00",
- }
- ],
- "metadata": {},
- }
- psm = PersistentStateMachine.from_dict(data)
- assert psm.state == "translating"
- assert psm.context["progress"] == 50
- assert len(psm.history) == 1
- class TestStateMachinePersistence:
- """Test StateMachinePersistence class."""
- def test_init_with_path(self):
- """Test initialization with path."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- persistence = StateMachinePersistence(path)
- assert persistence.state_file == path
- def test_save_and_load(self):
- """Test saving and loading state machine."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- # Create and configure state machine
- sm = StateMachine()
- sm.transition_to(PipelineState.FINGERPRINTING, file="novel.txt")
- sm.transition_to(PipelineState.CLEANING, mode="deep")
- # Save
- persistence = StateMachinePersistence(path)
- persistence.save(sm)
- # Load
- loaded_data = persistence.load()
- assert loaded_data is not None
- assert loaded_data.state == "cleaning"
- assert loaded_data.context["file"] == "novel.txt"
- assert loaded_data.context["mode"] == "deep"
- assert len(loaded_data.history) == 2
- def test_save_creates_parent_directories(self):
- """Test that save creates parent directories."""
- with tempfile.TemporaryDirectory() as tmpdir:
- nested_path = Path(tmpdir) / "nested" / "dir" / "state.json"
- sm = StateMachine()
- persistence = StateMachinePersistence(nested_path)
- persistence.save(sm)
- assert nested_path.exists()
- def test_save_creates_valid_json(self):
- """Test that save creates valid JSON file."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- sm.transition_to(PipelineState.TRANSLATING, progress=75)
- persistence = StateMachinePersistence(path)
- persistence.save(sm)
- # Read and verify JSON
- with open(path, "r") as f:
- data = json.load(f)
- assert data["state"] == "translating"
- assert data["context"]["progress"] == 75
- def test_atomic_write(self):
- """Test that save uses atomic write (temp file + rename)."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- persistence = StateMachinePersistence(path)
- persistence.save(sm)
- # Temp file should not exist after successful save
- temp_file = path.with_suffix(path.suffix + ".tmp")
- assert not temp_file.exists()
- # Final file should exist
- assert path.exists()
- def test_load_nonexistent_file(self):
- """Test loading a non-existent file returns None."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "nonexistent.json"
- persistence = StateMachinePersistence(path)
- result = persistence.load()
- assert result is None
- def test_load_invalid_json(self):
- """Test loading invalid JSON raises error."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "invalid.json"
- # Write invalid JSON
- with open(path, "w") as f:
- f.write("{ invalid json }")
- persistence = StateMachinePersistence(path)
- with pytest.raises(StateMachinePersistenceError):
- persistence.load()
- def test_load_and_restore(self):
- """Test load_and_restore method."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- # Save original state
- sm1 = StateMachine()
- sm1.transition_to(PipelineState.TRANSLATING, progress=50)
- persistence = StateMachinePersistence(path)
- persistence.save(sm1)
- # Restore into new state machine
- sm2 = StateMachine()
- result = persistence.load_and_restore(sm2)
- assert result is True
- assert sm2.state == PipelineState.TRANSLATING
- assert sm2.get_context_value("progress") == 50
- assert len(sm2.history) == 1
- def test_load_and_restore_nonexistent(self):
- """Test load_and_restore with non-existent file."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "nonexistent.json"
- sm = StateMachine()
- persistence = StateMachinePersistence(path)
- result = persistence.load_and_restore(sm)
- assert result is False
- assert sm.state == PipelineState.IDLE
- def test_exists(self):
- """Test exists method."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- persistence = StateMachinePersistence(path)
- assert persistence.exists() is False
- sm = StateMachine()
- persistence.save(sm)
- assert persistence.exists() is True
- def test_delete(self):
- """Test delete method."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- persistence = StateMachinePersistence(path)
- persistence.save(sm)
- assert path.exists()
- assert persistence.delete() is True
- assert not path.exists()
- def test_delete_nonexistent(self):
- """Test deleting non-existent file."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "nonexistent.json"
- persistence = StateMachinePersistence(path)
- result = persistence.delete()
- assert result is False
- class TestStateMachinePersistenceMethods:
- """Test StateMachine save_to_file and load_from_file methods."""
- def test_save_to_file(self):
- """Test StateMachine.save_to_file method."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- sm.transition_to(PipelineState.UPLOADING, target="web")
- sm.save_to_file(path)
- # Verify file was created
- assert path.exists()
- # Verify content
- with open(path, "r") as f:
- data = json.load(f)
- assert data["state"] == "uploading"
- assert data["context"]["target"] == "web"
- def test_load_from_file(self):
- """Test StateMachine.load_from_file class method."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- # Save original
- sm1 = StateMachine()
- sm1.transition_to(PipelineState.COMPLETED, output="/path/to/file.txt")
- sm1.save_to_file(path)
- # Load
- sm2 = StateMachine.load_from_file(path)
- assert sm2 is not None
- assert sm2.state == PipelineState.COMPLETED
- assert sm2.get_context_value("output") == "/path/to/file.txt"
- def test_load_from_file_nonexistent(self):
- """Test load_from_file with non-existent file."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "nonexistent.json"
- sm = StateMachine.load_from_file(path)
- assert sm is None
- def test_roundtrip_preserves_all_data(self):
- """Test that save/load roundtrip preserves all data."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- # Create state machine with various states
- sm1 = StateMachine()
- sm1.transition_to(PipelineState.FINGERPRINTING, file="novel.txt")
- sm1.transition_to(PipelineState.CLEANING)
- sm1.transition_to(PipelineState.TERM_EXTRACTION, terms=5)
- sm1.transition_to(PipelineState.TRANSLATING, progress=25, chapter=1)
- sm1.transition_to(PipelineState.TRANSLATING, progress=50, chapter=2)
- # Save and load
- sm1.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- # Verify all data
- assert sm2.state == sm1.state
- assert sm2.context == sm1.context
- assert len(sm2.history) == len(sm1.history)
- for i, (h1, h2) in enumerate(zip(sm1.history, sm2.history)):
- assert h1.from_state == h2.from_state
- assert h1.to_state == h2.to_state
- assert h1.context == h2.context
- class TestPersistenceEdgeCases:
- """Test edge cases and error conditions."""
- def test_save_empty_state_machine(self):
- """Test saving an empty (initial) state machine."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert sm2 is not None
- assert sm2.state == PipelineState.IDLE
- assert sm2.context == {}
- assert sm2.history == []
- def test_save_with_large_context(self):
- """Test saving with large context data."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- large_context = {f"key_{i}": f"value_{i}" * 100 for i in range(50)}
- sm.transition_to(PipelineState.TRANSLATING, **large_context)
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert sm2 is not None
- assert len(sm2.context) == 50
- def test_version_field_preserved(self):
- """Test that version field is saved and can be checked."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- sm.save_to_file(path)
- with open(path, "r") as f:
- data = json.load(f)
- assert "version" in data
- assert data["version"] == "1.0"
- def test_metadata_includes_saved_at(self):
- """Test that metadata includes timestamp."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- sm.save_to_file(path)
- persistence = StateMachinePersistence(path)
- data = persistence.load()
- assert "metadata" in data.to_dict()
- assert "saved_at" in data.metadata
- assert data.metadata["saved_at"] != ""
- class TestStateValidation:
- """Test state validation functionality."""
- def test_validate_on_restore_valid(self):
- """Test validation of valid state machine."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- sm.transition_to(PipelineState.TRANSLATING, progress=50)
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert sm2.validate_on_restore() is True
- def test_validate_on_restore_empty(self):
- """Test validation of empty state machine."""
- sm = StateMachine()
- assert sm.validate_on_restore() is True
- def test_validate_on_restored_with_complete_flow(self):
- """Test validation of complete pipeline flow."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- for state in [
- PipelineState.FINGERPRINTING,
- PipelineState.CLEANING,
- PipelineState.TERM_EXTRACTION,
- PipelineState.TRANSLATING,
- PipelineState.UPLOADING,
- PipelineState.COMPLETED,
- ]:
- sm.transition_to(state)
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert sm2.validate_on_restore() is True
- def test_get_resume_point(self):
- """Test getting resume point description."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- # Test active state
- sm = StateMachine()
- sm.transition_to(PipelineState.TRANSLATING, progress=75)
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- resume_point = sm2.get_resume_point()
- assert "Translating" in resume_point
- # Test terminal state
- sm3 = StateMachine()
- sm3.transition_to(PipelineState.COMPLETED)
- assert "completed" in sm3.get_resume_point().lower()
- # Test idle state
- sm4 = StateMachine()
- assert "Ready to start" in sm4.get_resume_point()
- def test_validate_detects_invalid_state(self):
- """Test validation detects manually corrupted state."""
- sm = StateMachine()
- # Manually corrupt the state
- sm._state = "invalid_state"
- assert sm.validate_on_restore() is False
- def test_validate_detects_invalid_history(self):
- """Test validation detects invalid history transitions."""
- sm = StateMachine()
- # Manually add invalid history entry
- from src.core.state_machine import TransitionEvent
- sm._history.append(
- TransitionEvent(
- from_state=PipelineState.IDLE,
- to_state=PipelineState.TRANSLATING, # Invalid: IDLE can't go to TRANSLATING
- context={},
- )
- )
- sm._state = PipelineState.TRANSLATING
- assert sm.validate_on_restore() is False
- class TestPersistenceEdgeCases:
- """Test edge cases for persistence."""
- def test_save_with_special_characters_in_context(self):
- """Test saving with special characters in context values."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- sm.transition_to(
- PipelineState.TRANSLATING,
- text="Hello\nWorld\t!",
- path="C:\\Users\\Test",
- quote='Test "quoted" string',
- )
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert sm2.context["text"] == "Hello\nWorld\t!"
- assert sm2.context["path"] == "C:\\Users\\Test"
- def test_save_with_unicode(self):
- """Test saving with Unicode characters."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- sm.transition_to(
- PipelineState.TRANSLATING,
- chinese="林风是主角",
- emoji="😀🎉",
- mixed="Hello 世界 🌍",
- )
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert sm2.context["chinese"] == "林风是主角"
- assert sm2.context["emoji"] == "😀🎉"
- assert sm2.context["mixed"] == "Hello 世界 🌍"
- def test_overwrite_existing_state_file(self):
- """Test overwriting an existing state file."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- # Save first state
- sm1 = StateMachine()
- sm1.transition_to(PipelineState.FINGERPRINTING)
- sm1.save_to_file(path)
- # Overwrite with new state
- sm2 = StateMachine()
- sm2.transition_to(PipelineState.UPLOADING, target="web")
- sm2.save_to_file(path)
- # Load should get the new state
- sm3 = StateMachine.load_from_file(path)
- assert sm3.state == PipelineState.UPLOADING
- assert sm3.context["target"] == "web"
- def test_save_load_cycle_preserves_callbacks_config(self):
- """Test that callbacks are not persisted (as expected)."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm1 = StateMachine()
- sm1.register_callback("on_transition", lambda e: None)
- sm1.transition_to(PipelineState.TRANSLATING)
- sm1.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- # Loaded machine should have no callbacks registered
- assert len(sm2._callbacks) == 0
- # But state should be preserved
- assert sm2.state == PipelineState.TRANSLATING
- class TestLargeScalePersistence:
- """Test persistence with larger data sets."""
- def test_large_history(self):
- """Test saving and loading with large history."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- # Create a back-and-forth pattern
- for i in range(50):
- sm.transition_to(PipelineState.TRANSLATING, iteration=i)
- sm.transition_to(PipelineState.UPLOADING)
- sm.transition_to(PipelineState.COMPLETED)
- sm._state = PipelineState.IDLE # Reset for next iteration
- sm.transition_to(PipelineState.FINGERPRINTING)
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert len(sm2.history) == len(sm.history)
- assert sm2.context["iteration"] == 49 # Last iteration
- def test_many_context_entries(self):
- """Test saving with many context entries."""
- with tempfile.TemporaryDirectory() as tmpdir:
- path = Path(tmpdir) / "state.json"
- sm = StateMachine()
- large_context = {f"key_{i}": f"value_{i}" for i in range(100)}
- sm.transition_to(PipelineState.TRANSLATING, **large_context)
- sm.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert len(sm2.context) == 100
- assert sm2.context["key_0"] == "value_0"
- assert sm2.context["key_99"] == "value_99"
- class TestPersistenceWithDifferentStates:
- """Test persistence across different pipeline states."""
- def test_persist_from_each_state(self):
- """Test saving and restoring from each possible state."""
- with tempfile.TemporaryDirectory() as tmpdir:
- for state in PipelineState:
- path = Path(tmpdir) / f"state_{state.value}.json"
- sm1 = StateMachine()
- # Flow to the target state
- if state != PipelineState.IDLE:
- # Find a path to the state
- if state == PipelineState.PAUSED:
- sm1.transition_to(PipelineState.TRANSLATING)
- sm1.transition_to(PipelineState.PAUSED)
- elif state == PipelineState.FAILED:
- sm1.transition_to(PipelineState.TRANSLATING)
- sm1.transition_to(PipelineState.FAILED)
- elif state == PipelineState.COMPLETED:
- for s in [
- PipelineState.FINGERPRINTING,
- PipelineState.CLEANING,
- PipelineState.TERM_EXTRACTION,
- PipelineState.TRANSLATING,
- PipelineState.UPLOADING,
- PipelineState.COMPLETED,
- ]:
- sm1.transition_to(s)
- else:
- # For other states, try direct flow
- try:
- sm1.transition_to(state)
- except:
- pass # Skip if not reachable directly
- sm1.save_to_file(path)
- sm2 = StateMachine.load_from_file(path)
- assert sm2 is not None
- if sm1.state == state: # Only check if we successfully reached the state
- assert sm2.state == state
- assert sm2.validate_on_restore()
|