|
|
@@ -0,0 +1,459 @@
|
|
|
+"""
|
|
|
+Integration tests for the full pipeline scheduler.
|
|
|
+"""
|
|
|
+
|
|
|
+import pytest
|
|
|
+import asyncio
|
|
|
+from pathlib import Path
|
|
|
+from datetime import datetime
|
|
|
+from unittest.mock import Mock, MagicMock, patch
|
|
|
+
|
|
|
+# Mock torch before importing anything that depends on it
|
|
|
+import sys
|
|
|
+sys.modules['torch'] = MagicMock()
|
|
|
+sys.modules['transformers'] = MagicMock()
|
|
|
+
|
|
|
+from src.scheduler import (
|
|
|
+ TaskQueue,
|
|
|
+ RetryManager,
|
|
|
+ ProgressNotifier,
|
|
|
+ ConsoleProgressObserver,
|
|
|
+ CallbackProgressObserver,
|
|
|
+ RecoveryManager,
|
|
|
+ ChapterTask,
|
|
|
+ TaskStatus,
|
|
|
+ SchedulerState,
|
|
|
+ RetryConfig,
|
|
|
+ PipelineProgress,
|
|
|
+ CheckpointData
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def temp_work_dir(tmp_path):
|
|
|
+ """Create temporary working directory."""
|
|
|
+ work_dir = tmp_path / "work"
|
|
|
+ work_dir.mkdir()
|
|
|
+ return work_dir
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def sample_novel_file(tmp_path):
|
|
|
+ """Create a sample novel file for testing."""
|
|
|
+ novel_file = tmp_path / "novel.txt"
|
|
|
+ content = """第一章 开始
|
|
|
+
|
|
|
+这是第一章的内容,包含一些文字。
|
|
|
+
|
|
|
+林风站在山顶,看着远方的城市。
|
|
|
+
|
|
|
+第二章 继续
|
|
|
+
|
|
|
+这是第二章的内容。
|
|
|
+
|
|
|
+他开始了新的旅程。"""
|
|
|
+ novel_file.write_text(content, encoding="utf-8")
|
|
|
+ return novel_file
|
|
|
+
|
|
|
+
|
|
|
+class TestTaskQueue:
|
|
|
+ """Test suite for TaskQueue."""
|
|
|
+
|
|
|
+ def test_add_and_get_task(self):
|
|
|
+ """Test adding and retrieving tasks."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ task = queue.add_chapter(
|
|
|
+ "ch1",
|
|
|
+ 0,
|
|
|
+ "第一章",
|
|
|
+ "这是内容"
|
|
|
+ )
|
|
|
+
|
|
|
+ assert task.chapter_id == "ch1"
|
|
|
+ assert task.status == TaskStatus.PENDING
|
|
|
+
|
|
|
+ def test_get_next_pending(self):
|
|
|
+ """Test getting next pending task."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ queue.add_chapter("ch1", 0, "第一章", "内容1")
|
|
|
+ queue.add_chapter("ch2", 1, "第二章", "内容2")
|
|
|
+
|
|
|
+ task = queue.get_next_pending()
|
|
|
+ assert task is not None
|
|
|
+ assert task.chapter_id == "ch1"
|
|
|
+ assert task.status == TaskStatus.IN_PROGRESS
|
|
|
+
|
|
|
+ def test_mark_completed(self):
|
|
|
+ """Test marking task as completed."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ queue.add_chapter("ch1", 0, "第一章", "内容")
|
|
|
+
|
|
|
+ queue.get_next_pending()
|
|
|
+ queue.mark_completed("ch1", "Translated content")
|
|
|
+
|
|
|
+ retrieved = queue.get_task("ch1")
|
|
|
+ assert retrieved.status == TaskStatus.COMPLETED
|
|
|
+ assert retrieved.translated_content == "Translated content"
|
|
|
+
|
|
|
+ def test_mark_failed(self):
|
|
|
+ """Test marking task as failed."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ queue.add_chapter("ch1", 0, "第一章", "内容")
|
|
|
+
|
|
|
+ queue.get_next_pending()
|
|
|
+ queue.mark_failed("ch1", "Translation failed")
|
|
|
+
|
|
|
+ retrieved = queue.get_task("ch1")
|
|
|
+ assert retrieved.status == TaskStatus.FAILED
|
|
|
+ assert "Translation failed" in retrieved.error_message
|
|
|
+
|
|
|
+ def test_get_stats(self):
|
|
|
+ """Test getting queue statistics."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ queue.add_chapter("ch1", 0, "第一章", "内容1")
|
|
|
+ queue.add_chapter("ch2", 1, "第二章", "内容2")
|
|
|
+
|
|
|
+ stats = queue.get_stats()
|
|
|
+ assert stats.total == 2
|
|
|
+ assert stats.pending == 2
|
|
|
+
|
|
|
+ def test_reset_for_retry(self):
|
|
|
+ """Test resetting task for retry."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ queue.add_chapter("ch1", 0, "第一章", "内容")
|
|
|
+
|
|
|
+ queue.get_next_pending()
|
|
|
+ queue.mark_failed("ch1", "Error")
|
|
|
+ assert queue.reset_for_retry("ch1") is True
|
|
|
+
|
|
|
+ task = queue.get_task("ch1")
|
|
|
+ assert task.status == TaskStatus.PENDING
|
|
|
+
|
|
|
+ def test_has_pending(self):
|
|
|
+ """Test checking for pending tasks."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ assert not queue.has_pending()
|
|
|
+
|
|
|
+ queue.add_chapter("ch1", 0, "第一章", "内容")
|
|
|
+ assert queue.has_pending()
|
|
|
+
|
|
|
+ def test_is_complete(self):
|
|
|
+ """Test checking if queue is complete."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ queue.add_chapter("ch1", 0, "第一章", "内容")
|
|
|
+
|
|
|
+ assert not queue.is_complete()
|
|
|
+
|
|
|
+ queue.get_next_pending()
|
|
|
+ queue.mark_completed("ch1", "Translated")
|
|
|
+ assert queue.is_complete()
|
|
|
+
|
|
|
+ def test_iteration(self):
|
|
|
+ """Test iterating over tasks."""
|
|
|
+ queue = TaskQueue()
|
|
|
+ queue.add_chapter("ch2", 1, "第二章", "内容2")
|
|
|
+ queue.add_chapter("ch1", 0, "第一章", "内容1")
|
|
|
+
|
|
|
+ tasks = list(queue)
|
|
|
+ assert len(tasks) == 2
|
|
|
+ assert tasks[0].chapter_index == 0
|
|
|
+ assert tasks[1].chapter_index == 1
|
|
|
+
|
|
|
+
|
|
|
+class TestRetryManager:
|
|
|
+ """Test suite for RetryManager."""
|
|
|
+
|
|
|
+ def test_should_retry_on_timeout(self):
|
|
|
+ """Test retry decision for timeout errors."""
|
|
|
+ manager = RetryManager()
|
|
|
+ task = ChapterTask("ch1", 0, "第一章", "内容")
|
|
|
+ task.retry_count = 1
|
|
|
+
|
|
|
+ assert manager.should_retry(task, "Request timeout") is True
|
|
|
+
|
|
|
+ def test_should_not_retry_exceeded(self):
|
|
|
+ """Test no retry after max attempts."""
|
|
|
+ manager = RetryManager()
|
|
|
+ task = ChapterTask("ch1", 0, "第一章", "内容")
|
|
|
+ task.retry_count = 3
|
|
|
+
|
|
|
+ assert manager.should_retry(task, "Error") is False
|
|
|
+
|
|
|
+ def test_get_retry_delay_exponential(self):
|
|
|
+ """Test exponential backoff delay calculation."""
|
|
|
+ config = RetryConfig(exponential_backoff=True, base_delay=1.0)
|
|
|
+ manager = RetryManager(config)
|
|
|
+
|
|
|
+ assert manager.get_retry_delay(1) == 1.0
|
|
|
+ assert manager.get_retry_delay(2) == 2.0
|
|
|
+ assert manager.get_retry_delay(3) == 4.0
|
|
|
+
|
|
|
+ def test_get_retry_delay_linear(self):
|
|
|
+ """Test linear delay calculation."""
|
|
|
+ config = RetryConfig(exponential_backoff=False, base_delay=2.0)
|
|
|
+ manager = RetryManager(config)
|
|
|
+
|
|
|
+ assert manager.get_retry_delay(1) == 2.0
|
|
|
+ assert manager.get_retry_delay(2) == 2.0
|
|
|
+
|
|
|
+ def test_max_delay_cap(self):
|
|
|
+ """Test that delay is capped at max_delay."""
|
|
|
+ config = RetryConfig(
|
|
|
+ exponential_backoff=True,
|
|
|
+ base_delay=1.0,
|
|
|
+ max_delay=5.0
|
|
|
+ )
|
|
|
+ manager = RetryManager(config)
|
|
|
+
|
|
|
+ # Should cap at 5.0
|
|
|
+ assert manager.get_retry_delay(10) <= 5.0
|
|
|
+
|
|
|
+ def test_record_and_get_history(self):
|
|
|
+ """Test recording and retrieving retry history."""
|
|
|
+ manager = RetryManager()
|
|
|
+ record = manager.record_retry("ch1", 1, "Error", 1.0)
|
|
|
+
|
|
|
+ assert record.chapter_id == "ch1"
|
|
|
+ assert record.attempt_number == 1
|
|
|
+
|
|
|
+ history = manager.get_retry_history("ch1")
|
|
|
+ assert len(history) == 1
|
|
|
+
|
|
|
+ def test_clear_retry_history(self):
|
|
|
+ """Test clearing retry history."""
|
|
|
+ manager = RetryManager()
|
|
|
+ manager.record_retry("ch1", 1, "Error", 1.0)
|
|
|
+ manager.clear_retry_history("ch1")
|
|
|
+
|
|
|
+ history = manager.get_retry_history("ch1")
|
|
|
+ assert len(history) == 0
|
|
|
+
|
|
|
+ def test_get_stats(self):
|
|
|
+ """Test getting retry statistics."""
|
|
|
+ manager = RetryManager()
|
|
|
+ manager.record_retry("ch1", 1, "Error", 1.0, success=True)
|
|
|
+ manager.record_retry("ch2", 1, "Error", 1.0, success=False)
|
|
|
+
|
|
|
+ stats = manager.get_stats()
|
|
|
+ assert stats["total_retries"] == 2
|
|
|
+ assert stats["successful_retries"] == 1
|
|
|
+ assert stats["failed_retries"] == 1
|
|
|
+
|
|
|
+
|
|
|
+class TestProgressNotifier:
|
|
|
+ """Test suite for ProgressNotifier."""
|
|
|
+
|
|
|
+ def test_register_and_notify(self):
|
|
|
+ """Test registering observer and sending notifications."""
|
|
|
+ notifier = ProgressNotifier()
|
|
|
+ observer = CallbackProgressObserver(
|
|
|
+ on_start=lambda total: None,
|
|
|
+ on_chapter_complete=lambda task: None
|
|
|
+ )
|
|
|
+
|
|
|
+ notifier.register(observer)
|
|
|
+ assert notifier.observer_count == 1
|
|
|
+
|
|
|
+ task = ChapterTask("ch1", 0, "第一章", "内容")
|
|
|
+ notifier.notify_chapter_complete(task)
|
|
|
+
|
|
|
+ def test_unregister_observer(self):
|
|
|
+ """Test unregistering observer."""
|
|
|
+ notifier = ProgressNotifier()
|
|
|
+ observer = CallbackProgressObserver()
|
|
|
+
|
|
|
+ notifier.register(observer)
|
|
|
+ assert notifier.observer_count == 1
|
|
|
+
|
|
|
+ notifier.unregister(observer)
|
|
|
+ assert notifier.observer_count == 0
|
|
|
+
|
|
|
+ def test_event_history(self):
|
|
|
+ """Test event history tracking."""
|
|
|
+ notifier = ProgressNotifier()
|
|
|
+
|
|
|
+ task = ChapterTask("ch1", 0, "第一章", "内容")
|
|
|
+ notifier.notify_chapter_start(task)
|
|
|
+ notifier.notify_chapter_complete(task)
|
|
|
+
|
|
|
+ history = notifier.get_event_history()
|
|
|
+ assert len(history) == 2
|
|
|
+
|
|
|
+ def test_clear_observers(self):
|
|
|
+ """Test clearing all observers."""
|
|
|
+ notifier = ProgressNotifier()
|
|
|
+ notifier.register(CallbackProgressObserver())
|
|
|
+ notifier.register(CallbackProgressObserver())
|
|
|
+
|
|
|
+ assert notifier.observer_count == 2
|
|
|
+ notifier.clear_observers()
|
|
|
+ assert notifier.observer_count == 0
|
|
|
+
|
|
|
+
|
|
|
+class TestRecoveryManager:
|
|
|
+ """Test suite for RecoveryManager."""
|
|
|
+
|
|
|
+ def test_save_and_load_checkpoint(self, temp_work_dir):
|
|
|
+ """Test saving and loading checkpoint."""
|
|
|
+ recovery = RecoveryManager(temp_work_dir)
|
|
|
+
|
|
|
+ checkpoint = CheckpointData(
|
|
|
+ work_id="work123",
|
|
|
+ current_chapter_index=5,
|
|
|
+ completed_indices=[0, 1, 2, 3, 4],
|
|
|
+ failed_indices=[]
|
|
|
+ )
|
|
|
+
|
|
|
+ recovery.save_checkpoint(checkpoint)
|
|
|
+ assert recovery.has_checkpoint() is True
|
|
|
+
|
|
|
+ loaded = recovery.load_checkpoint()
|
|
|
+ assert loaded is not None
|
|
|
+ assert loaded.work_id == "work123"
|
|
|
+ assert loaded.current_chapter_index == 5
|
|
|
+
|
|
|
+ def test_delete_checkpoint(self, temp_work_dir):
|
|
|
+ """Test deleting checkpoint."""
|
|
|
+ recovery = RecoveryManager(temp_work_dir)
|
|
|
+
|
|
|
+ checkpoint = CheckpointData(
|
|
|
+ work_id="work123",
|
|
|
+ current_chapter_index=0,
|
|
|
+ completed_indices=[],
|
|
|
+ failed_indices=[]
|
|
|
+ )
|
|
|
+
|
|
|
+ recovery.save_checkpoint(checkpoint)
|
|
|
+ assert recovery.has_checkpoint() is True
|
|
|
+
|
|
|
+ recovery.delete_checkpoint()
|
|
|
+ assert recovery.has_checkpoint() is False
|
|
|
+
|
|
|
+ def test_get_recovery_state(self, temp_work_dir):
|
|
|
+ """Test getting recovery state."""
|
|
|
+ recovery = RecoveryManager(temp_work_dir)
|
|
|
+
|
|
|
+ checkpoint = CheckpointData(
|
|
|
+ work_id="work123",
|
|
|
+ current_chapter_index=3,
|
|
|
+ completed_indices=[0, 1, 2],
|
|
|
+ failed_indices=[]
|
|
|
+ )
|
|
|
+
|
|
|
+ recovery.save_checkpoint(checkpoint)
|
|
|
+ state = recovery.get_recovery_state()
|
|
|
+
|
|
|
+ assert state is not None
|
|
|
+ assert state["recoverable"] is True
|
|
|
+ assert state["work_id"] == "work123"
|
|
|
+ assert state["resume_index"] == 3
|
|
|
+
|
|
|
+ def test_can_resume(self, temp_work_dir):
|
|
|
+ """Test checking if resume is possible."""
|
|
|
+ recovery = RecoveryManager(temp_work_dir)
|
|
|
+
|
|
|
+ assert recovery.can_resume() is False
|
|
|
+
|
|
|
+ checkpoint = CheckpointData(
|
|
|
+ work_id="work123",
|
|
|
+ current_chapter_index=0,
|
|
|
+ completed_indices=[],
|
|
|
+ failed_indices=[]
|
|
|
+ )
|
|
|
+
|
|
|
+ recovery.save_checkpoint(checkpoint)
|
|
|
+ assert recovery.can_resume() is True
|
|
|
+
|
|
|
+
|
|
|
+class TestModels:
|
|
|
+ """Test suite for scheduler models."""
|
|
|
+
|
|
|
+ def test_chapter_task_model(self):
|
|
|
+ """Test ChapterTask model."""
|
|
|
+ task = ChapterTask(
|
|
|
+ "ch1",
|
|
|
+ 0,
|
|
|
+ "第一章",
|
|
|
+ "内容"
|
|
|
+ )
|
|
|
+
|
|
|
+ assert task.is_finished is False
|
|
|
+ assert task.can_retry is False
|
|
|
+
|
|
|
+ task.status = TaskStatus.COMPLETED
|
|
|
+ assert task.is_finished is True
|
|
|
+
|
|
|
+ def test_task_status_enum(self):
|
|
|
+ """Test TaskStatus enum values."""
|
|
|
+ assert TaskStatus.PENDING.value == "pending"
|
|
|
+ assert TaskStatus.IN_PROGRESS.value == "in_progress"
|
|
|
+ assert TaskStatus.COMPLETED.value == "completed"
|
|
|
+ assert TaskStatus.FAILED.value == "failed"
|
|
|
+
|
|
|
+ def test_scheduler_state_enum(self):
|
|
|
+ """Test SchedulerState enum values."""
|
|
|
+ assert SchedulerState.IDLE.value == "idle"
|
|
|
+ assert SchedulerState.RUNNING.value == "running"
|
|
|
+ assert SchedulerState.PAUSED.value == "paused"
|
|
|
+ assert SchedulerState.COMPLETED.value == "completed"
|
|
|
+
|
|
|
+ def test_pipeline_progress_model(self):
|
|
|
+ """Test PipelineProgress model."""
|
|
|
+ progress = PipelineProgress(
|
|
|
+ total_chapters=10,
|
|
|
+ completed_chapters=5
|
|
|
+ )
|
|
|
+
|
|
|
+ assert progress.pending_chapters == 5
|
|
|
+ assert progress.completion_rate == 0.5
|
|
|
+
|
|
|
+ def test_checkpoint_data_model(self):
|
|
|
+ """Test CheckpointData model."""
|
|
|
+ checkpoint = CheckpointData(
|
|
|
+ work_id="work123",
|
|
|
+ current_chapter_index=5,
|
|
|
+ completed_indices=[0, 1, 2, 3, 4],
|
|
|
+ failed_indices=[]
|
|
|
+ )
|
|
|
+
|
|
|
+ assert checkpoint.work_id == "work123"
|
|
|
+ assert checkpoint.current_chapter_index == 5
|
|
|
+
|
|
|
+
|
|
|
+class TestConsoleProgressObserver:
|
|
|
+ """Test suite for ConsoleProgressObserver."""
|
|
|
+
|
|
|
+ def test_observer_creation(self):
|
|
|
+ """Test creating console observer."""
|
|
|
+ observer = ConsoleProgressObserver(verbose=False)
|
|
|
+ assert observer.verbose is False
|
|
|
+
|
|
|
+ def test_observer_methods_exist(self):
|
|
|
+ """Test that all observer methods exist."""
|
|
|
+ observer = ConsoleProgressObserver()
|
|
|
+
|
|
|
+ assert hasattr(observer, "on_pipeline_start")
|
|
|
+ assert hasattr(observer, "on_chapter_complete")
|
|
|
+ assert hasattr(observer, "on_chapter_failed")
|
|
|
+
|
|
|
+
|
|
|
+class TestCallbackProgressObserver:
|
|
|
+ """Test suite for CallbackProgressObserver."""
|
|
|
+
|
|
|
+ def test_callback_invocation(self):
|
|
|
+ """Test that callbacks are invoked."""
|
|
|
+ calls = []
|
|
|
+
|
|
|
+ on_start = lambda total: calls.append("start")
|
|
|
+ on_complete = lambda task: calls.append("complete")
|
|
|
+
|
|
|
+ observer = CallbackProgressObserver(
|
|
|
+ on_start=on_start,
|
|
|
+ on_chapter_complete=on_complete
|
|
|
+ )
|
|
|
+
|
|
|
+ observer.on_pipeline_start(10)
|
|
|
+ observer.on_chapter_complete(ChapterTask("ch1", 0, "Title", "Content"))
|
|
|
+
|
|
|
+ assert "start" in calls
|
|
|
+ assert "complete" in calls
|