2
0

test_pipeline.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. """
  2. Unit tests for Pipeline orchestration framework.
  3. Tests cover:
  4. - Stage execution in sequence
  5. - Result caching and retrieval
  6. - Exception handling
  7. - State reset functionality
  8. """
  9. import pytest
  10. from typing import Any
  11. from src.pipeline.pipeline import (
  12. Stage,
  13. PipelineExecutor,
  14. LambdaStage,
  15. StageResult,
  16. )
  17. class AddStage(Stage):
  18. """Test stage that adds a value."""
  19. def __init__(self, name: str, value: int):
  20. super().__init__(name)
  21. self.value = value
  22. def execute(self, input_data: int) -> int:
  23. return input_data + self.value
  24. class MultiplyStage(Stage):
  25. """Test stage that multiplies by a value."""
  26. def __init__(self, name: str, value: int):
  27. super().__init__(name)
  28. self.value = value
  29. def execute(self, input_data: int) -> int:
  30. return input_data * self.value
  31. class FailingStage(Stage):
  32. """Test stage that always raises an exception."""
  33. def __init__(self, name: str, error_message: str = "Stage failed"):
  34. super().__init__(name)
  35. self.error_message = error_message
  36. def execute(self, input_data: Any) -> Any:
  37. raise RuntimeError(self.error_message)
  38. class IdentityStage(Stage):
  39. """Test stage that returns input unchanged."""
  40. def execute(self, input_data: Any) -> Any:
  41. return input_data
  42. class TestStage:
  43. """Test cases for Stage base class."""
  44. def test_stage_has_name(self):
  45. """Test that stage stores its name."""
  46. stage = AddStage("add_ten", 10)
  47. assert stage.name == "add_ten"
  48. def test_stage_repr(self):
  49. """Test stage string representation."""
  50. stage = AddStage("add_ten", 10)
  51. assert "AddStage" in repr(stage)
  52. assert "add_ten" in repr(stage)
  53. def test_subclass_must_implement_execute(self):
  54. """Test that Stage subclasses must implement execute."""
  55. # With abc.ABC, TypeError is raised at instantiation time
  56. with pytest.raises(TypeError, match="abstract"):
  57. Stage("test") # type: ignore
  58. class TestLambdaStage:
  59. """Test cases for LambdaStage convenience class."""
  60. def test_lambda_stage_executes_function(self):
  61. """Test that LambdaStage wraps and executes a function."""
  62. stage = LambdaStage("double", lambda x: x * 2)
  63. assert stage.execute(5) == 10
  64. def test_lambda_stage_with_string_function(self):
  65. """Test LambdaStage with string manipulation."""
  66. stage = LambdaStage("uppercase", lambda s: s.upper())
  67. assert stage.execute("hello") == "HELLO"
  68. class TestPipelineExecutorCreation:
  69. """Test cases for PipelineExecutor creation and configuration."""
  70. def test_create_pipeline(self):
  71. """Test creating a new pipeline."""
  72. pipeline = PipelineExecutor("TestPipeline")
  73. assert pipeline.name == "TestPipeline"
  74. assert len(pipeline) == 0
  75. def test_add_stage(self):
  76. """Test adding a stage to the pipeline."""
  77. pipeline = PipelineExecutor()
  78. stage = AddStage("add_ten", 10)
  79. result = pipeline.add_stage(stage)
  80. assert len(pipeline) == 1
  81. assert result is pipeline # Method chaining
  82. def test_add_multiple_stages(self):
  83. """Test adding multiple stages."""
  84. pipeline = PipelineExecutor()
  85. pipeline.add_stage(AddStage("add_ten", 10))
  86. pipeline.add_stage(MultiplyStage("double", 2))
  87. assert len(pipeline) == 2
  88. def test_add_stage_raises_for_non_stage(self):
  89. """Test that adding non-Stage objects raises TypeError."""
  90. pipeline = PipelineExecutor()
  91. with pytest.raises(TypeError):
  92. pipeline.add_stage("not a stage") # type: ignore
  93. def test_get_stage_names(self):
  94. """Test getting list of stage names."""
  95. pipeline = PipelineExecutor()
  96. pipeline.add_stage(AddStage("add_ten", 10))
  97. pipeline.add_stage(MultiplyStage("double", 2))
  98. names = pipeline.get_stage_names()
  99. assert names == ["add_ten", "double"]
  100. def test_clear_stages(self):
  101. """Test clearing all stages."""
  102. pipeline = PipelineExecutor()
  103. pipeline.add_stage(AddStage("add_ten", 10))
  104. pipeline.clear_stages()
  105. assert len(pipeline) == 0
  106. class TestPipelineExecutorSequentialExecution:
  107. """Test cases for sequential stage execution."""
  108. def test_execute_single_stage(self):
  109. """Test executing a single stage."""
  110. pipeline = PipelineExecutor()
  111. pipeline.add_stage(AddStage("add_ten", 10))
  112. result = pipeline.execute(5)
  113. assert result == 15
  114. def test_execute_multiple_stages_in_order(self):
  115. """Test that stages execute in the correct order."""
  116. pipeline = PipelineExecutor()
  117. pipeline.add_stage(AddStage("add_ten", 10))
  118. pipeline.add_stage(MultiplyStage("double", 2))
  119. # (5 + 10) * 2 = 30
  120. result = pipeline.execute(5)
  121. assert result == 30
  122. def test_execute_complex_pipeline(self):
  123. """Test a more complex multi-stage pipeline."""
  124. pipeline = PipelineExecutor("Complex")
  125. pipeline.add_stage(AddStage("add_five", 5))
  126. pipeline.add_stage(MultiplyStage("triple", 3))
  127. pipeline.add_stage(AddStage("add_ten", 10))
  128. pipeline.add_stage(MultiplyStage("halve", 0.5))
  129. # ((5 + 5) * 3 + 10) * 0.5 = (10 * 3 + 10) * 0.5 = 40 * 0.5 = 20
  130. result = pipeline.execute(5)
  131. assert result == 20.0
  132. def test_execute_with_string_data(self):
  133. """Test pipeline with string data transformation."""
  134. pipeline = PipelineExecutor()
  135. pipeline.add_stage(LambdaStage("prefix", lambda s: ">> " + s))
  136. pipeline.add_stage(LambdaStage("suffix", lambda s: s + " <<"))
  137. result = pipeline.execute("hello")
  138. assert result == ">> hello <<"
  139. def test_execute_with_list_data(self):
  140. """Test pipeline with list data transformation."""
  141. pipeline = PipelineExecutor()
  142. pipeline.add_stage(LambdaStage("append", lambda lst: lst + [3]))
  143. pipeline.add_stage(LambdaStage("extend", lambda lst: lst + [4, 5]))
  144. result = pipeline.execute([1, 2])
  145. assert result == [1, 2, 3, 4, 5]
  146. def test_execute_empty_pipeline_raises(self):
  147. """Test that executing an empty pipeline raises ValueError."""
  148. pipeline = PipelineExecutor()
  149. with pytest.raises(ValueError, match="no stages"):
  150. pipeline.execute(10)
  151. class TestPipelineStageResultCaching:
  152. """Test cases for stage result caching and retrieval."""
  153. def test_get_stage_result(self):
  154. """Test retrieving a specific stage result."""
  155. pipeline = PipelineExecutor()
  156. pipeline.add_stage(AddStage("add_ten", 10))
  157. pipeline.add_stage(MultiplyStage("double", 2))
  158. pipeline.execute(5)
  159. result = pipeline.get_stage_result("add_ten")
  160. assert result is not None
  161. assert result.stage_name == "add_ten"
  162. assert result.success is True
  163. assert result.output == 15
  164. def test_get_stage_result_for_nonexistent_stage(self):
  165. """Test getting result for non-existent stage."""
  166. pipeline = PipelineExecutor()
  167. pipeline.add_stage(AddStage("add_ten", 10))
  168. pipeline.execute(5)
  169. result = pipeline.get_stage_result("nonexistent")
  170. assert result is None
  171. def test_get_all_results(self):
  172. """Test getting all stage results."""
  173. pipeline = PipelineExecutor()
  174. pipeline.add_stage(AddStage("add_ten", 10))
  175. pipeline.add_stage(MultiplyStage("double", 2))
  176. pipeline.execute(5)
  177. results = pipeline.get_all_results()
  178. assert len(results) == 2
  179. assert "add_ten" in results
  180. assert "double" in results
  181. def test_get_final_output(self):
  182. """Test getting the final pipeline output."""
  183. pipeline = PipelineExecutor()
  184. pipeline.add_stage(AddStage("add_ten", 10))
  185. pipeline.add_stage(MultiplyStage("double", 2))
  186. pipeline.execute(5)
  187. output = pipeline.get_final_output()
  188. assert output == 30
  189. def test_intermediate_stage_outputs(self):
  190. """Test that intermediate stage outputs are correct."""
  191. pipeline = PipelineExecutor()
  192. pipeline.add_stage(AddStage("add_five", 5))
  193. pipeline.add_stage(MultiplyStage("triple", 3))
  194. pipeline.add_stage(AddStage("add_ten", 10))
  195. pipeline.execute(5)
  196. # Input: 5
  197. # add_five: 5 + 5 = 10
  198. # triple: 10 * 3 = 30
  199. # add_ten: 30 + 10 = 40
  200. assert pipeline.get_stage_result("add_five").output == 10
  201. assert pipeline.get_stage_result("triple").output == 30
  202. assert pipeline.get_stage_result("add_ten").output == 40
  203. class TestPipelineExceptionHandling:
  204. """Test cases for exception handling during execution."""
  205. def test_exception_stops_execution(self):
  206. """Test that an exception stops further stage execution."""
  207. executed_count = [0]
  208. class CountingStage(Stage):
  209. def execute(self, input_data):
  210. executed_count[0] += 1
  211. return input_data
  212. pipeline = PipelineExecutor()
  213. pipeline.add_stage(CountingStage("first"))
  214. pipeline.add_stage(FailingStage("fails"))
  215. pipeline.add_stage(CountingStage("never_runs"))
  216. pipeline.execute(10)
  217. # Only first stage and failing stage should have run
  218. assert executed_count[0] == 1
  219. def test_exception_is_caught(self):
  220. """Test that exceptions are caught and stored."""
  221. pipeline = PipelineExecutor()
  222. pipeline.add_stage(AddStage("add_ten", 10))
  223. pipeline.add_stage(FailingStage("fails", "Custom error"))
  224. pipeline.add_stage(MultiplyStage("double", 2))
  225. result = pipeline.execute(5)
  226. assert result is None # Returns None on failure
  227. def test_get_last_exception(self):
  228. """Test retrieving the last exception."""
  229. pipeline = PipelineExecutor()
  230. pipeline.add_stage(AddStage("add_ten", 10))
  231. pipeline.add_stage(FailingStage("fails", "Test error"))
  232. pipeline.execute(5)
  233. exception = pipeline.get_last_exception()
  234. assert exception is not None
  235. assert isinstance(exception, RuntimeError)
  236. assert "Test error" in str(exception)
  237. def test_get_stopped_at_stage(self):
  238. """Test getting the stage where execution stopped."""
  239. pipeline = PipelineExecutor()
  240. pipeline.add_stage(AddStage("add_ten", 10))
  241. pipeline.add_stage(FailingStage("fails"))
  242. pipeline.add_stage(MultiplyStage("double", 2))
  243. pipeline.execute(5)
  244. stopped_at = pipeline.get_stopped_at_stage()
  245. assert stopped_at == "fails"
  246. def test_failed_stage_result(self):
  247. """Test that failed stage result has correct info."""
  248. pipeline = PipelineExecutor()
  249. pipeline.add_stage(AddStage("add_ten", 10))
  250. pipeline.add_stage(FailingStage("fails"))
  251. pipeline.execute(5)
  252. result = pipeline.get_stage_result("fails")
  253. assert result.success is False
  254. assert result.error is not None
  255. assert isinstance(result.error, RuntimeError)
  256. def test_is_completed_returns_false_on_failure(self):
  257. """Test that is_completed returns False after failure."""
  258. pipeline = PipelineExecutor()
  259. pipeline.add_stage(AddStage("add_ten", 10))
  260. pipeline.add_stage(FailingStage("fails"))
  261. pipeline.execute(5)
  262. assert pipeline.is_completed() is False
  263. def test_is_completed_returns_true_on_success(self):
  264. """Test that is_completed returns True after success."""
  265. pipeline = PipelineExecutor()
  266. pipeline.add_stage(AddStage("add_ten", 10))
  267. pipeline.execute(5)
  268. assert pipeline.is_completed() is True
  269. def test_first_stage_failure(self):
  270. """Test handling of failure in the first stage."""
  271. pipeline = PipelineExecutor()
  272. pipeline.add_stage(FailingStage("first_fails"))
  273. pipeline.add_stage(AddStage("add_ten", 10))
  274. pipeline.execute(5)
  275. assert pipeline.get_stopped_at_stage() == "first_fails"
  276. assert pipeline.get_stage_result("add_ten") is None
  277. class TestPipelineReset:
  278. """Test cases for pipeline state reset functionality."""
  279. def test_reset_clears_results(self):
  280. """Test that reset clears cached results."""
  281. pipeline = PipelineExecutor()
  282. pipeline.add_stage(AddStage("add_ten", 10))
  283. pipeline.add_stage(MultiplyStage("double", 2))
  284. pipeline.execute(5)
  285. assert pipeline.get_all_results()
  286. pipeline.reset()
  287. assert not pipeline.get_all_results()
  288. def test_reset_clears_exception(self):
  289. """Test that reset clears exception state."""
  290. pipeline = PipelineExecutor()
  291. pipeline.add_stage(FailingStage("fails"))
  292. pipeline.execute(5)
  293. assert pipeline.get_last_exception() is not None
  294. pipeline.reset()
  295. assert pipeline.get_last_exception() is None
  296. def test_reset_clears_stopped_at(self):
  297. """Test that reset clears stopped_at_stage."""
  298. pipeline = PipelineExecutor()
  299. pipeline.add_stage(FailingStage("fails"))
  300. pipeline.execute(5)
  301. assert pipeline.get_stopped_at_stage() is not None
  302. pipeline.reset()
  303. assert pipeline.get_stopped_at_stage() is None
  304. def test_re_execute_after_reset(self):
  305. """Test that pipeline can be re-executed after reset."""
  306. pipeline = PipelineExecutor()
  307. pipeline.add_stage(AddStage("add_ten", 10))
  308. pipeline.execute(5)
  309. assert pipeline.get_final_output() == 15
  310. pipeline.reset()
  311. pipeline.execute(20)
  312. assert pipeline.get_final_output() == 30
  313. def test_re_execute_after_failure_and_reset(self):
  314. """Test re-execution after failure and reset."""
  315. pipeline = PipelineExecutor()
  316. pipeline.add_stage(AddStage("add_ten", 10))
  317. pipeline.add_stage(FailingStage("fails"))
  318. pipeline.execute(5)
  319. assert pipeline.is_completed() is False
  320. # Fix the pipeline by replacing the failing stage
  321. pipeline.reset()
  322. pipeline._stages = [pipeline._stages[0], MultiplyStage("double", 2)]
  323. pipeline.execute(5)
  324. assert pipeline.is_completed() is True
  325. assert pipeline.get_final_output() == 30
  326. class TestPipelineRepr:
  327. """Test cases for string representations."""
  328. def test_pipeline_repr(self):
  329. """Test pipeline string representation."""
  330. pipeline = PipelineExecutor("TestPipeline")
  331. pipeline.add_stage(AddStage("add_ten", 10))
  332. pipeline.add_stage(MultiplyStage("double", 2))
  333. repr_str = repr(pipeline)
  334. assert "TestPipeline" in repr_str
  335. assert "2" in repr_str # Stage count
  336. def test_stage_result_repr_success(self):
  337. """Test StageResult repr for success."""
  338. result = StageResult("test_stage", success=True, output=42)
  339. assert "test_stage" in repr(result)
  340. assert "success=True" in repr(result)
  341. def test_stage_result_repr_failure(self):
  342. """Test StageResult repr for failure."""
  343. error = RuntimeError("test error")
  344. result = StageResult("test_stage", success=False, error=error)
  345. assert "test_stage" in repr(result)
  346. assert "success=False" in repr(result)
  347. class TestPipelineWithVariousDataTypes:
  348. """Test pipeline behavior with different data types."""
  349. def test_pipeline_with_dict(self):
  350. """Test pipeline transforming dictionaries."""
  351. pipeline = PipelineExecutor()
  352. pipeline.add_stage(LambdaStage("add_key", lambda d: {**d, "y": 20}))
  353. pipeline.add_stage(LambdaStage("add_another", lambda d: {**d, "z": 30}))
  354. result = pipeline.execute({"x": 10})
  355. assert result == {"x": 10, "y": 20, "z": 30}
  356. def test_pipeline_with_identity_stages(self):
  357. """Test pipeline with stages that pass data through unchanged."""
  358. pipeline = PipelineExecutor()
  359. pipeline.add_stage(IdentityStage("first"))
  360. pipeline.add_stage(IdentityStage("second"))
  361. pipeline.add_stage(IdentityStage("third"))
  362. test_data = {"complex": "object"}
  363. result = pipeline.execute(test_data)
  364. assert result is test_data
  365. assert pipeline.is_completed() is True