2
0
Просмотр исходного кода

feat(core): Implement StateMachine transition engine (Story 1.1.2)

- Add StateMachine class with transition validation
- Implement callback system for state changes (on_transition, on_enter_<state>)
- Add context storage for state data (set/get/clear_context)
- Add transition history tracking (TransitionEvent dataclass)
- Add transition_or_raise() for strict mode error handling
- Add can_transition_to() for validation checking
- Add reset() method to return to initial state
- Add get_state_info() for comprehensive state information
- Unit tests covering all scenarios:
  * Valid/invalid transitions
  * Callback registration/unregistration
  * Context management
  * History tracking
  * Reset functionality
  * Error handling (InvalidTransitionError)

Part of Epic 1.1: State Machine (Phase 1a)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
d8dfun 3 дней назад
Родитель
Сommit
b1a92c8928
3 измененных файлов с 690 добавлено и 0 удалено
  1. 10 0
      src/core/__init__.py
  2. 299 0
      src/core/state_machine.py
  3. 381 0
      tests/test_core_state_machine.py

+ 10 - 0
src/core/__init__.py

@@ -7,9 +7,19 @@ translation pipeline lifecycle, including states, transitions, and validation.
 
 from .states import PipelineState
 from .transitions import is_transition_allowed, ALLOWED_TRANSITIONS
+from .state_machine import (
+    StateMachine,
+    StateMachineError,
+    InvalidTransitionError,
+    TransitionEvent,
+)
 
 __all__ = [
     "PipelineState",
     "is_transition_allowed",
     "ALLOWED_TRANSITIONS",
+    "StateMachine",
+    "StateMachineError",
+    "InvalidTransitionError",
+    "TransitionEvent",
 ]

+ 299 - 0
src/core/state_machine.py

