| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459 |
- """
- Integration tests for the full pipeline scheduler.
- """
- import pytest
- import asyncio
- from pathlib import Path
- from datetime import datetime
- from unittest.mock import Mock, MagicMock, patch
- # Mock torch before importing anything that depends on it
- import sys
- sys.modules['torch'] = MagicMock()
- sys.modules['transformers'] = MagicMock()
- from src.scheduler import (
- TaskQueue,
- RetryManager,
- ProgressNotifier,
- ConsoleProgressObserver,
- CallbackProgressObserver,
- RecoveryManager,
- ChapterTask,
- TaskStatus,
- SchedulerState,
- RetryConfig,
- PipelineProgress,
- CheckpointData
- )
- @pytest.fixture
- def temp_work_dir(tmp_path):
- """Create temporary working directory."""
- work_dir = tmp_path / "work"
- work_dir.mkdir()
- return work_dir
- @pytest.fixture
- def sample_novel_file(tmp_path):
- """Create a sample novel file for testing."""
- novel_file = tmp_path / "novel.txt"
- content = """第一章 开始
- 这是第一章的内容,包含一些文字。
- 林风站在山顶,看着远方的城市。
- 第二章 继续
- 这是第二章的内容。
- 他开始了新的旅程。"""
- novel_file.write_text(content, encoding="utf-8")
- return novel_file
- class TestTaskQueue:
- """Test suite for TaskQueue."""
- def test_add_and_get_task(self):
- """Test adding and retrieving tasks."""
- queue = TaskQueue()
- task = queue.add_chapter(
- "ch1",
- 0,
- "第一章",
- "这是内容"
- )
- assert task.chapter_id == "ch1"
- assert task.status == TaskStatus.PENDING
- def test_get_next_pending(self):
- """Test getting next pending task."""
- queue = TaskQueue()
- queue.add_chapter("ch1", 0, "第一章", "内容1")
- queue.add_chapter("ch2", 1, "第二章", "内容2")
- task = queue.get_next_pending()
- assert task is not None
- assert task.chapter_id == "ch1"
- assert task.status == TaskStatus.IN_PROGRESS
- def test_mark_completed(self):
- """Test marking task as completed."""
- queue = TaskQueue()
- queue.add_chapter("ch1", 0, "第一章", "内容")
- queue.get_next_pending()
- queue.mark_completed("ch1", "Translated content")
- retrieved = queue.get_task("ch1")
- assert retrieved.status == TaskStatus.COMPLETED
- assert retrieved.translated_content == "Translated content"
- def test_mark_failed(self):
- """Test marking task as failed."""
- queue = TaskQueue()
- queue.add_chapter("ch1", 0, "第一章", "内容")
- queue.get_next_pending()
- queue.mark_failed("ch1", "Translation failed")
- retrieved = queue.get_task("ch1")
- assert retrieved.status == TaskStatus.FAILED
- assert "Translation failed" in retrieved.error_message
- def test_get_stats(self):
- """Test getting queue statistics."""
- queue = TaskQueue()
- queue.add_chapter("ch1", 0, "第一章", "内容1")
- queue.add_chapter("ch2", 1, "第二章", "内容2")
- stats = queue.get_stats()
- assert stats.total == 2
- assert stats.pending == 2
- def test_reset_for_retry(self):
- """Test resetting task for retry."""
- queue = TaskQueue()
- queue.add_chapter("ch1", 0, "第一章", "内容")
- queue.get_next_pending()
- queue.mark_failed("ch1", "Error")
- assert queue.reset_for_retry("ch1") is True
- task = queue.get_task("ch1")
- assert task.status == TaskStatus.PENDING
- def test_has_pending(self):
- """Test checking for pending tasks."""
- queue = TaskQueue()
- assert not queue.has_pending()
- queue.add_chapter("ch1", 0, "第一章", "内容")
- assert queue.has_pending()
- def test_is_complete(self):
- """Test checking if queue is complete."""
- queue = TaskQueue()
- queue.add_chapter("ch1", 0, "第一章", "内容")
- assert not queue.is_complete()
- queue.get_next_pending()
- queue.mark_completed("ch1", "Translated")
- assert queue.is_complete()
- def test_iteration(self):
- """Test iterating over tasks."""
- queue = TaskQueue()
- queue.add_chapter("ch2", 1, "第二章", "内容2")
- queue.add_chapter("ch1", 0, "第一章", "内容1")
- tasks = list(queue)
- assert len(tasks) == 2
- assert tasks[0].chapter_index == 0
- assert tasks[1].chapter_index == 1
- class TestRetryManager:
- """Test suite for RetryManager."""
- def test_should_retry_on_timeout(self):
- """Test retry decision for timeout errors."""
- manager = RetryManager()
- task = ChapterTask("ch1", 0, "第一章", "内容")
- task.retry_count = 1
- assert manager.should_retry(task, "Request timeout") is True
- def test_should_not_retry_exceeded(self):
- """Test no retry after max attempts."""
- manager = RetryManager()
- task = ChapterTask("ch1", 0, "第一章", "内容")
- task.retry_count = 3
- assert manager.should_retry(task, "Error") is False
- def test_get_retry_delay_exponential(self):
- """Test exponential backoff delay calculation."""
- config = RetryConfig(exponential_backoff=True, base_delay=1.0)
- manager = RetryManager(config)
- assert manager.get_retry_delay(1) == 1.0
- assert manager.get_retry_delay(2) == 2.0
- assert manager.get_retry_delay(3) == 4.0
- def test_get_retry_delay_linear(self):
- """Test linear delay calculation."""
- config = RetryConfig(exponential_backoff=False, base_delay=2.0)
- manager = RetryManager(config)
- assert manager.get_retry_delay(1) == 2.0
- assert manager.get_retry_delay(2) == 2.0
- def test_max_delay_cap(self):
- """Test that delay is capped at max_delay."""
- config = RetryConfig(
- exponential_backoff=True,
- base_delay=1.0,
- max_delay=5.0
- )
- manager = RetryManager(config)
- # Should cap at 5.0
- assert manager.get_retry_delay(10) <= 5.0
- def test_record_and_get_history(self):
- """Test recording and retrieving retry history."""
- manager = RetryManager()
- record = manager.record_retry("ch1", 1, "Error", 1.0)
- assert record.chapter_id == "ch1"
- assert record.attempt_number == 1
- history = manager.get_retry_history("ch1")
- assert len(history) == 1
- def test_clear_retry_history(self):
- """Test clearing retry history."""
- manager = RetryManager()
- manager.record_retry("ch1", 1, "Error", 1.0)
- manager.clear_retry_history("ch1")
- history = manager.get_retry_history("ch1")
- assert len(history) == 0
- def test_get_stats(self):
- """Test getting retry statistics."""
- manager = RetryManager()
- manager.record_retry("ch1", 1, "Error", 1.0, success=True)
- manager.record_retry("ch2", 1, "Error", 1.0, success=False)
- stats = manager.get_stats()
- assert stats["total_retries"] == 2
- assert stats["successful_retries"] == 1
- assert stats["failed_retries"] == 1
- class TestProgressNotifier:
- """Test suite for ProgressNotifier."""
- def test_register_and_notify(self):
- """Test registering observer and sending notifications."""
- notifier = ProgressNotifier()
- observer = CallbackProgressObserver(
- on_start=lambda total: None,
- on_chapter_complete=lambda task: None
- )
- notifier.register(observer)
- assert notifier.observer_count == 1
- task = ChapterTask("ch1", 0, "第一章", "内容")
- notifier.notify_chapter_complete(task)
- def test_unregister_observer(self):
- """Test unregistering observer."""
- notifier = ProgressNotifier()
- observer = CallbackProgressObserver()
- notifier.register(observer)
- assert notifier.observer_count == 1
- notifier.unregister(observer)
- assert notifier.observer_count == 0
- def test_event_history(self):
- """Test event history tracking."""
- notifier = ProgressNotifier()
- task = ChapterTask("ch1", 0, "第一章", "内容")
- notifier.notify_chapter_start(task)
- notifier.notify_chapter_complete(task)
- history = notifier.get_event_history()
- assert len(history) == 2
- def test_clear_observers(self):
- """Test clearing all observers."""
- notifier = ProgressNotifier()
- notifier.register(CallbackProgressObserver())
- notifier.register(CallbackProgressObserver())
- assert notifier.observer_count == 2
- notifier.clear_observers()
- assert notifier.observer_count == 0
- class TestRecoveryManager:
- """Test suite for RecoveryManager."""
- def test_save_and_load_checkpoint(self, temp_work_dir):
- """Test saving and loading checkpoint."""
- recovery = RecoveryManager(temp_work_dir)
- checkpoint = CheckpointData(
- work_id="work123",
- current_chapter_index=5,
- completed_indices=[0, 1, 2, 3, 4],
- failed_indices=[]
- )
- recovery.save_checkpoint(checkpoint)
- assert recovery.has_checkpoint() is True
- loaded = recovery.load_checkpoint()
- assert loaded is not None
- assert loaded.work_id == "work123"
- assert loaded.current_chapter_index == 5
- def test_delete_checkpoint(self, temp_work_dir):
- """Test deleting checkpoint."""
- recovery = RecoveryManager(temp_work_dir)
- checkpoint = CheckpointData(
- work_id="work123",
- current_chapter_index=0,
- completed_indices=[],
- failed_indices=[]
- )
- recovery.save_checkpoint(checkpoint)
- assert recovery.has_checkpoint() is True
- recovery.delete_checkpoint()
- assert recovery.has_checkpoint() is False
- def test_get_recovery_state(self, temp_work_dir):
- """Test getting recovery state."""
- recovery = RecoveryManager(temp_work_dir)
- checkpoint = CheckpointData(
- work_id="work123",
- current_chapter_index=3,
- completed_indices=[0, 1, 2],
- failed_indices=[]
- )
- recovery.save_checkpoint(checkpoint)
- state = recovery.get_recovery_state()
- assert state is not None
- assert state["recoverable"] is True
- assert state["work_id"] == "work123"
- assert state["resume_index"] == 3
- def test_can_resume(self, temp_work_dir):
- """Test checking if resume is possible."""
- recovery = RecoveryManager(temp_work_dir)
- assert recovery.can_resume() is False
- checkpoint = CheckpointData(
- work_id="work123",
- current_chapter_index=0,
- completed_indices=[],
- failed_indices=[]
- )
- recovery.save_checkpoint(checkpoint)
- assert recovery.can_resume() is True
- class TestModels:
- """Test suite for scheduler models."""
- def test_chapter_task_model(self):
- """Test ChapterTask model."""
- task = ChapterTask(
- "ch1",
- 0,
- "第一章",
- "内容"
- )
- assert task.is_finished is False
- assert task.can_retry is False
- task.status = TaskStatus.COMPLETED
- assert task.is_finished is True
- def test_task_status_enum(self):
- """Test TaskStatus enum values."""
- assert TaskStatus.PENDING.value == "pending"
- assert TaskStatus.IN_PROGRESS.value == "in_progress"
- assert TaskStatus.COMPLETED.value == "completed"
- assert TaskStatus.FAILED.value == "failed"
- def test_scheduler_state_enum(self):
- """Test SchedulerState enum values."""
- assert SchedulerState.IDLE.value == "idle"
- assert SchedulerState.RUNNING.value == "running"
- assert SchedulerState.PAUSED.value == "paused"
- assert SchedulerState.COMPLETED.value == "completed"
- def test_pipeline_progress_model(self):
- """Test PipelineProgress model."""
- progress = PipelineProgress(
- total_chapters=10,
- completed_chapters=5
- )
- assert progress.pending_chapters == 5
- assert progress.completion_rate == 0.5
- def test_checkpoint_data_model(self):
- """Test CheckpointData model."""
- checkpoint = CheckpointData(
- work_id="work123",
- current_chapter_index=5,
- completed_indices=[0, 1, 2, 3, 4],
- failed_indices=[]
- )
- assert checkpoint.work_id == "work123"
- assert checkpoint.current_chapter_index == 5
- class TestConsoleProgressObserver:
- """Test suite for ConsoleProgressObserver."""
- def test_observer_creation(self):
- """Test creating console observer."""
- observer = ConsoleProgressObserver(verbose=False)
- assert observer.verbose is False
- def test_observer_methods_exist(self):
- """Test that all observer methods exist."""
- observer = ConsoleProgressObserver()
- assert hasattr(observer, "on_pipeline_start")
- assert hasattr(observer, "on_chapter_complete")
- assert hasattr(observer, "on_chapter_failed")
- class TestCallbackProgressObserver:
- """Test suite for CallbackProgressObserver."""
- def test_callback_invocation(self):
- """Test that callbacks are invoked."""
- calls = []
- on_start = lambda total: calls.append("start")
- on_complete = lambda task: calls.append("complete")
- observer = CallbackProgressObserver(
- on_start=on_start,
- on_chapter_complete=on_complete
- )
- observer.on_pipeline_start(10)
- observer.on_chapter_complete(ChapterTask("ch1", 0, "Title", "Content"))
- assert "start" in calls
- assert "complete" in calls
|