|
|
@@ -0,0 +1,494 @@
|
|
|
+"""
|
|
|
+Unit tests for Pipeline orchestration framework.
|
|
|
+
|
|
|
+Tests cover:
|
|
|
+- Stage execution in sequence
|
|
|
+- Result caching and retrieval
|
|
|
+- Exception handling
|
|
|
+- State reset functionality
|
|
|
+"""
|
|
|
+
|
|
|
+import pytest
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+from src.pipeline.pipeline import (
|
|
|
+ Stage,
|
|
|
+ PipelineExecutor,
|
|
|
+ LambdaStage,
|
|
|
+ StageResult,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+class AddStage(Stage):
|
|
|
+ """Test stage that adds a value."""
|
|
|
+
|
|
|
+ def __init__(self, name: str, value: int):
|
|
|
+ super().__init__(name)
|
|
|
+ self.value = value
|
|
|
+
|
|
|
+ def execute(self, input_data: int) -> int:
|
|
|
+ return input_data + self.value
|
|
|
+
|
|
|
+
|
|
|
+class MultiplyStage(Stage):
|
|
|
+ """Test stage that multiplies by a value."""
|
|
|
+
|
|
|
+ def __init__(self, name: str, value: int):
|
|
|
+ super().__init__(name)
|
|
|
+ self.value = value
|
|
|
+
|
|
|
+ def execute(self, input_data: int) -> int:
|
|
|
+ return input_data * self.value
|
|
|
+
|
|
|
+
|
|
|
+class FailingStage(Stage):
|
|
|
+ """Test stage that always raises an exception."""
|
|
|
+
|
|
|
+ def __init__(self, name: str, error_message: str = "Stage failed"):
|
|
|
+ super().__init__(name)
|
|
|
+ self.error_message = error_message
|
|
|
+
|
|
|
+ def execute(self, input_data: Any) -> Any:
|
|
|
+ raise RuntimeError(self.error_message)
|
|
|
+
|
|
|
+
|
|
|
+class IdentityStage(Stage):
|
|
|
+ """Test stage that returns input unchanged."""
|
|
|
+
|
|
|
+ def execute(self, input_data: Any) -> Any:
|
|
|
+ return input_data
|
|
|
+
|
|
|
+
|
|
|
+class TestStage:
|
|
|
+ """Test cases for Stage base class."""
|
|
|
+
|
|
|
+ def test_stage_has_name(self):
|
|
|
+ """Test that stage stores its name."""
|
|
|
+ stage = AddStage("add_ten", 10)
|
|
|
+ assert stage.name == "add_ten"
|
|
|
+
|
|
|
+ def test_stage_repr(self):
|
|
|
+ """Test stage string representation."""
|
|
|
+ stage = AddStage("add_ten", 10)
|
|
|
+ assert "AddStage" in repr(stage)
|
|
|
+ assert "add_ten" in repr(stage)
|
|
|
+
|
|
|
+ def test_subclass_must_implement_execute(self):
|
|
|
+ """Test that Stage subclasses must implement execute."""
|
|
|
+ # With abc.ABC, TypeError is raised at instantiation time
|
|
|
+ with pytest.raises(TypeError, match="abstract"):
|
|
|
+ Stage("test") # type: ignore
|
|
|
+
|
|
|
+
|
|
|
+class TestLambdaStage:
|
|
|
+ """Test cases for LambdaStage convenience class."""
|
|
|
+
|
|
|
+ def test_lambda_stage_executes_function(self):
|
|
|
+ """Test that LambdaStage wraps and executes a function."""
|
|
|
+ stage = LambdaStage("double", lambda x: x * 2)
|
|
|
+ assert stage.execute(5) == 10
|
|
|
+
|
|
|
+ def test_lambda_stage_with_string_function(self):
|
|
|
+ """Test LambdaStage with string manipulation."""
|
|
|
+ stage = LambdaStage("uppercase", lambda s: s.upper())
|
|
|
+ assert stage.execute("hello") == "HELLO"
|
|
|
+
|
|
|
+
|
|
|
+class TestPipelineExecutorCreation:
|
|
|
+ """Test cases for PipelineExecutor creation and configuration."""
|
|
|
+
|
|
|
+ def test_create_pipeline(self):
|
|
|
+ """Test creating a new pipeline."""
|
|
|
+ pipeline = PipelineExecutor("TestPipeline")
|
|
|
+ assert pipeline.name == "TestPipeline"
|
|
|
+ assert len(pipeline) == 0
|
|
|
+
|
|
|
+ def test_add_stage(self):
|
|
|
+ """Test adding a stage to the pipeline."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ stage = AddStage("add_ten", 10)
|
|
|
+ result = pipeline.add_stage(stage)
|
|
|
+
|
|
|
+ assert len(pipeline) == 1
|
|
|
+ assert result is pipeline # Method chaining
|
|
|
+
|
|
|
+ def test_add_multiple_stages(self):
|
|
|
+ """Test adding multiple stages."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ assert len(pipeline) == 2
|
|
|
+
|
|
|
+ def test_add_stage_raises_for_non_stage(self):
|
|
|
+ """Test that adding non-Stage objects raises TypeError."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+
|
|
|
+ with pytest.raises(TypeError):
|
|
|
+ pipeline.add_stage("not a stage") # type: ignore
|
|
|
+
|
|
|
+ def test_get_stage_names(self):
|
|
|
+ """Test getting list of stage names."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ names = pipeline.get_stage_names()
|
|
|
+ assert names == ["add_ten", "double"]
|
|
|
+
|
|
|
+ def test_clear_stages(self):
|
|
|
+ """Test clearing all stages."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.clear_stages()
|
|
|
+
|
|
|
+ assert len(pipeline) == 0
|
|
|
+
|
|
|
+
|
|
|
+class TestPipelineExecutorSequentialExecution:
|
|
|
+ """Test cases for sequential stage execution."""
|
|
|
+
|
|
|
+ def test_execute_single_stage(self):
|
|
|
+ """Test executing a single stage."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+
|
|
|
+ result = pipeline.execute(5)
|
|
|
+ assert result == 15
|
|
|
+
|
|
|
+ def test_execute_multiple_stages_in_order(self):
|
|
|
+ """Test that stages execute in the correct order."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ # (5 + 10) * 2 = 30
|
|
|
+ result = pipeline.execute(5)
|
|
|
+ assert result == 30
|
|
|
+
|
|
|
+ def test_execute_complex_pipeline(self):
|
|
|
+ """Test a more complex multi-stage pipeline."""
|
|
|
+ pipeline = PipelineExecutor("Complex")
|
|
|
+ pipeline.add_stage(AddStage("add_five", 5))
|
|
|
+ pipeline.add_stage(MultiplyStage("triple", 3))
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("halve", 0.5))
|
|
|
+
|
|
|
+ # ((5 + 5) * 3 + 10) * 0.5 = (10 * 3 + 10) * 0.5 = 40 * 0.5 = 20
|
|
|
+ result = pipeline.execute(5)
|
|
|
+ assert result == 20.0
|
|
|
+
|
|
|
+ def test_execute_with_string_data(self):
|
|
|
+ """Test pipeline with string data transformation."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(LambdaStage("prefix", lambda s: ">> " + s))
|
|
|
+ pipeline.add_stage(LambdaStage("suffix", lambda s: s + " <<"))
|
|
|
+
|
|
|
+ result = pipeline.execute("hello")
|
|
|
+ assert result == ">> hello <<"
|
|
|
+
|
|
|
+ def test_execute_with_list_data(self):
|
|
|
+ """Test pipeline with list data transformation."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(LambdaStage("append", lambda lst: lst + [3]))
|
|
|
+ pipeline.add_stage(LambdaStage("extend", lambda lst: lst + [4, 5]))
|
|
|
+
|
|
|
+ result = pipeline.execute([1, 2])
|
|
|
+ assert result == [1, 2, 3, 4, 5]
|
|
|
+
|
|
|
+ def test_execute_empty_pipeline_raises(self):
|
|
|
+ """Test that executing an empty pipeline raises ValueError."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+
|
|
|
+ with pytest.raises(ValueError, match="no stages"):
|
|
|
+ pipeline.execute(10)
|
|
|
+
|
|
|
+
|
|
|
+class TestPipelineStageResultCaching:
|
|
|
+ """Test cases for stage result caching and retrieval."""
|
|
|
+
|
|
|
+ def test_get_stage_result(self):
|
|
|
+ """Test retrieving a specific stage result."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ result = pipeline.get_stage_result("add_ten")
|
|
|
+ assert result is not None
|
|
|
+ assert result.stage_name == "add_ten"
|
|
|
+ assert result.success is True
|
|
|
+ assert result.output == 15
|
|
|
+
|
|
|
+ def test_get_stage_result_for_nonexistent_stage(self):
|
|
|
+ """Test getting result for non-existent stage."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ result = pipeline.get_stage_result("nonexistent")
|
|
|
+ assert result is None
|
|
|
+
|
|
|
+ def test_get_all_results(self):
|
|
|
+ """Test getting all stage results."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ results = pipeline.get_all_results()
|
|
|
+ assert len(results) == 2
|
|
|
+ assert "add_ten" in results
|
|
|
+ assert "double" in results
|
|
|
+
|
|
|
+ def test_get_final_output(self):
|
|
|
+ """Test getting the final pipeline output."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ output = pipeline.get_final_output()
|
|
|
+ assert output == 30
|
|
|
+
|
|
|
+ def test_intermediate_stage_outputs(self):
|
|
|
+ """Test that intermediate stage outputs are correct."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_five", 5))
|
|
|
+ pipeline.add_stage(MultiplyStage("triple", 3))
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ # Input: 5
|
|
|
+ # add_five: 5 + 5 = 10
|
|
|
+ # triple: 10 * 3 = 30
|
|
|
+ # add_ten: 30 + 10 = 40
|
|
|
+ assert pipeline.get_stage_result("add_five").output == 10
|
|
|
+ assert pipeline.get_stage_result("triple").output == 30
|
|
|
+ assert pipeline.get_stage_result("add_ten").output == 40
|
|
|
+
|
|
|
+
|
|
|
+class TestPipelineExceptionHandling:
|
|
|
+ """Test cases for exception handling during execution."""
|
|
|
+
|
|
|
+ def test_exception_stops_execution(self):
|
|
|
+ """Test that an exception stops further stage execution."""
|
|
|
+ executed_count = [0]
|
|
|
+
|
|
|
+ class CountingStage(Stage):
|
|
|
+ def execute(self, input_data):
|
|
|
+ executed_count[0] += 1
|
|
|
+ return input_data
|
|
|
+
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(CountingStage("first"))
|
|
|
+ pipeline.add_stage(FailingStage("fails"))
|
|
|
+ pipeline.add_stage(CountingStage("never_runs"))
|
|
|
+
|
|
|
+ pipeline.execute(10)
|
|
|
+
|
|
|
+ # Only first stage and failing stage should have run
|
|
|
+ assert executed_count[0] == 1
|
|
|
+
|
|
|
+ def test_exception_is_caught(self):
|
|
|
+ """Test that exceptions are caught and stored."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(FailingStage("fails", "Custom error"))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ result = pipeline.execute(5)
|
|
|
+
|
|
|
+ assert result is None # Returns None on failure
|
|
|
+
|
|
|
+ def test_get_last_exception(self):
|
|
|
+ """Test retrieving the last exception."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(FailingStage("fails", "Test error"))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ exception = pipeline.get_last_exception()
|
|
|
+ assert exception is not None
|
|
|
+ assert isinstance(exception, RuntimeError)
|
|
|
+ assert "Test error" in str(exception)
|
|
|
+
|
|
|
+ def test_get_stopped_at_stage(self):
|
|
|
+ """Test getting the stage where execution stopped."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(FailingStage("fails"))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ stopped_at = pipeline.get_stopped_at_stage()
|
|
|
+ assert stopped_at == "fails"
|
|
|
+
|
|
|
+ def test_failed_stage_result(self):
|
|
|
+ """Test that failed stage result has correct info."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(FailingStage("fails"))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ result = pipeline.get_stage_result("fails")
|
|
|
+ assert result.success is False
|
|
|
+ assert result.error is not None
|
|
|
+ assert isinstance(result.error, RuntimeError)
|
|
|
+
|
|
|
+ def test_is_completed_returns_false_on_failure(self):
|
|
|
+ """Test that is_completed returns False after failure."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(FailingStage("fails"))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ assert pipeline.is_completed() is False
|
|
|
+
|
|
|
+ def test_is_completed_returns_true_on_success(self):
|
|
|
+ """Test that is_completed returns True after success."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ assert pipeline.is_completed() is True
|
|
|
+
|
|
|
+ def test_first_stage_failure(self):
|
|
|
+ """Test handling of failure in the first stage."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(FailingStage("first_fails"))
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+
|
|
|
+ assert pipeline.get_stopped_at_stage() == "first_fails"
|
|
|
+ assert pipeline.get_stage_result("add_ten") is None
|
|
|
+
|
|
|
+
|
|
|
+class TestPipelineReset:
|
|
|
+ """Test cases for pipeline state reset functionality."""
|
|
|
+
|
|
|
+ def test_reset_clears_results(self):
|
|
|
+ """Test that reset clears cached results."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+ assert pipeline.get_all_results()
|
|
|
+
|
|
|
+ pipeline.reset()
|
|
|
+ assert not pipeline.get_all_results()
|
|
|
+
|
|
|
+ def test_reset_clears_exception(self):
|
|
|
+ """Test that reset clears exception state."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(FailingStage("fails"))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+ assert pipeline.get_last_exception() is not None
|
|
|
+
|
|
|
+ pipeline.reset()
|
|
|
+ assert pipeline.get_last_exception() is None
|
|
|
+
|
|
|
+ def test_reset_clears_stopped_at(self):
|
|
|
+ """Test that reset clears stopped_at_stage."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(FailingStage("fails"))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+ assert pipeline.get_stopped_at_stage() is not None
|
|
|
+
|
|
|
+ pipeline.reset()
|
|
|
+ assert pipeline.get_stopped_at_stage() is None
|
|
|
+
|
|
|
+ def test_re_execute_after_reset(self):
|
|
|
+ """Test that pipeline can be re-executed after reset."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+ assert pipeline.get_final_output() == 15
|
|
|
+
|
|
|
+ pipeline.reset()
|
|
|
+ pipeline.execute(20)
|
|
|
+ assert pipeline.get_final_output() == 30
|
|
|
+
|
|
|
+ def test_re_execute_after_failure_and_reset(self):
|
|
|
+ """Test re-execution after failure and reset."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(FailingStage("fails"))
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+ assert pipeline.is_completed() is False
|
|
|
+
|
|
|
+ # Fix the pipeline by replacing the failing stage
|
|
|
+ pipeline.reset()
|
|
|
+ pipeline._stages = [pipeline._stages[0], MultiplyStage("double", 2)]
|
|
|
+
|
|
|
+ pipeline.execute(5)
|
|
|
+ assert pipeline.is_completed() is True
|
|
|
+ assert pipeline.get_final_output() == 30
|
|
|
+
|
|
|
+
|
|
|
+class TestPipelineRepr:
|
|
|
+ """Test cases for string representations."""
|
|
|
+
|
|
|
+ def test_pipeline_repr(self):
|
|
|
+ """Test pipeline string representation."""
|
|
|
+ pipeline = PipelineExecutor("TestPipeline")
|
|
|
+ pipeline.add_stage(AddStage("add_ten", 10))
|
|
|
+ pipeline.add_stage(MultiplyStage("double", 2))
|
|
|
+
|
|
|
+ repr_str = repr(pipeline)
|
|
|
+ assert "TestPipeline" in repr_str
|
|
|
+ assert "2" in repr_str # Stage count
|
|
|
+
|
|
|
+ def test_stage_result_repr_success(self):
|
|
|
+ """Test StageResult repr for success."""
|
|
|
+ result = StageResult("test_stage", success=True, output=42)
|
|
|
+ assert "test_stage" in repr(result)
|
|
|
+ assert "success=True" in repr(result)
|
|
|
+
|
|
|
+ def test_stage_result_repr_failure(self):
|
|
|
+ """Test StageResult repr for failure."""
|
|
|
+ error = RuntimeError("test error")
|
|
|
+ result = StageResult("test_stage", success=False, error=error)
|
|
|
+ assert "test_stage" in repr(result)
|
|
|
+ assert "success=False" in repr(result)
|
|
|
+
|
|
|
+
|
|
|
+class TestPipelineWithVariousDataTypes:
|
|
|
+ """Test pipeline behavior with different data types."""
|
|
|
+
|
|
|
+ def test_pipeline_with_dict(self):
|
|
|
+ """Test pipeline transforming dictionaries."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(LambdaStage("add_key", lambda d: {**d, "y": 20}))
|
|
|
+ pipeline.add_stage(LambdaStage("add_another", lambda d: {**d, "z": 30}))
|
|
|
+
|
|
|
+ result = pipeline.execute({"x": 10})
|
|
|
+ assert result == {"x": 10, "y": 20, "z": 30}
|
|
|
+
|
|
|
+ def test_pipeline_with_identity_stages(self):
|
|
|
+ """Test pipeline with stages that pass data through unchanged."""
|
|
|
+ pipeline = PipelineExecutor()
|
|
|
+ pipeline.add_stage(IdentityStage("first"))
|
|
|
+ pipeline.add_stage(IdentityStage("second"))
|
|
|
+ pipeline.add_stage(IdentityStage("third"))
|
|
|
+
|
|
|
+ test_data = {"complex": "object"}
|
|
|
+ result = pipeline.execute(test_data)
|
|
|
+
|
|
|
+ assert result is test_data
|
|
|
+ assert pipeline.is_completed() is True
|