@@ -0,0 +1,299 @@
+"""
+State machine implementation for pipeline lifecycle management.
+
+This module provides a state machine with transition validation,
+callbacks, and context storage for managing translation pipeline state.
+"""
+
+from typing import Callable, Dict, List, Any, Optional
+from dataclasses import dataclass, field
+
+from .states import PipelineState
+from .transitions import is_transition_allowed
+
+
+@dataclass
+class TransitionEvent:
+    """
+    Represents a state transition event.
+
+    Attributes:
+        from_state: The source state
+        to_state: The target state
+        context: Optional context data passed during transition
+    """
+
+    from_state: PipelineState
+    to_state: PipelineState
+    context: Dict[str, Any] = field(default_factory=dict)
+
+
+class StateMachineError(Exception):
+    """Base exception for state machine errors."""
+
+    pass
+
+
+class InvalidTransitionError(StateMachineError):
+    """Raised when an invalid state transition is attempted."""
+
+    def __init__(self, from_state: PipelineState, to_state: PipelineState):
+        self.from_state = from_state
+        self.to_state = to_state
+        super().__init__(
+            f"Invalid transition from {from_state.value} to {to_state.value}"
+        )
+
+
+class StateMachine:
+    """
+    State machine for managing pipeline lifecycle.
+
+    Features:
+    - Transition validation based on allowed transitions
+    - Callback system for state change events
+    - Context storage for state-related data
+    - Transition history tracking
+
+    Example:
+        >>> sm = StateMachine()
+        >>> sm.state
+        <PipelineState.IDLE: 'idle'>
+        >>> sm.transition_to(PipelineState.FINGERPRINTING)
+        True
+        >>> sm.state
+        <PipelineState.FINGERPRINTING: 'fingerprinting'>
+    """
+
+    def __init__(self):
+        """Initialize the state machine with IDLE as the initial state."""
+        self._state: PipelineState = PipelineState.IDLE
+        self._callbacks: Dict[str, List[Callable]] = {}
+        self._context: Dict[str, Any] = {}
+        self._history: List[TransitionEvent] = []
+
+    @property
+    def state(self) -> PipelineState:
+        """
+        Get the current state.
+
+        Returns:
+            The current PipelineState
+        """
+        return self._state
+
+    @property
+    def context(self) -> Dict[str, Any]:
+        """
+        Get the state context data.
+
+        Returns:
+            Dictionary containing context data
+        """
+        return self._context.copy()
+
+    @property
+    def history(self) -> List[TransitionEvent]:
+        """
+        Get the transition history.
+
+        Returns:
+            List of TransitionEvent objects in chronological order
+        """
+        return self._history.copy()
+
+    def can_transition_to(self, new_state: PipelineState) -> bool:
+        """
+        Check if transition to a new state is allowed.
+
+        Args:
+            new_state: The target state to check
+
+        Returns:
+            True if the transition is allowed, False otherwise
+        """
+        return is_transition_allowed(self._state, new_state)
+
+    def transition_to(
+        self, new_state: PipelineState, **kwargs
+    ) -> bool:
+        """
+        Attempt to transition to a new state.
+
+        Args:
+            new_state: The target state
+            **kwargs: Optional context data to store during transition
+
+        Returns:
+            True if transition succeeded, False otherwise
+
+        Raises:
+            InvalidTransitionError: If strict mode is enabled and transition is invalid
+        """
+        if not is_transition_allowed(self._state, new_state):
+            return False
+
+        old_state = self._state
+        self._state = new_state
+
+        # Update context with provided kwargs
+        if kwargs:
+            self._context.update(kwargs)
+
+        # Record transition in history
+        event = TransitionEvent(
+            from_state=old_state, to_state=new_state, context=kwargs.copy()
+        )
+        self._history.append(event)
+
+        # Trigger callbacks
+        self._trigger_callbacks("on_transition", event)
+        self._trigger_callbacks(f"on_enter_{new_state.value}", event)
+
+        return True
+
+    def transition_or_raise(
+        self, new_state: PipelineState, **kwargs
+    ) -> None:
+        """
+        Transition to a new state, raising an exception if invalid.
+
+        Args:
+            new_state: The target state
+            **kwargs: Optional context data to store during transition
+
+        Raises:
+            InvalidTransitionError: If the transition is not allowed
+        """
+        if not self.transition_to(new_state, **kwargs):
+            raise InvalidTransitionError(self._state, new_state)
+
+    def register_callback(
+        self, event: str, callback: Callable[..., None]
+    ) -> None:
+        """
+        Register a callback for a specific event.
+
+        Available events:
+            - "on_transition": Called on any state change, receives TransitionEvent
+            - "on_enter_<state>": Called when entering a specific state
+            - "on_exit_<state>": Called when exiting a specific state (not yet implemented)
+
+        Args:
+            event: The event name
+            callback: The callback function
+
+        Example:
+            >>> sm = StateMachine()
+            >>> sm.register_callback('on_enter_translating', lambda e: print('Translating!'))
+            >>> sm.transition_to(PipelineState.TRANSLATING)
+            Translating!
+        """
+        if event not in self._callbacks:
+            self._callbacks[event] = []
+        self._callbacks[event].append(callback)
+
+    def unregister_callback(self, event: str, callback: Callable[..., None]) -> bool:
+        """
+        Unregister a previously registered callback.
+
+        Args:
+            event: The event name
+            callback: The callback function to remove
+
+        Returns:
+            True if the callback was found and removed, False otherwise
+        """
+        if event not in self._callbacks:
+            return False
+
+        try:
+            self._callbacks[event].remove(callback)
+            return True
+        except ValueError:
+            return False
+
+    def _trigger_callbacks(self, event: str, *args) -> None:
+        """
+        Trigger all callbacks registered for an event.
+
+        Args:
+            event: The event name
+            *args: Arguments to pass to the callbacks
+        """
+        for callback in self._callbacks.get(event, []):
+            try:
+                callback(*args)
+            except Exception as e:
+                # Log but don't fail the state machine
+                # In production, this should go to a proper logger
+                print(f"Callback error for {event}: {e}")
+
+    def reset(self) -> None:
+        """
+        Reset the state machine to initial state.
+
+        Clears:
+        - State (resets to IDLE)
+        - Context data
+        - Transition history
+        - Callbacks (NOT cleared - they persist)
+        """
+        old_state = self._state
+        self._state = PipelineState.IDLE
+        self._context.clear()
+        self._history.clear()
+
+        # Trigger reset callback if transitioning from non-IDLE
+        if old_state != PipelineState.IDLE:
+            self._trigger_callbacks("on_reset", old_state)
+
+    def get_context_value(self, key: str, default: Any = None) -> Any:
+        """
+        Get a value from the context.
+
+        Args:
+            key: The context key
+            default: Default value if key not found
+
+        Returns:
+            The context value or default
+        """
+        return self._context.get(key, default)
+
+    def set_context_value(self, key: str, value: Any) -> None:
+        """
+        Set a value in the context.
+
+        Args:
+            key: The context key
+            value: The value to set
+        """
+        self._context[key] = value
+
+    def clear_context(self) -> None:
+        """Clear all context data."""
+        self._context.clear()
+
+    def get_state_info(self) -> Dict[str, Any]:
+        """
+        Get comprehensive state information.
+
+        Returns:
+            Dictionary containing state, context, and history info
+        """
+        return {
+            "current_state": self._state.value,
+            "is_terminal": self._state.is_terminal(),
+            "is_active": self._state.is_active(),
+            "context": self._context.copy(),
+            "transition_count": len(self._history),
+            "allowed_transitions": [
+                s.value for s in self._get_allowed_transitions()
+            ],
+        }
+
+    def _get_allowed_transitions(self) -> List[PipelineState]:
+        """Get list of allowed transitions from current state."""
+        from .transitions import get_allowed_transitions
+        return list(get_allowed_transitions(self._state))

