|
|
@@ -0,0 +1,406 @@
|
|
|
+"""
|
|
|
+Unit tests for the GPU fallback handler module.
|
|
|
+
|
|
|
+Tests cover automatic fallback to batch_size=1 on GPU errors.
|
|
|
+"""
|
|
|
+
|
|
|
+import sys
|
|
|
+from unittest.mock import Mock, MagicMock, patch
|
|
|
+
|
|
|
+# Mock torch and transformers before importing
|
|
|
+sys_mock = Mock()
|
|
|
+sys.modules["torch"] = sys_mock
|
|
|
+sys.modules["transformers"] = sys_mock
|
|
|
+
|
|
|
+import pytest
|
|
|
+
|
|
|
+from src.translator.fallback_handler import (
|
|
|
+ FallbackHandler,
|
|
|
+ FallbackEvent,
|
|
|
+ FallbackReason,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+class MockEngine:
|
|
|
+ """Mock translation engine for testing."""
|
|
|
+
|
|
|
+ def __init__(self, fail_on_batch: bool = False, error_msg: str = ""):
|
|
|
+ self.fail_on_batch = fail_on_batch
|
|
|
+ self.error_msg = error_msg
|
|
|
+ self._is_gpu_enabled = True
|
|
|
+ self._call_count = 0
|
|
|
+
|
|
|
+ def translate(self, text: str, src_lang: str = "zh", tgt_lang: str = "en", max_length: int = None) -> str:
|
|
|
+ return f"Translated: {text}"
|
|
|
+
|
|
|
+ def translate_batch(
|
|
|
+ self,
|
|
|
+ texts: list,
|
|
|
+ src_lang: str = "zh",
|
|
|
+ tgt_lang: str = "en",
|
|
|
+ batch_size: int = 4,
|
|
|
+ max_length: int = None
|
|
|
+ ) -> list:
|
|
|
+ self._call_count += 1
|
|
|
+ # Only fail on first call (with larger batch size), succeed with batch_size=1
|
|
|
+ if self.fail_on_batch and batch_size > 1:
|
|
|
+ raise RuntimeError(self.error_msg or "CUDA out of memory")
|
|
|
+ return [f"Translated: {t}" for t in texts]
|
|
|
+
|
|
|
+ # Use a property with setter for flexibility
|
|
|
+ @property
|
|
|
+ def is_gpu_enabled(self):
|
|
|
+ return self._is_gpu_enabled
|
|
|
+
|
|
|
+ @is_gpu_enabled.setter
|
|
|
+ def is_gpu_enabled(self, value):
|
|
|
+ self._is_gpu_enabled = value
|
|
|
+
|
|
|
+
|
|
|
+class TestFallbackEvent:
|
|
|
+ """Test cases for FallbackEvent dataclass."""
|
|
|
+
|
|
|
+ def test_create_fallback_event(self):
|
|
|
+ """Test creating a fallback event."""
|
|
|
+ event = FallbackEvent(
|
|
|
+ original_batch_size=4,
|
|
|
+ reason=FallbackReason.OUT_OF_MEMORY,
|
|
|
+ error_message="CUDA out of memory",
|
|
|
+ )
|
|
|
+
|
|
|
+ assert event.original_batch_size == 4
|
|
|
+ assert event.reason == FallbackReason.OUT_OF_MEMORY
|
|
|
+ assert event.error_message == "CUDA out of memory"
|
|
|
+
|
|
|
+ def test_to_dict(self):
|
|
|
+ """Test converting fallback event to dictionary."""
|
|
|
+ event = FallbackEvent(
|
|
|
+ original_batch_size=4,
|
|
|
+ reason=FallbackReason.OUT_OF_MEMORY,
|
|
|
+ error_message="CUDA error",
|
|
|
+ )
|
|
|
+
|
|
|
+ data = event.to_dict()
|
|
|
+
|
|
|
+ assert data["original_batch_size"] == 4
|
|
|
+ assert data["reason"] == "out_of_memory"
|
|
|
+ assert data["error_message"] == "CUDA error"
|
|
|
+ assert "timestamp" in data
|
|
|
+ assert data["recovered"] is False # default
|
|
|
+
|
|
|
+
|
|
|
+class TestFallbackHandler:
|
|
|
+ """Test cases for FallbackHandler class."""
|
|
|
+
|
|
|
+ def test_init(self):
|
|
|
+ """Test FallbackHandler initialization."""
|
|
|
+ engine = MockEngine()
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ assert handler.engine == engine
|
|
|
+ assert handler.fallback_callback is None
|
|
|
+ assert not handler.has_fallback_occurred()
|
|
|
+
|
|
|
+ def test_init_with_callback(self):
|
|
|
+ """Test FallbackHandler with callback."""
|
|
|
+ engine = MockEngine()
|
|
|
+ callback = Mock()
|
|
|
+ handler = FallbackHandler(engine, fallback_callback=callback)
|
|
|
+
|
|
|
+ assert handler.fallback_callback == callback
|
|
|
+
|
|
|
+ def test_translate_single(self):
|
|
|
+ """Test single text translation (no fallback)."""
|
|
|
+ engine = MockEngine()
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ result = handler.translate("Hello")
|
|
|
+
|
|
|
+ assert result == "Translated: Hello"
|
|
|
+ assert not handler.has_fallback_occurred()
|
|
|
+
|
|
|
+ def test_translate_batch_success(self):
|
|
|
+ """Test successful batch translation."""
|
|
|
+ engine = MockEngine()
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+ results = handler.translate_batch(texts)
|
|
|
+
|
|
|
+ assert len(results) == 2
|
|
|
+ assert results[0] == "Translated: Hello"
|
|
|
+ assert results[1] == "Translated: World"
|
|
|
+ assert not handler.has_fallback_occurred()
|
|
|
+
|
|
|
+ def test_translate_batch_oom_fallback(self):
|
|
|
+ """Test batch translation with OOM fallback."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+ results = handler.translate_batch(texts, batch_size=4)
|
|
|
+
|
|
|
+ assert len(results) == 2
|
|
|
+ assert handler.has_fallback_occurred()
|
|
|
+ events = handler.get_fallback_events()
|
|
|
+ assert len(events) == 1
|
|
|
+ assert events[0].reason == FallbackReason.OUT_OF_MEMORY
|
|
|
+ assert events[0].original_batch_size == 4
|
|
|
+ assert events[0].recovered is True
|
|
|
+
|
|
|
+ def test_translate_batch_cuda_error_fallback(self):
|
|
|
+ """Test batch translation with CUDA error fallback."""
|
|
|
+ # Create a custom mock that fails on batch_size > 1, succeeds on batch_size=1
|
|
|
+ class CustomMockEngine:
|
|
|
+ def __init__(self):
|
|
|
+ self._gpu = True # Use different name to avoid property conflict
|
|
|
+ self.call_count = 0
|
|
|
+
|
|
|
+ def translate(self, text, src_lang="zh", tgt_lang="en", max_length=None):
|
|
|
+ return f"Translated: {text}"
|
|
|
+
|
|
|
+ def translate_batch(self, texts, src_lang="zh", tgt_lang="en", batch_size=4, max_length=None):
|
|
|
+ self.call_count += 1
|
|
|
+ # Only fail when batch_size > 1
|
|
|
+ if batch_size > 1:
|
|
|
+ raise RuntimeError("CUDA error: invalid argument")
|
|
|
+ return [f"Translated: {t}" for t in texts]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_gpu_enabled(self):
|
|
|
+ return self._gpu
|
|
|
+
|
|
|
+ engine = CustomMockEngine()
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+ results = handler.translate_batch(texts, batch_size=4)
|
|
|
+
|
|
|
+ assert len(results) == 2
|
|
|
+ events = handler.get_fallback_events()
|
|
|
+ assert events[0].reason == FallbackReason.CUDA_ERROR
|
|
|
+
|
|
|
+ def test_translate_batch_runtime_error_fallback(self):
|
|
|
+ """Test batch translation with runtime error fallback."""
|
|
|
+ # Create a custom mock that fails on batch_size > 1, succeeds on batch_size=1
|
|
|
+ class CustomMockEngine:
|
|
|
+ def __init__(self):
|
|
|
+ self._gpu = True # Use different name to avoid property conflict
|
|
|
+ self.call_count = 0
|
|
|
+
|
|
|
+ def translate(self, text, src_lang="zh", tgt_lang="en", max_length=None):
|
|
|
+ return f"Translated: {text}"
|
|
|
+
|
|
|
+ def translate_batch(self, texts, src_lang="zh", tgt_lang="en", batch_size=4, max_length=None):
|
|
|
+ self.call_count += 1
|
|
|
+ # Only fail when batch_size > 1
|
|
|
+ # Use a message that matches RUNTIME_ERROR pattern
|
|
|
+ if batch_size > 1:
|
|
|
+ raise RuntimeError("RuntimeError: CUDA device error")
|
|
|
+ return [f"Translated: {t}" for t in texts]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_gpu_enabled(self):
|
|
|
+ return self._gpu
|
|
|
+
|
|
|
+ engine = CustomMockEngine()
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+ results = handler.translate_batch(texts, batch_size=4)
|
|
|
+
|
|
|
+ assert len(results) == 2
|
|
|
+ events = handler.get_fallback_events()
|
|
|
+ assert events[0].reason == FallbackReason.RUNTIME_ERROR
|
|
|
+
|
|
|
+ def test_fallback_callback_invoked(self):
|
|
|
+ """Test that fallback callback is invoked."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ callback = Mock()
|
|
|
+ handler = FallbackHandler(engine, fallback_callback=callback)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+ handler.translate_batch(texts, batch_size=4)
|
|
|
+
|
|
|
+ assert callback.called
|
|
|
+ event_arg = callback.call_args[0][0]
|
|
|
+ assert isinstance(event_arg, FallbackEvent)
|
|
|
+ assert event_arg.reason == FallbackReason.OUT_OF_MEMORY
|
|
|
+
|
|
|
+ def test_no_fallback_for_non_gpu_errors(self):
|
|
|
+ """Test that non-GPU errors are not caught."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="Some other error")
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+
|
|
|
+ with pytest.raises(RuntimeError):
|
|
|
+ handler.translate_batch(texts, batch_size=4)
|
|
|
+
|
|
|
+ assert not handler.has_fallback_occurred()
|
|
|
+
|
|
|
+ def test_no_fallback_when_cpu_mode(self):
|
|
|
+ """Test that fallback doesn't occur on CPU mode."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ engine.is_gpu_enabled = False # Use property setter
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+
|
|
|
+ # Should raise the error, not fall back
|
|
|
+ with pytest.raises(RuntimeError):
|
|
|
+ handler.translate_batch(texts, batch_size=4)
|
|
|
+
|
|
|
+ def test_no_fallback_when_batch_size_is_1(self):
|
|
|
+ """Test that fallback doesn't occur when batch_size is already 1."""
|
|
|
+ # Use a custom engine that always fails
|
|
|
+ engine = Mock()
|
|
|
+ engine._is_gpu_enabled = True
|
|
|
+
|
|
|
+ def always_fail(texts, src_lang=None, tgt_lang=None, batch_size=4, max_length=None):
|
|
|
+ raise RuntimeError("CUDA out of memory")
|
|
|
+
|
|
|
+ engine.translate = Mock(return_value="Translated: test")
|
|
|
+ engine.translate_batch = always_fail
|
|
|
+
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+
|
|
|
+ # Should raise the error, not fall back (already at batch_size=1)
|
|
|
+ with pytest.raises(RuntimeError):
|
|
|
+ handler.translate_batch(texts, batch_size=1)
|
|
|
+
|
|
|
+ # Verify no fallback occurred
|
|
|
+ assert not handler.has_fallback_occurred()
|
|
|
+
|
|
|
+ def test_fallback_failure_propagates(self):
|
|
|
+ """Test that fallback failure is propagated."""
|
|
|
+ # Create an engine that fails both in batch and in retry
|
|
|
+ engine = Mock()
|
|
|
+ engine.is_gpu_enabled = True
|
|
|
+
|
|
|
+ def fail_batch(*args, **kwargs):
|
|
|
+ raise RuntimeError("CUDA out of memory")
|
|
|
+
|
|
|
+ engine.translate = Mock(return_value="Translated: test")
|
|
|
+ engine.translate_batch = fail_batch
|
|
|
+
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ texts = ["Hello", "World"]
|
|
|
+
|
|
|
+ with pytest.raises(RuntimeError):
|
|
|
+ handler.translate_batch(texts, batch_size=4)
|
|
|
+
|
|
|
+ events = handler.get_fallback_events()
|
|
|
+ assert events[0].recovered is False
|
|
|
+
|
|
|
+ def test_get_fallback_summary(self):
|
|
|
+ """Test getting fallback summary."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ # No events initially
|
|
|
+ summary = handler.get_fallback_summary()
|
|
|
+ assert summary["total_fallbacks"] == 0
|
|
|
+ assert summary["successful_recoveries"] == 0
|
|
|
+
|
|
|
+ # Trigger a fallback
|
|
|
+ handler.translate_batch(["Hello"], batch_size=4)
|
|
|
+
|
|
|
+ summary = handler.get_fallback_summary()
|
|
|
+ assert summary["total_fallbacks"] == 1
|
|
|
+ assert summary["successful_recoveries"] == 1
|
|
|
+ assert summary["last_event"] is not None
|
|
|
+
|
|
|
+ def test_clear_fallback_history(self):
|
|
|
+ """Test clearing fallback history."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ # Trigger a fallback
|
|
|
+ handler.translate_batch(["Hello"], batch_size=4)
|
|
|
+ assert handler.has_fallback_occurred()
|
|
|
+
|
|
|
+ # Clear history
|
|
|
+ handler.clear_fallback_history()
|
|
|
+ assert not handler.has_fallback_occurred()
|
|
|
+ assert len(handler.get_fallback_events()) == 0
|
|
|
+
|
|
|
+ def test_get_user_message_with_fallback(self):
|
|
|
+ """Test getting user message after fallback."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ # Before fallback
|
|
|
+ assert handler.get_user_message() is None
|
|
|
+
|
|
|
+ # After fallback
|
|
|
+ handler.translate_batch(["Hello"], batch_size=4)
|
|
|
+ message = handler.get_user_message()
|
|
|
+ assert message is not None
|
|
|
+ assert "reduced batch size" in message.lower()
|
|
|
+
|
|
|
+ def test_get_user_message_after_failed_recovery(self):
|
|
|
+ """Test getting user message after failed recovery."""
|
|
|
+ engine = Mock()
|
|
|
+ engine.is_gpu_enabled = True
|
|
|
+ engine.translate = Mock(return_value="Translated")
|
|
|
+
|
|
|
+ def fail_batch(*args, **kwargs):
|
|
|
+ raise RuntimeError("CUDA out of memory")
|
|
|
+
|
|
|
+ engine.translate_batch = fail_batch
|
|
|
+
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ try:
|
|
|
+ handler.translate_batch(["Hello"], batch_size=4)
|
|
|
+ except RuntimeError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ message = handler.get_user_message()
|
|
|
+ assert message is not None
|
|
|
+ assert "failed" in message.lower()
|
|
|
+
|
|
|
+ def test_is_recommended_batch_size(self):
|
|
|
+ """Test checking recommended batch size."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ # Before fallback, any size > 1 is recommended
|
|
|
+ assert handler.is_recommended_batch_size(4)
|
|
|
+ assert not handler.is_recommended_batch_size(1)
|
|
|
+
|
|
|
+ # After fallback, only batch_size=1 is recommended
|
|
|
+ handler.translate_batch(["Hello"], batch_size=4)
|
|
|
+ assert not handler.is_recommended_batch_size(4)
|
|
|
+ assert handler.is_recommended_batch_size(1)
|
|
|
+
|
|
|
+ def test_get_recommended_batch_size(self):
|
|
|
+ """Test getting recommended batch size."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ # Before fallback
|
|
|
+ assert handler.get_recommended_batch_size() == 4
|
|
|
+
|
|
|
+ # After fallback
|
|
|
+ handler.translate_batch(["Hello"], batch_size=4)
|
|
|
+ assert handler.get_recommended_batch_size() == 1
|
|
|
+
|
|
|
+ def test_multiple_fallback_events(self):
|
|
|
+ """Test tracking multiple fallback events."""
|
|
|
+ engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
|
|
|
+ handler = FallbackHandler(engine)
|
|
|
+
|
|
|
+ # Trigger multiple fallbacks
|
|
|
+ for _ in range(3):
|
|
|
+ handler.translate_batch(["Hello"], batch_size=4)
|
|
|
+
|
|
|
+ events = handler.get_fallback_events()
|
|
|
+ assert len(events) == 3
|
|
|
+
|
|
|
+ summary = handler.get_fallback_summary()
|
|
|
+ assert summary["total_fallbacks"] == 3
|