""" Integration tests for the full pipeline scheduler. """ import pytest import asyncio from pathlib import Path from datetime import datetime from unittest.mock import Mock, MagicMock, patch # Mock torch before importing anything that depends on it import sys sys.modules['torch'] = MagicMock() sys.modules['transformers'] = MagicMock() from src.scheduler import ( TaskQueue, RetryManager, ProgressNotifier, ConsoleProgressObserver, CallbackProgressObserver, RecoveryManager, ChapterTask, TaskStatus, SchedulerState, RetryConfig, PipelineProgress, CheckpointData ) @pytest.fixture def temp_work_dir(tmp_path): """Create temporary working directory.""" work_dir = tmp_path / "work" work_dir.mkdir() return work_dir @pytest.fixture def sample_novel_file(tmp_path): """Create a sample novel file for testing.""" novel_file = tmp_path / "novel.txt" content = """第一章 开始 这是第一章的内容,包含一些文字。 林风站在山顶,看着远方的城市。 第二章 继续 这是第二章的内容。 他开始了新的旅程。""" novel_file.write_text(content, encoding="utf-8") return novel_file class TestTaskQueue: """Test suite for TaskQueue.""" def test_add_and_get_task(self): """Test adding and retrieving tasks.""" queue = TaskQueue() task = queue.add_chapter( "ch1", 0, "第一章", "这是内容" ) assert task.chapter_id == "ch1" assert task.status == TaskStatus.PENDING def test_get_next_pending(self): """Test getting next pending task.""" queue = TaskQueue() queue.add_chapter("ch1", 0, "第一章", "内容1") queue.add_chapter("ch2", 1, "第二章", "内容2") task = queue.get_next_pending() assert task is not None assert task.chapter_id == "ch1" assert task.status == TaskStatus.IN_PROGRESS def test_mark_completed(self): """Test marking task as completed.""" queue = TaskQueue() queue.add_chapter("ch1", 0, "第一章", "内容") queue.get_next_pending() queue.mark_completed("ch1", "Translated content") retrieved = queue.get_task("ch1") assert retrieved.status == TaskStatus.COMPLETED assert retrieved.translated_content == "Translated content" def test_mark_failed(self): """Test marking task as failed.""" queue = TaskQueue() queue.add_chapter("ch1", 0, "第一章", "内容") queue.get_next_pending() queue.mark_failed("ch1", "Translation failed") retrieved = queue.get_task("ch1") assert retrieved.status == TaskStatus.FAILED assert "Translation failed" in retrieved.error_message def test_get_stats(self): """Test getting queue statistics.""" queue = TaskQueue() queue.add_chapter("ch1", 0, "第一章", "内容1") queue.add_chapter("ch2", 1, "第二章", "内容2") stats = queue.get_stats() assert stats.total == 2 assert stats.pending == 2 def test_reset_for_retry(self): """Test resetting task for retry.""" queue = TaskQueue() queue.add_chapter("ch1", 0, "第一章", "内容") queue.get_next_pending() queue.mark_failed("ch1", "Error") assert queue.reset_for_retry("ch1") is True task = queue.get_task("ch1") assert task.status == TaskStatus.PENDING def test_has_pending(self): """Test checking for pending tasks.""" queue = TaskQueue() assert not queue.has_pending() queue.add_chapter("ch1", 0, "第一章", "内容") assert queue.has_pending() def test_is_complete(self): """Test checking if queue is complete.""" queue = TaskQueue() queue.add_chapter("ch1", 0, "第一章", "内容") assert not queue.is_complete() queue.get_next_pending() queue.mark_completed("ch1", "Translated") assert queue.is_complete() def test_iteration(self): """Test iterating over tasks.""" queue = TaskQueue() queue.add_chapter("ch2", 1, "第二章", "内容2") queue.add_chapter("ch1", 0, "第一章", "内容1") tasks = list(queue) assert len(tasks) == 2 assert tasks[0].chapter_index == 0 assert tasks[1].chapter_index == 1 class TestRetryManager: """Test suite for RetryManager.""" def test_should_retry_on_timeout(self): """Test retry decision for timeout errors.""" manager = RetryManager() task = ChapterTask("ch1", 0, "第一章", "内容") task.retry_count = 1 assert manager.should_retry(task, "Request timeout") is True def test_should_not_retry_exceeded(self): """Test no retry after max attempts.""" manager = RetryManager() task = ChapterTask("ch1", 0, "第一章", "内容") task.retry_count = 3 assert manager.should_retry(task, "Error") is False def test_get_retry_delay_exponential(self): """Test exponential backoff delay calculation.""" config = RetryConfig(exponential_backoff=True, base_delay=1.0) manager = RetryManager(config) assert manager.get_retry_delay(1) == 1.0 assert manager.get_retry_delay(2) == 2.0 assert manager.get_retry_delay(3) == 4.0 def test_get_retry_delay_linear(self): """Test linear delay calculation.""" config = RetryConfig(exponential_backoff=False, base_delay=2.0) manager = RetryManager(config) assert manager.get_retry_delay(1) == 2.0 assert manager.get_retry_delay(2) == 2.0 def test_max_delay_cap(self): """Test that delay is capped at max_delay.""" config = RetryConfig( exponential_backoff=True, base_delay=1.0, max_delay=5.0 ) manager = RetryManager(config) # Should cap at 5.0 assert manager.get_retry_delay(10) <= 5.0 def test_record_and_get_history(self): """Test recording and retrieving retry history.""" manager = RetryManager() record = manager.record_retry("ch1", 1, "Error", 1.0) assert record.chapter_id == "ch1" assert record.attempt_number == 1 history = manager.get_retry_history("ch1") assert len(history) == 1 def test_clear_retry_history(self): """Test clearing retry history.""" manager = RetryManager() manager.record_retry("ch1", 1, "Error", 1.0) manager.clear_retry_history("ch1") history = manager.get_retry_history("ch1") assert len(history) == 0 def test_get_stats(self): """Test getting retry statistics.""" manager = RetryManager() manager.record_retry("ch1", 1, "Error", 1.0, success=True) manager.record_retry("ch2", 1, "Error", 1.0, success=False) stats = manager.get_stats() assert stats["total_retries"] == 2 assert stats["successful_retries"] == 1 assert stats["failed_retries"] == 1 class TestProgressNotifier: """Test suite for ProgressNotifier.""" def test_register_and_notify(self): """Test registering observer and sending notifications.""" notifier = ProgressNotifier() observer = CallbackProgressObserver( on_start=lambda total: None, on_chapter_complete=lambda task: None ) notifier.register(observer) assert notifier.observer_count == 1 task = ChapterTask("ch1", 0, "第一章", "内容") notifier.notify_chapter_complete(task) def test_unregister_observer(self): """Test unregistering observer.""" notifier = ProgressNotifier() observer = CallbackProgressObserver() notifier.register(observer) assert notifier.observer_count == 1 notifier.unregister(observer) assert notifier.observer_count == 0 def test_event_history(self): """Test event history tracking.""" notifier = ProgressNotifier() task = ChapterTask("ch1", 0, "第一章", "内容") notifier.notify_chapter_start(task) notifier.notify_chapter_complete(task) history = notifier.get_event_history() assert len(history) == 2 def test_clear_observers(self): """Test clearing all observers.""" notifier = ProgressNotifier() notifier.register(CallbackProgressObserver()) notifier.register(CallbackProgressObserver()) assert notifier.observer_count == 2 notifier.clear_observers() assert notifier.observer_count == 0 class TestRecoveryManager: """Test suite for RecoveryManager.""" def test_save_and_load_checkpoint(self, temp_work_dir): """Test saving and loading checkpoint.""" recovery = RecoveryManager(temp_work_dir) checkpoint = CheckpointData( work_id="work123", current_chapter_index=5, completed_indices=[0, 1, 2, 3, 4], failed_indices=[] ) recovery.save_checkpoint(checkpoint) assert recovery.has_checkpoint() is True loaded = recovery.load_checkpoint() assert loaded is not None assert loaded.work_id == "work123" assert loaded.current_chapter_index == 5 def test_delete_checkpoint(self, temp_work_dir): """Test deleting checkpoint.""" recovery = RecoveryManager(temp_work_dir) checkpoint = CheckpointData( work_id="work123", current_chapter_index=0, completed_indices=[], failed_indices=[] ) recovery.save_checkpoint(checkpoint) assert recovery.has_checkpoint() is True recovery.delete_checkpoint() assert recovery.has_checkpoint() is False def test_get_recovery_state(self, temp_work_dir): """Test getting recovery state.""" recovery = RecoveryManager(temp_work_dir) checkpoint = CheckpointData( work_id="work123", current_chapter_index=3, completed_indices=[0, 1, 2], failed_indices=[] ) recovery.save_checkpoint(checkpoint) state = recovery.get_recovery_state() assert state is not None assert state["recoverable"] is True assert state["work_id"] == "work123" assert state["resume_index"] == 3 def test_can_resume(self, temp_work_dir): """Test checking if resume is possible.""" recovery = RecoveryManager(temp_work_dir) assert recovery.can_resume() is False checkpoint = CheckpointData( work_id="work123", current_chapter_index=0, completed_indices=[], failed_indices=[] ) recovery.save_checkpoint(checkpoint) assert recovery.can_resume() is True class TestModels: """Test suite for scheduler models.""" def test_chapter_task_model(self): """Test ChapterTask model.""" task = ChapterTask( "ch1", 0, "第一章", "内容" ) assert task.is_finished is False assert task.can_retry is False task.status = TaskStatus.COMPLETED assert task.is_finished is True def test_task_status_enum(self): """Test TaskStatus enum values.""" assert TaskStatus.PENDING.value == "pending" assert TaskStatus.IN_PROGRESS.value == "in_progress" assert TaskStatus.COMPLETED.value == "completed" assert TaskStatus.FAILED.value == "failed" def test_scheduler_state_enum(self): """Test SchedulerState enum values.""" assert SchedulerState.IDLE.value == "idle" assert SchedulerState.RUNNING.value == "running" assert SchedulerState.PAUSED.value == "paused" assert SchedulerState.COMPLETED.value == "completed" def test_pipeline_progress_model(self): """Test PipelineProgress model.""" progress = PipelineProgress( total_chapters=10, completed_chapters=5 ) assert progress.pending_chapters == 5 assert progress.completion_rate == 0.5 def test_checkpoint_data_model(self): """Test CheckpointData model.""" checkpoint = CheckpointData( work_id="work123", current_chapter_index=5, completed_indices=[0, 1, 2, 3, 4], failed_indices=[] ) assert checkpoint.work_id == "work123" assert checkpoint.current_chapter_index == 5 class TestConsoleProgressObserver: """Test suite for ConsoleProgressObserver.""" def test_observer_creation(self): """Test creating console observer.""" observer = ConsoleProgressObserver(verbose=False) assert observer.verbose is False def test_observer_methods_exist(self): """Test that all observer methods exist.""" observer = ConsoleProgressObserver() assert hasattr(observer, "on_pipeline_start") assert hasattr(observer, "on_chapter_complete") assert hasattr(observer, "on_chapter_failed") class TestCallbackProgressObserver: """Test suite for CallbackProgressObserver.""" def test_callback_invocation(self): """Test that callbacks are invoked.""" calls = [] on_start = lambda total: calls.append("start") on_complete = lambda task: calls.append("complete") observer = CallbackProgressObserver( on_start=on_start, on_chapter_complete=on_complete ) observer.on_pipeline_start(10) observer.on_chapter_complete(ChapterTask("ch1", 0, "Title", "Content")) assert "start" in calls assert "complete" in calls