+ 381 - 0
tests/test_core_state_machine.py

@@ -0,0 +1,381 @@
+"""
+Unit tests for StateMachine implementation in src/core.
+
+Tests cover state transitions, callbacks, context management,
+and error handling.
+"""
+
+import pytest
+
+from src.core.states import PipelineState
+from src.core.state_machine import (
+    StateMachine,
+    StateMachineError,
+    InvalidTransitionError,
+    TransitionEvent,
+)
+
+
+class TestStateMachineInitialization:
+    """Test StateMachine initialization and basic properties."""
+
+    def test_initial_state_is_idle(self):
+        """Test that state machine starts in IDLE state."""
+        sm = StateMachine()
+        assert sm.state == PipelineState.IDLE
+
+    def test_context_is_empty_on_init(self):
+        """Test that context is empty on initialization."""
+        sm = StateMachine()
+        assert sm.context == {}
+        assert len(sm.context) == 0
+
+    def test_history_is_empty_on_init(self):
+        """Test that history is empty on initialization."""
+        sm = StateMachine()
+        assert sm.history == []
+        assert len(sm.history) == 0
+
+    def test_get_state_info(self):
+        """Test getting comprehensive state information."""
+        sm = StateMachine()
+        info = sm.get_state_info()
+
+        assert info["current_state"] == "idle"
+        assert info["is_terminal"] is False
+        assert info["is_active"] is False
+        assert info["transition_count"] == 0
+        assert "fingerprinting" in info["allowed_transitions"]
+
+
+class TestValidTransitions:
+    """Test valid state transitions."""
+
+    def test_idle_to_fingerprinting(self):
+        """Test IDLE -> FINGERPRINTING transition."""
+        sm = StateMachine()
+        result = sm.transition_to(PipelineState.FINGERPRINTING)
+
+        assert result is True
+        assert sm.state == PipelineState.FINGERPRINTING
+
+    def test_normal_pipeline_flow(self):
+        """Test complete normal flow through the pipeline."""
+        sm = StateMachine()
+
+        flow = [
+            PipelineState.FINGERPRINTING,
+            PipelineState.CLEANING,
+            PipelineState.TERM_EXTRACTION,
+            PipelineState.TRANSLATING,
+            PipelineState.UPLOADING,
+            PipelineState.COMPLETED,
+        ]
+
+        for target_state in flow:
+            result = sm.transition_to(target_state)
+            assert result is True, f"Failed to transition to {target_state}"
+            assert sm.state == target_state
+
+    def test_transition_to_failed(self):
+        """Test transition from active state to FAILED."""
+        sm = StateMachine()
+        sm.transition_to(PipelineState.FINGERPRINTING)
+
+        result = sm.transition_to(PipelineState.FAILED)
+        assert result is True
+        assert sm.state == PipelineState.FAILED
+
+    def test_failed_to_idle(self):
+        """Test reset from FAILED to IDLE."""
+        sm = StateMachine()
+        sm.transition_to(PipelineState.FINGERPRINTING)
+        sm.transition_to(PipelineState.FAILED)
+
+        result = sm.transition_to(PipelineState.IDLE)
+        assert result is True
+        assert sm.state == PipelineState.IDLE
+
+    def test_pause_and_resume(self):
+        """Test pausing from active state and resuming."""
+        sm = StateMachine()
+        sm.transition_to(PipelineState.TRANSLATING)
+
+        # Pause
+        assert sm.transition_to(PipelineState.PAUSED) is True
+        assert sm.state == PipelineState.PAUSED
+
+        # Resume
+        assert sm.transition_to(PipelineState.TRANSLATING) is True
+        assert sm.state == PipelineState.TRANSLATING
+
+
+class TestInvalidTransitions:
+    """Test invalid state transitions."""
+
+    def test_idle_direct_to_translating(self):
+        """Test that IDLE cannot transition directly to TRANSLATING."""
+        sm = StateMachine()
+        result = sm.transition_to(PipelineState.TRANSLATING)
+
+        assert result is False
+        assert sm.state == PipelineState.IDLE
+
+    def test_completed_to_active_state(self):
+        """Test that COMPLETED cannot transition directly to active states."""
+        sm = StateMachine()
+        # Complete the pipeline
+        for state in [
+            PipelineState.FINGERPRINTING,
+            PipelineState.CLEANING,
+            PipelineState.TERM_EXTRACTION,
+            PipelineState.TRANSLATING,
+            PipelineState.UPLOADING,
+            PipelineState.COMPLETED,
+        ]:
+            sm.transition_to(state)
+
+        # Try to go directly to TRANSLATING
+        result = sm.transition_to(PipelineState.TRANSLATING)
+        assert result is False
+        assert sm.state == PipelineState.COMPLETED
+
+    def test_backward_transition(self):
+        """Test that backward transitions are not allowed."""
+        sm = StateMachine()
+        sm.transition_to(PipelineState.TRANSLATING)
+
+        # Try to go back to CLEANING
+        result = sm.transition_to(PipelineState.CLEANING)
+        assert result is False
+        assert sm.state == PipelineState.TRANSLATING
+
+    def test_can_transition_to(self):
+        """Test can_transition_to method."""
+        sm = StateMachine()
+
+        assert sm.can_transition_to(PipelineState.FINGERPRINTING) is True
+        assert sm.can_transition_to(PipelineState.TRANSLATING) is False
+
+        sm.transition_to(PipelineState.TRANSLATING)
+
+        assert sm.can_transition_to(PipelineState.UPLOADING) is True
+        assert sm.can_transition_to(PipelineState.CLEANING) is False
+
+
+class TestTransitionOrRaise:
+    """Test transition_or_raise method."""
+
+    def test_valid_transition_with_or_raise(self):
+        """Test valid transition with transition_or_raise."""
+        sm = StateMachine()
+        sm.transition_or_raise(PipelineState.FINGERPRINTING)
+        assert sm.state == PipelineState.FINGERPRINTING
+
+    def test_invalid_transition_raises_error(self):
+        """Test that invalid transition raises InvalidTransitionError."""
+        sm = StateMachine()
+
+        with pytest.raises(InvalidTransitionError) as exc_info:
+            sm.transition_or_raise(PipelineState.TRANSLATING)
+
+        assert exc_info.value.from_state == PipelineState.IDLE
+        assert exc_info.value.to_state == PipelineState.TRANSLATING
+
+
+class TestContextManagement:
+    """Test context storage and retrieval."""
+
+    def test_context_update_on_transition(self):
+        """Test that context is updated during transition."""
+        sm = StateMachine()
+        sm.transition_to(PipelineState.FINGERPRINTING, file_path="/test.txt")
+
+        assert sm.get_context_value("file_path") == "/test.txt"
+
+    def test_set_and_get_context_value(self):
+        """Test setting and getting context values."""
+        sm = StateMachine()
+
+        sm.set_context_value("key1", "value1")
+        sm.set_context_value("key2", 123)
+
+        assert sm.get_context_value("key1") == "value1"
+        assert sm.get_context_value("key2") == 123
+        assert sm.get_context_value("key3", "default") == "default"
+
+    def test_context_returns_copy(self):
+        """Test that context property returns a copy."""
+        sm = StateMachine()
+        sm.set_context_value("key", "value")
+
+        context1 = sm.context
+        context2 = sm.context
+
+        # Modifying returned dict should not affect internal state
+        context1["key"] = "modified"
+
+        assert sm.get_context_value("key") == "value"
+        assert context2["key"] == "value"
+
+    def test_clear_context(self):
+        """Test clearing context."""
+        sm = StateMachine()
+        sm.set_context_value("key1", "value1")
+        sm.set_context_value("key2", "value2")
+
+        sm.clear_context()
+
+        assert sm.context == {}
+        assert sm.get_context_value("key1") is None
+        assert sm.get_context_value("key2") is None
+
+
+class TestCallbacks:
+    """Test callback system."""
+
+    def test_on_transition_callback(self):
+        """Test on_transition callback."""
+        sm = StateMachine()
+        events = []
+
+        def callback(event):
+            events.append(event)
+
+        sm.register_callback("on_transition", callback)
+        sm.transition_to(PipelineState.FINGERPRINTING)
+
+        assert len(events) == 1
+        assert events[0].from_state == PipelineState.IDLE
+        assert events[0].to_state == PipelineState.FINGERPRINTING
+
+    def test_on_enter_state_callback(self):
+        """Test on_enter_<state> callbacks."""
+        sm = StateMachine()
+        states = []
+
+        sm.register_callback("on_enter_translating", lambda e: states.append("translating"))
+        sm.register_callback("on_enter_completed", lambda e: states.append("completed"))
+
+        # Flow through to TRANSLATING
+        sm.transition_to(PipelineState.FINGERPRINTING)
+        sm.transition_to(PipelineState.CLEANING)
+        sm.transition_to(PipelineState.TERM_EXTRACTION)
+        sm.transition_to(PipelineState.TRANSLATING)
+
+        assert "translating" in states
+
+    def test_multiple_callbacks_same_event(self):
+        """Test multiple callbacks for the same event."""
+        sm = StateMachine()
+        results = []
+
+        sm.register_callback("on_transition", lambda e: results.append(1))
+        sm.register_callback("on_transition", lambda e: results.append(2))
+        sm.register_callback("on_transition", lambda e: results.append(3))
+
+        sm.transition_to(PipelineState.FINGERPRINTING)
+
+        assert results == [1, 2, 3]
+
+    def test_unregister_callback(self):
+        """Test unregistering a callback."""
+        sm = StateMachine()
+        results = []
+
+        def callback1(e):
+            results.append(1)
+
+        def callback2(e):
+            results.append(2)
+
+        sm.register_callback("on_transition", callback1)
+        sm.register_callback("on_transition", callback2)
+
+        sm.unregister_callback("on_transition", callback1)
+        sm.transition_to(PipelineState.FINGERPRINTING)
+
+        assert results == [2]
+
+    def test_unregister_nonexistent_callback(self):
+        """Test unregistering a callback that wasn't registered."""
+        sm = StateMachine()
+
+        def callback(e):
+            pass
+
+        result = sm.unregister_callback("on_transition", callback)
+        assert result is False
+
+
+class TestTransitionHistory:
+    """Test transition history tracking."""
+
+    def test_history_records_transitions(self):
+        """Test that transitions are recorded in history."""
+        sm = StateMachine()
+
+        sm.transition_to(PipelineState.FINGERPRINTING)
+        sm.transition_to(PipelineState.CLEANING)
+        sm.transition_to(PipelineState.TERM_EXTRACTION)
+
+        assert len(sm.history) == 3
+
+    def test_history_order(self):
+        """Test that history maintains chronological order."""
+        sm = StateMachine()
+
+        states = [
+            PipelineState.FINGERPRINTING,
+            PipelineState.CLEANING,
+            PipelineState.TERM_EXTRACTION,
+        ]
+
+        for state in states:
+            sm.transition_to(state)
+
+        for i, event in enumerate(sm.history):
+            assert event.to_state == states[i]
+
+    def test_history_context_preserved(self):
+        """Test that context during transition is preserved in history."""
+        sm = StateMachine()
+
+        sm.transition_to(PipelineState.FINGERPRINTING, file="test.txt")
+        sm.transition_to(PipelineState.CLEANING, chapters=10)
+
+        assert sm.history[0].context == {"file": "test.txt"}
+        assert sm.history[1].context == {"chapters": 10}
+
+
+class TestReset:
+    """Test state machine reset functionality."""
+
+    def test_reset_to_idle(self):
+        """Test that reset returns state to IDLE."""
+        sm = StateMachine()
+        sm.transition_to(PipelineState.TRANSLATING)
+
+        sm.reset()
+
+        assert sm.state == PipelineState.IDLE
+
+    def test_reset_clears_context(self):
+        """Test that reset clears context."""
+        sm = StateMachine()
+        sm.set_context_value("key", "value")
+        sm.transition_to(PipelineState.TRANSLATING)
+
+        sm.reset()
+
+        assert sm.context == {}
+
+    def test_reset_clears_history(self):
+        """Test that reset clears history."""
+        sm = StateMachine()
+        sm.transition_to(PipelineState.FINGERPRINTING)
+        sm.transition_to(PipelineState.CLEANING)
+
+        sm.reset()
+
+        assert sm.history == []