| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- """
- Unit tests for the GPU fallback handler module.
- Tests cover automatic fallback to batch_size=1 on GPU errors.
- """
- import sys
- from unittest.mock import Mock, MagicMock, patch
- # Mock torch and transformers before importing
- sys_mock = Mock()
- sys.modules["torch"] = sys_mock
- sys.modules["transformers"] = sys_mock
- import pytest
- from src.translator.fallback_handler import (
- FallbackHandler,
- FallbackEvent,
- FallbackReason,
- )
- class MockEngine:
- """Mock translation engine for testing."""
- def __init__(self, fail_on_batch: bool = False, error_msg: str = ""):
- self.fail_on_batch = fail_on_batch
- self.error_msg = error_msg
- self._is_gpu_enabled = True
- self._call_count = 0
- def translate(self, text: str, src_lang: str = "zh", tgt_lang: str = "en", max_length: int = None) -> str:
- return f"Translated: {text}"
- def translate_batch(
- self,
- texts: list,
- src_lang: str = "zh",
- tgt_lang: str = "en",
- batch_size: int = 4,
- max_length: int = None
- ) -> list:
- self._call_count += 1
- # Only fail on first call (with larger batch size), succeed with batch_size=1
- if self.fail_on_batch and batch_size > 1:
- raise RuntimeError(self.error_msg or "CUDA out of memory")
- return [f"Translated: {t}" for t in texts]
- # Use a property with setter for flexibility
- @property
- def is_gpu_enabled(self):
- return self._is_gpu_enabled
- @is_gpu_enabled.setter
- def is_gpu_enabled(self, value):
- self._is_gpu_enabled = value
- class TestFallbackEvent:
- """Test cases for FallbackEvent dataclass."""
- def test_create_fallback_event(self):
- """Test creating a fallback event."""
- event = FallbackEvent(
- original_batch_size=4,
- reason=FallbackReason.OUT_OF_MEMORY,
- error_message="CUDA out of memory",
- )
- assert event.original_batch_size == 4
- assert event.reason == FallbackReason.OUT_OF_MEMORY
- assert event.error_message == "CUDA out of memory"
- def test_to_dict(self):
- """Test converting fallback event to dictionary."""
- event = FallbackEvent(
- original_batch_size=4,
- reason=FallbackReason.OUT_OF_MEMORY,
- error_message="CUDA error",
- )
- data = event.to_dict()
- assert data["original_batch_size"] == 4
- assert data["reason"] == "out_of_memory"
- assert data["error_message"] == "CUDA error"
- assert "timestamp" in data
- assert data["recovered"] is False # default
- class TestFallbackHandler:
- """Test cases for FallbackHandler class."""
- def test_init(self):
- """Test FallbackHandler initialization."""
- engine = MockEngine()
- handler = FallbackHandler(engine)
- assert handler.engine == engine
- assert handler.fallback_callback is None
- assert not handler.has_fallback_occurred()
- def test_init_with_callback(self):
- """Test FallbackHandler with callback."""
- engine = MockEngine()
- callback = Mock()
- handler = FallbackHandler(engine, fallback_callback=callback)
- assert handler.fallback_callback == callback
- def test_translate_single(self):
- """Test single text translation (no fallback)."""
- engine = MockEngine()
- handler = FallbackHandler(engine)
- result = handler.translate("Hello")
- assert result == "Translated: Hello"
- assert not handler.has_fallback_occurred()
- def test_translate_batch_success(self):
- """Test successful batch translation."""
- engine = MockEngine()
- handler = FallbackHandler(engine)
- texts = ["Hello", "World"]
- results = handler.translate_batch(texts)
- assert len(results) == 2
- assert results[0] == "Translated: Hello"
- assert results[1] == "Translated: World"
- assert not handler.has_fallback_occurred()
- def test_translate_batch_oom_fallback(self):
- """Test batch translation with OOM fallback."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- handler = FallbackHandler(engine)
- texts = ["Hello", "World"]
- results = handler.translate_batch(texts, batch_size=4)
- assert len(results) == 2
- assert handler.has_fallback_occurred()
- events = handler.get_fallback_events()
- assert len(events) == 1
- assert events[0].reason == FallbackReason.OUT_OF_MEMORY
- assert events[0].original_batch_size == 4
- assert events[0].recovered is True
- def test_translate_batch_cuda_error_fallback(self):
- """Test batch translation with CUDA error fallback."""
- # Create a custom mock that fails on batch_size > 1, succeeds on batch_size=1
- class CustomMockEngine:
- def __init__(self):
- self._gpu = True # Use different name to avoid property conflict
- self.call_count = 0
- def translate(self, text, src_lang="zh", tgt_lang="en", max_length=None):
- return f"Translated: {text}"
- def translate_batch(self, texts, src_lang="zh", tgt_lang="en", batch_size=4, max_length=None):
- self.call_count += 1
- # Only fail when batch_size > 1
- if batch_size > 1:
- raise RuntimeError("CUDA error: invalid argument")
- return [f"Translated: {t}" for t in texts]
- @property
- def is_gpu_enabled(self):
- return self._gpu
- engine = CustomMockEngine()
- handler = FallbackHandler(engine)
- texts = ["Hello", "World"]
- results = handler.translate_batch(texts, batch_size=4)
- assert len(results) == 2
- events = handler.get_fallback_events()
- assert events[0].reason == FallbackReason.CUDA_ERROR
- def test_translate_batch_runtime_error_fallback(self):
- """Test batch translation with runtime error fallback."""
- # Create a custom mock that fails on batch_size > 1, succeeds on batch_size=1
- class CustomMockEngine:
- def __init__(self):
- self._gpu = True # Use different name to avoid property conflict
- self.call_count = 0
- def translate(self, text, src_lang="zh", tgt_lang="en", max_length=None):
- return f"Translated: {text}"
- def translate_batch(self, texts, src_lang="zh", tgt_lang="en", batch_size=4, max_length=None):
- self.call_count += 1
- # Only fail when batch_size > 1
- # Use a message that matches RUNTIME_ERROR pattern
- if batch_size > 1:
- raise RuntimeError("RuntimeError: CUDA device error")
- return [f"Translated: {t}" for t in texts]
- @property
- def is_gpu_enabled(self):
- return self._gpu
- engine = CustomMockEngine()
- handler = FallbackHandler(engine)
- texts = ["Hello", "World"]
- results = handler.translate_batch(texts, batch_size=4)
- assert len(results) == 2
- events = handler.get_fallback_events()
- assert events[0].reason == FallbackReason.RUNTIME_ERROR
- def test_fallback_callback_invoked(self):
- """Test that fallback callback is invoked."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- callback = Mock()
- handler = FallbackHandler(engine, fallback_callback=callback)
- texts = ["Hello", "World"]
- handler.translate_batch(texts, batch_size=4)
- assert callback.called
- event_arg = callback.call_args[0][0]
- assert isinstance(event_arg, FallbackEvent)
- assert event_arg.reason == FallbackReason.OUT_OF_MEMORY
- def test_no_fallback_for_non_gpu_errors(self):
- """Test that non-GPU errors are not caught."""
- engine = MockEngine(fail_on_batch=True, error_msg="Some other error")
- handler = FallbackHandler(engine)
- texts = ["Hello", "World"]
- with pytest.raises(RuntimeError):
- handler.translate_batch(texts, batch_size=4)
- assert not handler.has_fallback_occurred()
- def test_no_fallback_when_cpu_mode(self):
- """Test that fallback doesn't occur on CPU mode."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- engine.is_gpu_enabled = False # Use property setter
- handler = FallbackHandler(engine)
- texts = ["Hello", "World"]
- # Should raise the error, not fall back
- with pytest.raises(RuntimeError):
- handler.translate_batch(texts, batch_size=4)
- def test_no_fallback_when_batch_size_is_1(self):
- """Test that fallback doesn't occur when batch_size is already 1."""
- # Use a custom engine that always fails
- engine = Mock()
- engine._is_gpu_enabled = True
- def always_fail(texts, src_lang=None, tgt_lang=None, batch_size=4, max_length=None):
- raise RuntimeError("CUDA out of memory")
- engine.translate = Mock(return_value="Translated: test")
- engine.translate_batch = always_fail
- handler = FallbackHandler(engine)
- texts = ["Hello", "World"]
- # Should raise the error, not fall back (already at batch_size=1)
- with pytest.raises(RuntimeError):
- handler.translate_batch(texts, batch_size=1)
- # Verify no fallback occurred
- assert not handler.has_fallback_occurred()
- def test_fallback_failure_propagates(self):
- """Test that fallback failure is propagated."""
- # Create an engine that fails both in batch and in retry
- engine = Mock()
- engine.is_gpu_enabled = True
- def fail_batch(*args, **kwargs):
- raise RuntimeError("CUDA out of memory")
- engine.translate = Mock(return_value="Translated: test")
- engine.translate_batch = fail_batch
- handler = FallbackHandler(engine)
- texts = ["Hello", "World"]
- with pytest.raises(RuntimeError):
- handler.translate_batch(texts, batch_size=4)
- events = handler.get_fallback_events()
- assert events[0].recovered is False
- def test_get_fallback_summary(self):
- """Test getting fallback summary."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- handler = FallbackHandler(engine)
- # No events initially
- summary = handler.get_fallback_summary()
- assert summary["total_fallbacks"] == 0
- assert summary["successful_recoveries"] == 0
- # Trigger a fallback
- handler.translate_batch(["Hello"], batch_size=4)
- summary = handler.get_fallback_summary()
- assert summary["total_fallbacks"] == 1
- assert summary["successful_recoveries"] == 1
- assert summary["last_event"] is not None
- def test_clear_fallback_history(self):
- """Test clearing fallback history."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- handler = FallbackHandler(engine)
- # Trigger a fallback
- handler.translate_batch(["Hello"], batch_size=4)
- assert handler.has_fallback_occurred()
- # Clear history
- handler.clear_fallback_history()
- assert not handler.has_fallback_occurred()
- assert len(handler.get_fallback_events()) == 0
- def test_get_user_message_with_fallback(self):
- """Test getting user message after fallback."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- handler = FallbackHandler(engine)
- # Before fallback
- assert handler.get_user_message() is None
- # After fallback
- handler.translate_batch(["Hello"], batch_size=4)
- message = handler.get_user_message()
- assert message is not None
- assert "reduced batch size" in message.lower()
- def test_get_user_message_after_failed_recovery(self):
- """Test getting user message after failed recovery."""
- engine = Mock()
- engine.is_gpu_enabled = True
- engine.translate = Mock(return_value="Translated")
- def fail_batch(*args, **kwargs):
- raise RuntimeError("CUDA out of memory")
- engine.translate_batch = fail_batch
- handler = FallbackHandler(engine)
- try:
- handler.translate_batch(["Hello"], batch_size=4)
- except RuntimeError:
- pass
- message = handler.get_user_message()
- assert message is not None
- assert "failed" in message.lower()
- def test_is_recommended_batch_size(self):
- """Test checking recommended batch size."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- handler = FallbackHandler(engine)
- # Before fallback, any size > 1 is recommended
- assert handler.is_recommended_batch_size(4)
- assert not handler.is_recommended_batch_size(1)
- # After fallback, only batch_size=1 is recommended
- handler.translate_batch(["Hello"], batch_size=4)
- assert not handler.is_recommended_batch_size(4)
- assert handler.is_recommended_batch_size(1)
- def test_get_recommended_batch_size(self):
- """Test getting recommended batch size."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- handler = FallbackHandler(engine)
- # Before fallback
- assert handler.get_recommended_batch_size() == 4
- # After fallback
- handler.translate_batch(["Hello"], batch_size=4)
- assert handler.get_recommended_batch_size() == 1
- def test_multiple_fallback_events(self):
- """Test tracking multiple fallback events."""
- engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
- handler = FallbackHandler(engine)
- # Trigger multiple fallbacks
- for _ in range(3):
- handler.translate_batch(["Hello"], batch_size=4)
- events = handler.get_fallback_events()
- assert len(events) == 3
- summary = handler.get_fallback_summary()
- assert summary["total_fallbacks"] == 3
|