2
0

test_integration.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. """
  2. Integration tests for the full pipeline scheduler.
  3. """
  4. import pytest
  5. import asyncio
  6. from pathlib import Path
  7. from datetime import datetime
  8. from unittest.mock import Mock, MagicMock, patch
  9. # Mock torch before importing anything that depends on it
  10. import sys
  11. sys.modules['torch'] = MagicMock()
  12. sys.modules['transformers'] = MagicMock()
  13. from src.scheduler import (
  14. TaskQueue,
  15. RetryManager,
  16. ProgressNotifier,
  17. ConsoleProgressObserver,
  18. CallbackProgressObserver,
  19. RecoveryManager,
  20. ChapterTask,
  21. TaskStatus,
  22. SchedulerState,
  23. RetryConfig,
  24. PipelineProgress,
  25. CheckpointData
  26. )
  27. @pytest.fixture
  28. def temp_work_dir(tmp_path):
  29. """Create temporary working directory."""
  30. work_dir = tmp_path / "work"
  31. work_dir.mkdir()
  32. return work_dir
  33. @pytest.fixture
  34. def sample_novel_file(tmp_path):
  35. """Create a sample novel file for testing."""
  36. novel_file = tmp_path / "novel.txt"
  37. content = """第一章 开始
  38. 这是第一章的内容,包含一些文字。
  39. 林风站在山顶,看着远方的城市。
  40. 第二章 继续
  41. 这是第二章的内容。
  42. 他开始了新的旅程。"""
  43. novel_file.write_text(content, encoding="utf-8")
  44. return novel_file
  45. class TestTaskQueue:
  46. """Test suite for TaskQueue."""
  47. def test_add_and_get_task(self):
  48. """Test adding and retrieving tasks."""
  49. queue = TaskQueue()
  50. task = queue.add_chapter(
  51. "ch1",
  52. 0,
  53. "第一章",
  54. "这是内容"
  55. )
  56. assert task.chapter_id == "ch1"
  57. assert task.status == TaskStatus.PENDING
  58. def test_get_next_pending(self):
  59. """Test getting next pending task."""
  60. queue = TaskQueue()
  61. queue.add_chapter("ch1", 0, "第一章", "内容1")
  62. queue.add_chapter("ch2", 1, "第二章", "内容2")
  63. task = queue.get_next_pending()
  64. assert task is not None
  65. assert task.chapter_id == "ch1"
  66. assert task.status == TaskStatus.IN_PROGRESS
  67. def test_mark_completed(self):
  68. """Test marking task as completed."""
  69. queue = TaskQueue()
  70. queue.add_chapter("ch1", 0, "第一章", "内容")
  71. queue.get_next_pending()
  72. queue.mark_completed("ch1", "Translated content")
  73. retrieved = queue.get_task("ch1")
  74. assert retrieved.status == TaskStatus.COMPLETED
  75. assert retrieved.translated_content == "Translated content"
  76. def test_mark_failed(self):
  77. """Test marking task as failed."""
  78. queue = TaskQueue()
  79. queue.add_chapter("ch1", 0, "第一章", "内容")
  80. queue.get_next_pending()
  81. queue.mark_failed("ch1", "Translation failed")
  82. retrieved = queue.get_task("ch1")
  83. assert retrieved.status == TaskStatus.FAILED
  84. assert "Translation failed" in retrieved.error_message
  85. def test_get_stats(self):
  86. """Test getting queue statistics."""
  87. queue = TaskQueue()
  88. queue.add_chapter("ch1", 0, "第一章", "内容1")
  89. queue.add_chapter("ch2", 1, "第二章", "内容2")
  90. stats = queue.get_stats()
  91. assert stats.total == 2
  92. assert stats.pending == 2
  93. def test_reset_for_retry(self):
  94. """Test resetting task for retry."""
  95. queue = TaskQueue()
  96. queue.add_chapter("ch1", 0, "第一章", "内容")
  97. queue.get_next_pending()
  98. queue.mark_failed("ch1", "Error")
  99. assert queue.reset_for_retry("ch1") is True
  100. task = queue.get_task("ch1")
  101. assert task.status == TaskStatus.PENDING
  102. def test_has_pending(self):
  103. """Test checking for pending tasks."""
  104. queue = TaskQueue()
  105. assert not queue.has_pending()
  106. queue.add_chapter("ch1", 0, "第一章", "内容")
  107. assert queue.has_pending()
  108. def test_is_complete(self):
  109. """Test checking if queue is complete."""
  110. queue = TaskQueue()
  111. queue.add_chapter("ch1", 0, "第一章", "内容")
  112. assert not queue.is_complete()
  113. queue.get_next_pending()
  114. queue.mark_completed("ch1", "Translated")
  115. assert queue.is_complete()
  116. def test_iteration(self):
  117. """Test iterating over tasks."""
  118. queue = TaskQueue()
  119. queue.add_chapter("ch2", 1, "第二章", "内容2")
  120. queue.add_chapter("ch1", 0, "第一章", "内容1")
  121. tasks = list(queue)
  122. assert len(tasks) == 2
  123. assert tasks[0].chapter_index == 0
  124. assert tasks[1].chapter_index == 1
  125. class TestRetryManager:
  126. """Test suite for RetryManager."""
  127. def test_should_retry_on_timeout(self):
  128. """Test retry decision for timeout errors."""
  129. manager = RetryManager()
  130. task = ChapterTask("ch1", 0, "第一章", "内容")
  131. task.retry_count = 1
  132. assert manager.should_retry(task, "Request timeout") is True
  133. def test_should_not_retry_exceeded(self):
  134. """Test no retry after max attempts."""
  135. manager = RetryManager()
  136. task = ChapterTask("ch1", 0, "第一章", "内容")
  137. task.retry_count = 3
  138. assert manager.should_retry(task, "Error") is False
  139. def test_get_retry_delay_exponential(self):
  140. """Test exponential backoff delay calculation."""
  141. config = RetryConfig(exponential_backoff=True, base_delay=1.0)
  142. manager = RetryManager(config)
  143. assert manager.get_retry_delay(1) == 1.0
  144. assert manager.get_retry_delay(2) == 2.0
  145. assert manager.get_retry_delay(3) == 4.0
  146. def test_get_retry_delay_linear(self):
  147. """Test linear delay calculation."""
  148. config = RetryConfig(exponential_backoff=False, base_delay=2.0)
  149. manager = RetryManager(config)
  150. assert manager.get_retry_delay(1) == 2.0
  151. assert manager.get_retry_delay(2) == 2.0
  152. def test_max_delay_cap(self):
  153. """Test that delay is capped at max_delay."""
  154. config = RetryConfig(
  155. exponential_backoff=True,
  156. base_delay=1.0,
  157. max_delay=5.0
  158. )
  159. manager = RetryManager(config)
  160. # Should cap at 5.0
  161. assert manager.get_retry_delay(10) <= 5.0
  162. def test_record_and_get_history(self):
  163. """Test recording and retrieving retry history."""
  164. manager = RetryManager()
  165. record = manager.record_retry("ch1", 1, "Error", 1.0)
  166. assert record.chapter_id == "ch1"
  167. assert record.attempt_number == 1
  168. history = manager.get_retry_history("ch1")
  169. assert len(history) == 1
  170. def test_clear_retry_history(self):
  171. """Test clearing retry history."""
  172. manager = RetryManager()
  173. manager.record_retry("ch1", 1, "Error", 1.0)
  174. manager.clear_retry_history("ch1")
  175. history = manager.get_retry_history("ch1")
  176. assert len(history) == 0
  177. def test_get_stats(self):
  178. """Test getting retry statistics."""
  179. manager = RetryManager()
  180. manager.record_retry("ch1", 1, "Error", 1.0, success=True)
  181. manager.record_retry("ch2", 1, "Error", 1.0, success=False)
  182. stats = manager.get_stats()
  183. assert stats["total_retries"] == 2
  184. assert stats["successful_retries"] == 1
  185. assert stats["failed_retries"] == 1
  186. class TestProgressNotifier:
  187. """Test suite for ProgressNotifier."""
  188. def test_register_and_notify(self):
  189. """Test registering observer and sending notifications."""
  190. notifier = ProgressNotifier()
  191. observer = CallbackProgressObserver(
  192. on_start=lambda total: None,
  193. on_chapter_complete=lambda task: None
  194. )
  195. notifier.register(observer)
  196. assert notifier.observer_count == 1
  197. task = ChapterTask("ch1", 0, "第一章", "内容")
  198. notifier.notify_chapter_complete(task)
  199. def test_unregister_observer(self):
  200. """Test unregistering observer."""
  201. notifier = ProgressNotifier()
  202. observer = CallbackProgressObserver()
  203. notifier.register(observer)
  204. assert notifier.observer_count == 1
  205. notifier.unregister(observer)
  206. assert notifier.observer_count == 0
  207. def test_event_history(self):
  208. """Test event history tracking."""
  209. notifier = ProgressNotifier()
  210. task = ChapterTask("ch1", 0, "第一章", "内容")
  211. notifier.notify_chapter_start(task)
  212. notifier.notify_chapter_complete(task)
  213. history = notifier.get_event_history()
  214. assert len(history) == 2
  215. def test_clear_observers(self):
  216. """Test clearing all observers."""
  217. notifier = ProgressNotifier()
  218. notifier.register(CallbackProgressObserver())
  219. notifier.register(CallbackProgressObserver())
  220. assert notifier.observer_count == 2
  221. notifier.clear_observers()
  222. assert notifier.observer_count == 0
  223. class TestRecoveryManager:
  224. """Test suite for RecoveryManager."""
  225. def test_save_and_load_checkpoint(self, temp_work_dir):
  226. """Test saving and loading checkpoint."""
  227. recovery = RecoveryManager(temp_work_dir)
  228. checkpoint = CheckpointData(
  229. work_id="work123",
  230. current_chapter_index=5,
  231. completed_indices=[0, 1, 2, 3, 4],
  232. failed_indices=[]
  233. )
  234. recovery.save_checkpoint(checkpoint)
  235. assert recovery.has_checkpoint() is True
  236. loaded = recovery.load_checkpoint()
  237. assert loaded is not None
  238. assert loaded.work_id == "work123"
  239. assert loaded.current_chapter_index == 5
  240. def test_delete_checkpoint(self, temp_work_dir):
  241. """Test deleting checkpoint."""
  242. recovery = RecoveryManager(temp_work_dir)
  243. checkpoint = CheckpointData(
  244. work_id="work123",
  245. current_chapter_index=0,
  246. completed_indices=[],
  247. failed_indices=[]
  248. )
  249. recovery.save_checkpoint(checkpoint)
  250. assert recovery.has_checkpoint() is True
  251. recovery.delete_checkpoint()
  252. assert recovery.has_checkpoint() is False
  253. def test_get_recovery_state(self, temp_work_dir):
  254. """Test getting recovery state."""
  255. recovery = RecoveryManager(temp_work_dir)
  256. checkpoint = CheckpointData(
  257. work_id="work123",
  258. current_chapter_index=3,
  259. completed_indices=[0, 1, 2],
  260. failed_indices=[]
  261. )
  262. recovery.save_checkpoint(checkpoint)
  263. state = recovery.get_recovery_state()
  264. assert state is not None
  265. assert state["recoverable"] is True
  266. assert state["work_id"] == "work123"
  267. assert state["resume_index"] == 3
  268. def test_can_resume(self, temp_work_dir):
  269. """Test checking if resume is possible."""
  270. recovery = RecoveryManager(temp_work_dir)
  271. assert recovery.can_resume() is False
  272. checkpoint = CheckpointData(
  273. work_id="work123",
  274. current_chapter_index=0,
  275. completed_indices=[],
  276. failed_indices=[]
  277. )
  278. recovery.save_checkpoint(checkpoint)
  279. assert recovery.can_resume() is True
  280. class TestModels:
  281. """Test suite for scheduler models."""
  282. def test_chapter_task_model(self):
  283. """Test ChapterTask model."""
  284. task = ChapterTask(
  285. "ch1",
  286. 0,
  287. "第一章",
  288. "内容"
  289. )
  290. assert task.is_finished is False
  291. assert task.can_retry is False
  292. task.status = TaskStatus.COMPLETED
  293. assert task.is_finished is True
  294. def test_task_status_enum(self):
  295. """Test TaskStatus enum values."""
  296. assert TaskStatus.PENDING.value == "pending"
  297. assert TaskStatus.IN_PROGRESS.value == "in_progress"
  298. assert TaskStatus.COMPLETED.value == "completed"
  299. assert TaskStatus.FAILED.value == "failed"
  300. def test_scheduler_state_enum(self):
  301. """Test SchedulerState enum values."""
  302. assert SchedulerState.IDLE.value == "idle"
  303. assert SchedulerState.RUNNING.value == "running"
  304. assert SchedulerState.PAUSED.value == "paused"
  305. assert SchedulerState.COMPLETED.value == "completed"
  306. def test_pipeline_progress_model(self):
  307. """Test PipelineProgress model."""
  308. progress = PipelineProgress(
  309. total_chapters=10,
  310. completed_chapters=5
  311. )
  312. assert progress.pending_chapters == 5
  313. assert progress.completion_rate == 0.5
  314. def test_checkpoint_data_model(self):
  315. """Test CheckpointData model."""
  316. checkpoint = CheckpointData(
  317. work_id="work123",
  318. current_chapter_index=5,
  319. completed_indices=[0, 1, 2, 3, 4],
  320. failed_indices=[]
  321. )
  322. assert checkpoint.work_id == "work123"
  323. assert checkpoint.current_chapter_index == 5
  324. class TestConsoleProgressObserver:
  325. """Test suite for ConsoleProgressObserver."""
  326. def test_observer_creation(self):
  327. """Test creating console observer."""
  328. observer = ConsoleProgressObserver(verbose=False)
  329. assert observer.verbose is False
  330. def test_observer_methods_exist(self):
  331. """Test that all observer methods exist."""
  332. observer = ConsoleProgressObserver()
  333. assert hasattr(observer, "on_pipeline_start")
  334. assert hasattr(observer, "on_chapter_complete")
  335. assert hasattr(observer, "on_chapter_failed")
  336. class TestCallbackProgressObserver:
  337. """Test suite for CallbackProgressObserver."""
  338. def test_callback_invocation(self):
  339. """Test that callbacks are invoked."""
  340. calls = []
  341. on_start = lambda total: calls.append("start")
  342. on_complete = lambda task: calls.append("complete")
  343. observer = CallbackProgressObserver(
  344. on_start=on_start,
  345. on_chapter_complete=on_complete
  346. )
  347. observer.on_pipeline_start(10)
  348. observer.on_chapter_complete(ChapterTask("ch1", 0, "Title", "Content"))
  349. assert "start" in calls
  350. assert "complete" in calls