""" Unit tests for Pipeline orchestration framework. Tests cover: - Stage execution in sequence - Result caching and retrieval - Exception handling - State reset functionality """ import pytest from typing import Any from src.pipeline.pipeline import ( Stage, PipelineExecutor, LambdaStage, StageResult, ) class AddStage(Stage): """Test stage that adds a value.""" def __init__(self, name: str, value: int): super().__init__(name) self.value = value def execute(self, input_data: int) -> int: return input_data + self.value class MultiplyStage(Stage): """Test stage that multiplies by a value.""" def __init__(self, name: str, value: int): super().__init__(name) self.value = value def execute(self, input_data: int) -> int: return input_data * self.value class FailingStage(Stage): """Test stage that always raises an exception.""" def __init__(self, name: str, error_message: str = "Stage failed"): super().__init__(name) self.error_message = error_message def execute(self, input_data: Any) -> Any: raise RuntimeError(self.error_message) class IdentityStage(Stage): """Test stage that returns input unchanged.""" def execute(self, input_data: Any) -> Any: return input_data class TestStage: """Test cases for Stage base class.""" def test_stage_has_name(self): """Test that stage stores its name.""" stage = AddStage("add_ten", 10) assert stage.name == "add_ten" def test_stage_repr(self): """Test stage string representation.""" stage = AddStage("add_ten", 10) assert "AddStage" in repr(stage) assert "add_ten" in repr(stage) def test_subclass_must_implement_execute(self): """Test that Stage subclasses must implement execute.""" # With abc.ABC, TypeError is raised at instantiation time with pytest.raises(TypeError, match="abstract"): Stage("test") # type: ignore class TestLambdaStage: """Test cases for LambdaStage convenience class.""" def test_lambda_stage_executes_function(self): """Test that LambdaStage wraps and executes a function.""" stage = LambdaStage("double", lambda x: x * 2) assert stage.execute(5) == 10 def test_lambda_stage_with_string_function(self): """Test LambdaStage with string manipulation.""" stage = LambdaStage("uppercase", lambda s: s.upper()) assert stage.execute("hello") == "HELLO" class TestPipelineExecutorCreation: """Test cases for PipelineExecutor creation and configuration.""" def test_create_pipeline(self): """Test creating a new pipeline.""" pipeline = PipelineExecutor("TestPipeline") assert pipeline.name == "TestPipeline" assert len(pipeline) == 0 def test_add_stage(self): """Test adding a stage to the pipeline.""" pipeline = PipelineExecutor() stage = AddStage("add_ten", 10) result = pipeline.add_stage(stage) assert len(pipeline) == 1 assert result is pipeline # Method chaining def test_add_multiple_stages(self): """Test adding multiple stages.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("double", 2)) assert len(pipeline) == 2 def test_add_stage_raises_for_non_stage(self): """Test that adding non-Stage objects raises TypeError.""" pipeline = PipelineExecutor() with pytest.raises(TypeError): pipeline.add_stage("not a stage") # type: ignore def test_get_stage_names(self): """Test getting list of stage names.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("double", 2)) names = pipeline.get_stage_names() assert names == ["add_ten", "double"] def test_clear_stages(self): """Test clearing all stages.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.clear_stages() assert len(pipeline) == 0 class TestPipelineExecutorSequentialExecution: """Test cases for sequential stage execution.""" def test_execute_single_stage(self): """Test executing a single stage.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) result = pipeline.execute(5) assert result == 15 def test_execute_multiple_stages_in_order(self): """Test that stages execute in the correct order.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("double", 2)) # (5 + 10) * 2 = 30 result = pipeline.execute(5) assert result == 30 def test_execute_complex_pipeline(self): """Test a more complex multi-stage pipeline.""" pipeline = PipelineExecutor("Complex") pipeline.add_stage(AddStage("add_five", 5)) pipeline.add_stage(MultiplyStage("triple", 3)) pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("halve", 0.5)) # ((5 + 5) * 3 + 10) * 0.5 = (10 * 3 + 10) * 0.5 = 40 * 0.5 = 20 result = pipeline.execute(5) assert result == 20.0 def test_execute_with_string_data(self): """Test pipeline with string data transformation.""" pipeline = PipelineExecutor() pipeline.add_stage(LambdaStage("prefix", lambda s: ">> " + s)) pipeline.add_stage(LambdaStage("suffix", lambda s: s + " <<")) result = pipeline.execute("hello") assert result == ">> hello <<" def test_execute_with_list_data(self): """Test pipeline with list data transformation.""" pipeline = PipelineExecutor() pipeline.add_stage(LambdaStage("append", lambda lst: lst + [3])) pipeline.add_stage(LambdaStage("extend", lambda lst: lst + [4, 5])) result = pipeline.execute([1, 2]) assert result == [1, 2, 3, 4, 5] def test_execute_empty_pipeline_raises(self): """Test that executing an empty pipeline raises ValueError.""" pipeline = PipelineExecutor() with pytest.raises(ValueError, match="no stages"): pipeline.execute(10) class TestPipelineStageResultCaching: """Test cases for stage result caching and retrieval.""" def test_get_stage_result(self): """Test retrieving a specific stage result.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("double", 2)) pipeline.execute(5) result = pipeline.get_stage_result("add_ten") assert result is not None assert result.stage_name == "add_ten" assert result.success is True assert result.output == 15 def test_get_stage_result_for_nonexistent_stage(self): """Test getting result for non-existent stage.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.execute(5) result = pipeline.get_stage_result("nonexistent") assert result is None def test_get_all_results(self): """Test getting all stage results.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("double", 2)) pipeline.execute(5) results = pipeline.get_all_results() assert len(results) == 2 assert "add_ten" in results assert "double" in results def test_get_final_output(self): """Test getting the final pipeline output.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("double", 2)) pipeline.execute(5) output = pipeline.get_final_output() assert output == 30 def test_intermediate_stage_outputs(self): """Test that intermediate stage outputs are correct.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_five", 5)) pipeline.add_stage(MultiplyStage("triple", 3)) pipeline.add_stage(AddStage("add_ten", 10)) pipeline.execute(5) # Input: 5 # add_five: 5 + 5 = 10 # triple: 10 * 3 = 30 # add_ten: 30 + 10 = 40 assert pipeline.get_stage_result("add_five").output == 10 assert pipeline.get_stage_result("triple").output == 30 assert pipeline.get_stage_result("add_ten").output == 40 class TestPipelineExceptionHandling: """Test cases for exception handling during execution.""" def test_exception_stops_execution(self): """Test that an exception stops further stage execution.""" executed_count = [0] class CountingStage(Stage): def execute(self, input_data): executed_count[0] += 1 return input_data pipeline = PipelineExecutor() pipeline.add_stage(CountingStage("first")) pipeline.add_stage(FailingStage("fails")) pipeline.add_stage(CountingStage("never_runs")) pipeline.execute(10) # Only first stage and failing stage should have run assert executed_count[0] == 1 def test_exception_is_caught(self): """Test that exceptions are caught and stored.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(FailingStage("fails", "Custom error")) pipeline.add_stage(MultiplyStage("double", 2)) result = pipeline.execute(5) assert result is None # Returns None on failure def test_get_last_exception(self): """Test retrieving the last exception.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(FailingStage("fails", "Test error")) pipeline.execute(5) exception = pipeline.get_last_exception() assert exception is not None assert isinstance(exception, RuntimeError) assert "Test error" in str(exception) def test_get_stopped_at_stage(self): """Test getting the stage where execution stopped.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(FailingStage("fails")) pipeline.add_stage(MultiplyStage("double", 2)) pipeline.execute(5) stopped_at = pipeline.get_stopped_at_stage() assert stopped_at == "fails" def test_failed_stage_result(self): """Test that failed stage result has correct info.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(FailingStage("fails")) pipeline.execute(5) result = pipeline.get_stage_result("fails") assert result.success is False assert result.error is not None assert isinstance(result.error, RuntimeError) def test_is_completed_returns_false_on_failure(self): """Test that is_completed returns False after failure.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(FailingStage("fails")) pipeline.execute(5) assert pipeline.is_completed() is False def test_is_completed_returns_true_on_success(self): """Test that is_completed returns True after success.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.execute(5) assert pipeline.is_completed() is True def test_first_stage_failure(self): """Test handling of failure in the first stage.""" pipeline = PipelineExecutor() pipeline.add_stage(FailingStage("first_fails")) pipeline.add_stage(AddStage("add_ten", 10)) pipeline.execute(5) assert pipeline.get_stopped_at_stage() == "first_fails" assert pipeline.get_stage_result("add_ten") is None class TestPipelineReset: """Test cases for pipeline state reset functionality.""" def test_reset_clears_results(self): """Test that reset clears cached results.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("double", 2)) pipeline.execute(5) assert pipeline.get_all_results() pipeline.reset() assert not pipeline.get_all_results() def test_reset_clears_exception(self): """Test that reset clears exception state.""" pipeline = PipelineExecutor() pipeline.add_stage(FailingStage("fails")) pipeline.execute(5) assert pipeline.get_last_exception() is not None pipeline.reset() assert pipeline.get_last_exception() is None def test_reset_clears_stopped_at(self): """Test that reset clears stopped_at_stage.""" pipeline = PipelineExecutor() pipeline.add_stage(FailingStage("fails")) pipeline.execute(5) assert pipeline.get_stopped_at_stage() is not None pipeline.reset() assert pipeline.get_stopped_at_stage() is None def test_re_execute_after_reset(self): """Test that pipeline can be re-executed after reset.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.execute(5) assert pipeline.get_final_output() == 15 pipeline.reset() pipeline.execute(20) assert pipeline.get_final_output() == 30 def test_re_execute_after_failure_and_reset(self): """Test re-execution after failure and reset.""" pipeline = PipelineExecutor() pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(FailingStage("fails")) pipeline.execute(5) assert pipeline.is_completed() is False # Fix the pipeline by replacing the failing stage pipeline.reset() pipeline._stages = [pipeline._stages[0], MultiplyStage("double", 2)] pipeline.execute(5) assert pipeline.is_completed() is True assert pipeline.get_final_output() == 30 class TestPipelineRepr: """Test cases for string representations.""" def test_pipeline_repr(self): """Test pipeline string representation.""" pipeline = PipelineExecutor("TestPipeline") pipeline.add_stage(AddStage("add_ten", 10)) pipeline.add_stage(MultiplyStage("double", 2)) repr_str = repr(pipeline) assert "TestPipeline" in repr_str assert "2" in repr_str # Stage count def test_stage_result_repr_success(self): """Test StageResult repr for success.""" result = StageResult("test_stage", success=True, output=42) assert "test_stage" in repr(result) assert "success=True" in repr(result) def test_stage_result_repr_failure(self): """Test StageResult repr for failure.""" error = RuntimeError("test error") result = StageResult("test_stage", success=False, error=error) assert "test_stage" in repr(result) assert "success=False" in repr(result) class TestPipelineWithVariousDataTypes: """Test pipeline behavior with different data types.""" def test_pipeline_with_dict(self): """Test pipeline transforming dictionaries.""" pipeline = PipelineExecutor() pipeline.add_stage(LambdaStage("add_key", lambda d: {**d, "y": 20})) pipeline.add_stage(LambdaStage("add_another", lambda d: {**d, "z": 30})) result = pipeline.execute({"x": 10}) assert result == {"x": 10, "y": 20, "z": 30} def test_pipeline_with_identity_stages(self): """Test pipeline with stages that pass data through unchanged.""" pipeline = PipelineExecutor() pipeline.add_stage(IdentityStage("first")) pipeline.add_stage(IdentityStage("second")) pipeline.add_stage(IdentityStage("third")) test_data = {"complex": "object"} result = pipeline.execute(test_data) assert result is test_data assert pipeline.is_completed() is True