| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494 |
- """
- Unit tests for Pipeline orchestration framework.
- Tests cover:
- - Stage execution in sequence
- - Result caching and retrieval
- - Exception handling
- - State reset functionality
- """
- import pytest
- from typing import Any
- from src.pipeline.pipeline import (
- Stage,
- PipelineExecutor,
- LambdaStage,
- StageResult,
- )
- class AddStage(Stage):
- """Test stage that adds a value."""
- def __init__(self, name: str, value: int):
- super().__init__(name)
- self.value = value
- def execute(self, input_data: int) -> int:
- return input_data + self.value
- class MultiplyStage(Stage):
- """Test stage that multiplies by a value."""
- def __init__(self, name: str, value: int):
- super().__init__(name)
- self.value = value
- def execute(self, input_data: int) -> int:
- return input_data * self.value
- class FailingStage(Stage):
- """Test stage that always raises an exception."""
- def __init__(self, name: str, error_message: str = "Stage failed"):
- super().__init__(name)
- self.error_message = error_message
- def execute(self, input_data: Any) -> Any:
- raise RuntimeError(self.error_message)
- class IdentityStage(Stage):
- """Test stage that returns input unchanged."""
- def execute(self, input_data: Any) -> Any:
- return input_data
- class TestStage:
- """Test cases for Stage base class."""
- def test_stage_has_name(self):
- """Test that stage stores its name."""
- stage = AddStage("add_ten", 10)
- assert stage.name == "add_ten"
- def test_stage_repr(self):
- """Test stage string representation."""
- stage = AddStage("add_ten", 10)
- assert "AddStage" in repr(stage)
- assert "add_ten" in repr(stage)
- def test_subclass_must_implement_execute(self):
- """Test that Stage subclasses must implement execute."""
- # With abc.ABC, TypeError is raised at instantiation time
- with pytest.raises(TypeError, match="abstract"):
- Stage("test") # type: ignore
- class TestLambdaStage:
- """Test cases for LambdaStage convenience class."""
- def test_lambda_stage_executes_function(self):
- """Test that LambdaStage wraps and executes a function."""
- stage = LambdaStage("double", lambda x: x * 2)
- assert stage.execute(5) == 10
- def test_lambda_stage_with_string_function(self):
- """Test LambdaStage with string manipulation."""
- stage = LambdaStage("uppercase", lambda s: s.upper())
- assert stage.execute("hello") == "HELLO"
- class TestPipelineExecutorCreation:
- """Test cases for PipelineExecutor creation and configuration."""
- def test_create_pipeline(self):
- """Test creating a new pipeline."""
- pipeline = PipelineExecutor("TestPipeline")
- assert pipeline.name == "TestPipeline"
- assert len(pipeline) == 0
- def test_add_stage(self):
- """Test adding a stage to the pipeline."""
- pipeline = PipelineExecutor()
- stage = AddStage("add_ten", 10)
- result = pipeline.add_stage(stage)
- assert len(pipeline) == 1
- assert result is pipeline # Method chaining
- def test_add_multiple_stages(self):
- """Test adding multiple stages."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("double", 2))
- assert len(pipeline) == 2
- def test_add_stage_raises_for_non_stage(self):
- """Test that adding non-Stage objects raises TypeError."""
- pipeline = PipelineExecutor()
- with pytest.raises(TypeError):
- pipeline.add_stage("not a stage") # type: ignore
- def test_get_stage_names(self):
- """Test getting list of stage names."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("double", 2))
- names = pipeline.get_stage_names()
- assert names == ["add_ten", "double"]
- def test_clear_stages(self):
- """Test clearing all stages."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.clear_stages()
- assert len(pipeline) == 0
- class TestPipelineExecutorSequentialExecution:
- """Test cases for sequential stage execution."""
- def test_execute_single_stage(self):
- """Test executing a single stage."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- result = pipeline.execute(5)
- assert result == 15
- def test_execute_multiple_stages_in_order(self):
- """Test that stages execute in the correct order."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("double", 2))
- # (5 + 10) * 2 = 30
- result = pipeline.execute(5)
- assert result == 30
- def test_execute_complex_pipeline(self):
- """Test a more complex multi-stage pipeline."""
- pipeline = PipelineExecutor("Complex")
- pipeline.add_stage(AddStage("add_five", 5))
- pipeline.add_stage(MultiplyStage("triple", 3))
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("halve", 0.5))
- # ((5 + 5) * 3 + 10) * 0.5 = (10 * 3 + 10) * 0.5 = 40 * 0.5 = 20
- result = pipeline.execute(5)
- assert result == 20.0
- def test_execute_with_string_data(self):
- """Test pipeline with string data transformation."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(LambdaStage("prefix", lambda s: ">> " + s))
- pipeline.add_stage(LambdaStage("suffix", lambda s: s + " <<"))
- result = pipeline.execute("hello")
- assert result == ">> hello <<"
- def test_execute_with_list_data(self):
- """Test pipeline with list data transformation."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(LambdaStage("append", lambda lst: lst + [3]))
- pipeline.add_stage(LambdaStage("extend", lambda lst: lst + [4, 5]))
- result = pipeline.execute([1, 2])
- assert result == [1, 2, 3, 4, 5]
- def test_execute_empty_pipeline_raises(self):
- """Test that executing an empty pipeline raises ValueError."""
- pipeline = PipelineExecutor()
- with pytest.raises(ValueError, match="no stages"):
- pipeline.execute(10)
- class TestPipelineStageResultCaching:
- """Test cases for stage result caching and retrieval."""
- def test_get_stage_result(self):
- """Test retrieving a specific stage result."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("double", 2))
- pipeline.execute(5)
- result = pipeline.get_stage_result("add_ten")
- assert result is not None
- assert result.stage_name == "add_ten"
- assert result.success is True
- assert result.output == 15
- def test_get_stage_result_for_nonexistent_stage(self):
- """Test getting result for non-existent stage."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.execute(5)
- result = pipeline.get_stage_result("nonexistent")
- assert result is None
- def test_get_all_results(self):
- """Test getting all stage results."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("double", 2))
- pipeline.execute(5)
- results = pipeline.get_all_results()
- assert len(results) == 2
- assert "add_ten" in results
- assert "double" in results
- def test_get_final_output(self):
- """Test getting the final pipeline output."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("double", 2))
- pipeline.execute(5)
- output = pipeline.get_final_output()
- assert output == 30
- def test_intermediate_stage_outputs(self):
- """Test that intermediate stage outputs are correct."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_five", 5))
- pipeline.add_stage(MultiplyStage("triple", 3))
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.execute(5)
- # Input: 5
- # add_five: 5 + 5 = 10
- # triple: 10 * 3 = 30
- # add_ten: 30 + 10 = 40
- assert pipeline.get_stage_result("add_five").output == 10
- assert pipeline.get_stage_result("triple").output == 30
- assert pipeline.get_stage_result("add_ten").output == 40
- class TestPipelineExceptionHandling:
- """Test cases for exception handling during execution."""
- def test_exception_stops_execution(self):
- """Test that an exception stops further stage execution."""
- executed_count = [0]
- class CountingStage(Stage):
- def execute(self, input_data):
- executed_count[0] += 1
- return input_data
- pipeline = PipelineExecutor()
- pipeline.add_stage(CountingStage("first"))
- pipeline.add_stage(FailingStage("fails"))
- pipeline.add_stage(CountingStage("never_runs"))
- pipeline.execute(10)
- # Only first stage and failing stage should have run
- assert executed_count[0] == 1
- def test_exception_is_caught(self):
- """Test that exceptions are caught and stored."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(FailingStage("fails", "Custom error"))
- pipeline.add_stage(MultiplyStage("double", 2))
- result = pipeline.execute(5)
- assert result is None # Returns None on failure
- def test_get_last_exception(self):
- """Test retrieving the last exception."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(FailingStage("fails", "Test error"))
- pipeline.execute(5)
- exception = pipeline.get_last_exception()
- assert exception is not None
- assert isinstance(exception, RuntimeError)
- assert "Test error" in str(exception)
- def test_get_stopped_at_stage(self):
- """Test getting the stage where execution stopped."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(FailingStage("fails"))
- pipeline.add_stage(MultiplyStage("double", 2))
- pipeline.execute(5)
- stopped_at = pipeline.get_stopped_at_stage()
- assert stopped_at == "fails"
- def test_failed_stage_result(self):
- """Test that failed stage result has correct info."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(FailingStage("fails"))
- pipeline.execute(5)
- result = pipeline.get_stage_result("fails")
- assert result.success is False
- assert result.error is not None
- assert isinstance(result.error, RuntimeError)
- def test_is_completed_returns_false_on_failure(self):
- """Test that is_completed returns False after failure."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(FailingStage("fails"))
- pipeline.execute(5)
- assert pipeline.is_completed() is False
- def test_is_completed_returns_true_on_success(self):
- """Test that is_completed returns True after success."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.execute(5)
- assert pipeline.is_completed() is True
- def test_first_stage_failure(self):
- """Test handling of failure in the first stage."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(FailingStage("first_fails"))
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.execute(5)
- assert pipeline.get_stopped_at_stage() == "first_fails"
- assert pipeline.get_stage_result("add_ten") is None
- class TestPipelineReset:
- """Test cases for pipeline state reset functionality."""
- def test_reset_clears_results(self):
- """Test that reset clears cached results."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("double", 2))
- pipeline.execute(5)
- assert pipeline.get_all_results()
- pipeline.reset()
- assert not pipeline.get_all_results()
- def test_reset_clears_exception(self):
- """Test that reset clears exception state."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(FailingStage("fails"))
- pipeline.execute(5)
- assert pipeline.get_last_exception() is not None
- pipeline.reset()
- assert pipeline.get_last_exception() is None
- def test_reset_clears_stopped_at(self):
- """Test that reset clears stopped_at_stage."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(FailingStage("fails"))
- pipeline.execute(5)
- assert pipeline.get_stopped_at_stage() is not None
- pipeline.reset()
- assert pipeline.get_stopped_at_stage() is None
- def test_re_execute_after_reset(self):
- """Test that pipeline can be re-executed after reset."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.execute(5)
- assert pipeline.get_final_output() == 15
- pipeline.reset()
- pipeline.execute(20)
- assert pipeline.get_final_output() == 30
- def test_re_execute_after_failure_and_reset(self):
- """Test re-execution after failure and reset."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(FailingStage("fails"))
- pipeline.execute(5)
- assert pipeline.is_completed() is False
- # Fix the pipeline by replacing the failing stage
- pipeline.reset()
- pipeline._stages = [pipeline._stages[0], MultiplyStage("double", 2)]
- pipeline.execute(5)
- assert pipeline.is_completed() is True
- assert pipeline.get_final_output() == 30
- class TestPipelineRepr:
- """Test cases for string representations."""
- def test_pipeline_repr(self):
- """Test pipeline string representation."""
- pipeline = PipelineExecutor("TestPipeline")
- pipeline.add_stage(AddStage("add_ten", 10))
- pipeline.add_stage(MultiplyStage("double", 2))
- repr_str = repr(pipeline)
- assert "TestPipeline" in repr_str
- assert "2" in repr_str # Stage count
- def test_stage_result_repr_success(self):
- """Test StageResult repr for success."""
- result = StageResult("test_stage", success=True, output=42)
- assert "test_stage" in repr(result)
- assert "success=True" in repr(result)
- def test_stage_result_repr_failure(self):
- """Test StageResult repr for failure."""
- error = RuntimeError("test error")
- result = StageResult("test_stage", success=False, error=error)
- assert "test_stage" in repr(result)
- assert "success=False" in repr(result)
- class TestPipelineWithVariousDataTypes:
- """Test pipeline behavior with different data types."""
- def test_pipeline_with_dict(self):
- """Test pipeline transforming dictionaries."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(LambdaStage("add_key", lambda d: {**d, "y": 20}))
- pipeline.add_stage(LambdaStage("add_another", lambda d: {**d, "z": 30}))
- result = pipeline.execute({"x": 10})
- assert result == {"x": 10, "y": 20, "z": 30}
- def test_pipeline_with_identity_stages(self):
- """Test pipeline with stages that pass data through unchanged."""
- pipeline = PipelineExecutor()
- pipeline.add_stage(IdentityStage("first"))
- pipeline.add_stage(IdentityStage("second"))
- pipeline.add_stage(IdentityStage("third"))
- test_data = {"complex": "object"}
- result = pipeline.execute(test_data)
- assert result is test_data
- assert pipeline.is_completed() is True
|