""" Unit tests for the GPU fallback handler module. Tests cover automatic fallback to batch_size=1 on GPU errors. """ import sys from unittest.mock import Mock, MagicMock, patch # Mock torch and transformers before importing sys_mock = Mock() sys.modules["torch"] = sys_mock sys.modules["transformers"] = sys_mock import pytest from src.translator.fallback_handler import ( FallbackHandler, FallbackEvent, FallbackReason, ) class MockEngine: """Mock translation engine for testing.""" def __init__(self, fail_on_batch: bool = False, error_msg: str = ""): self.fail_on_batch = fail_on_batch self.error_msg = error_msg self._is_gpu_enabled = True self._call_count = 0 def translate(self, text: str, src_lang: str = "zh", tgt_lang: str = "en", max_length: int = None) -> str: return f"Translated: {text}" def translate_batch( self, texts: list, src_lang: str = "zh", tgt_lang: str = "en", batch_size: int = 4, max_length: int = None ) -> list: self._call_count += 1 # Only fail on first call (with larger batch size), succeed with batch_size=1 if self.fail_on_batch and batch_size > 1: raise RuntimeError(self.error_msg or "CUDA out of memory") return [f"Translated: {t}" for t in texts] # Use a property with setter for flexibility @property def is_gpu_enabled(self): return self._is_gpu_enabled @is_gpu_enabled.setter def is_gpu_enabled(self, value): self._is_gpu_enabled = value class TestFallbackEvent: """Test cases for FallbackEvent dataclass.""" def test_create_fallback_event(self): """Test creating a fallback event.""" event = FallbackEvent( original_batch_size=4, reason=FallbackReason.OUT_OF_MEMORY, error_message="CUDA out of memory", ) assert event.original_batch_size == 4 assert event.reason == FallbackReason.OUT_OF_MEMORY assert event.error_message == "CUDA out of memory" def test_to_dict(self): """Test converting fallback event to dictionary.""" event = FallbackEvent( original_batch_size=4, reason=FallbackReason.OUT_OF_MEMORY, error_message="CUDA error", ) data = event.to_dict() assert data["original_batch_size"] == 4 assert data["reason"] == "out_of_memory" assert data["error_message"] == "CUDA error" assert "timestamp" in data assert data["recovered"] is False # default class TestFallbackHandler: """Test cases for FallbackHandler class.""" def test_init(self): """Test FallbackHandler initialization.""" engine = MockEngine() handler = FallbackHandler(engine) assert handler.engine == engine assert handler.fallback_callback is None assert not handler.has_fallback_occurred() def test_init_with_callback(self): """Test FallbackHandler with callback.""" engine = MockEngine() callback = Mock() handler = FallbackHandler(engine, fallback_callback=callback) assert handler.fallback_callback == callback def test_translate_single(self): """Test single text translation (no fallback).""" engine = MockEngine() handler = FallbackHandler(engine) result = handler.translate("Hello") assert result == "Translated: Hello" assert not handler.has_fallback_occurred() def test_translate_batch_success(self): """Test successful batch translation.""" engine = MockEngine() handler = FallbackHandler(engine) texts = ["Hello", "World"] results = handler.translate_batch(texts) assert len(results) == 2 assert results[0] == "Translated: Hello" assert results[1] == "Translated: World" assert not handler.has_fallback_occurred() def test_translate_batch_oom_fallback(self): """Test batch translation with OOM fallback.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") handler = FallbackHandler(engine) texts = ["Hello", "World"] results = handler.translate_batch(texts, batch_size=4) assert len(results) == 2 assert handler.has_fallback_occurred() events = handler.get_fallback_events() assert len(events) == 1 assert events[0].reason == FallbackReason.OUT_OF_MEMORY assert events[0].original_batch_size == 4 assert events[0].recovered is True def test_translate_batch_cuda_error_fallback(self): """Test batch translation with CUDA error fallback.""" # Create a custom mock that fails on batch_size > 1, succeeds on batch_size=1 class CustomMockEngine: def __init__(self): self._gpu = True # Use different name to avoid property conflict self.call_count = 0 def translate(self, text, src_lang="zh", tgt_lang="en", max_length=None): return f"Translated: {text}" def translate_batch(self, texts, src_lang="zh", tgt_lang="en", batch_size=4, max_length=None): self.call_count += 1 # Only fail when batch_size > 1 if batch_size > 1: raise RuntimeError("CUDA error: invalid argument") return [f"Translated: {t}" for t in texts] @property def is_gpu_enabled(self): return self._gpu engine = CustomMockEngine() handler = FallbackHandler(engine) texts = ["Hello", "World"] results = handler.translate_batch(texts, batch_size=4) assert len(results) == 2 events = handler.get_fallback_events() assert events[0].reason == FallbackReason.CUDA_ERROR def test_translate_batch_runtime_error_fallback(self): """Test batch translation with runtime error fallback.""" # Create a custom mock that fails on batch_size > 1, succeeds on batch_size=1 class CustomMockEngine: def __init__(self): self._gpu = True # Use different name to avoid property conflict self.call_count = 0 def translate(self, text, src_lang="zh", tgt_lang="en", max_length=None): return f"Translated: {text}" def translate_batch(self, texts, src_lang="zh", tgt_lang="en", batch_size=4, max_length=None): self.call_count += 1 # Only fail when batch_size > 1 # Use a message that matches RUNTIME_ERROR pattern if batch_size > 1: raise RuntimeError("RuntimeError: CUDA device error") return [f"Translated: {t}" for t in texts] @property def is_gpu_enabled(self): return self._gpu engine = CustomMockEngine() handler = FallbackHandler(engine) texts = ["Hello", "World"] results = handler.translate_batch(texts, batch_size=4) assert len(results) == 2 events = handler.get_fallback_events() assert events[0].reason == FallbackReason.RUNTIME_ERROR def test_fallback_callback_invoked(self): """Test that fallback callback is invoked.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") callback = Mock() handler = FallbackHandler(engine, fallback_callback=callback) texts = ["Hello", "World"] handler.translate_batch(texts, batch_size=4) assert callback.called event_arg = callback.call_args[0][0] assert isinstance(event_arg, FallbackEvent) assert event_arg.reason == FallbackReason.OUT_OF_MEMORY def test_no_fallback_for_non_gpu_errors(self): """Test that non-GPU errors are not caught.""" engine = MockEngine(fail_on_batch=True, error_msg="Some other error") handler = FallbackHandler(engine) texts = ["Hello", "World"] with pytest.raises(RuntimeError): handler.translate_batch(texts, batch_size=4) assert not handler.has_fallback_occurred() def test_no_fallback_when_cpu_mode(self): """Test that fallback doesn't occur on CPU mode.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") engine.is_gpu_enabled = False # Use property setter handler = FallbackHandler(engine) texts = ["Hello", "World"] # Should raise the error, not fall back with pytest.raises(RuntimeError): handler.translate_batch(texts, batch_size=4) def test_no_fallback_when_batch_size_is_1(self): """Test that fallback doesn't occur when batch_size is already 1.""" # Use a custom engine that always fails engine = Mock() engine._is_gpu_enabled = True def always_fail(texts, src_lang=None, tgt_lang=None, batch_size=4, max_length=None): raise RuntimeError("CUDA out of memory") engine.translate = Mock(return_value="Translated: test") engine.translate_batch = always_fail handler = FallbackHandler(engine) texts = ["Hello", "World"] # Should raise the error, not fall back (already at batch_size=1) with pytest.raises(RuntimeError): handler.translate_batch(texts, batch_size=1) # Verify no fallback occurred assert not handler.has_fallback_occurred() def test_fallback_failure_propagates(self): """Test that fallback failure is propagated.""" # Create an engine that fails both in batch and in retry engine = Mock() engine.is_gpu_enabled = True def fail_batch(*args, **kwargs): raise RuntimeError("CUDA out of memory") engine.translate = Mock(return_value="Translated: test") engine.translate_batch = fail_batch handler = FallbackHandler(engine) texts = ["Hello", "World"] with pytest.raises(RuntimeError): handler.translate_batch(texts, batch_size=4) events = handler.get_fallback_events() assert events[0].recovered is False def test_get_fallback_summary(self): """Test getting fallback summary.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") handler = FallbackHandler(engine) # No events initially summary = handler.get_fallback_summary() assert summary["total_fallbacks"] == 0 assert summary["successful_recoveries"] == 0 # Trigger a fallback handler.translate_batch(["Hello"], batch_size=4) summary = handler.get_fallback_summary() assert summary["total_fallbacks"] == 1 assert summary["successful_recoveries"] == 1 assert summary["last_event"] is not None def test_clear_fallback_history(self): """Test clearing fallback history.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") handler = FallbackHandler(engine) # Trigger a fallback handler.translate_batch(["Hello"], batch_size=4) assert handler.has_fallback_occurred() # Clear history handler.clear_fallback_history() assert not handler.has_fallback_occurred() assert len(handler.get_fallback_events()) == 0 def test_get_user_message_with_fallback(self): """Test getting user message after fallback.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") handler = FallbackHandler(engine) # Before fallback assert handler.get_user_message() is None # After fallback handler.translate_batch(["Hello"], batch_size=4) message = handler.get_user_message() assert message is not None assert "reduced batch size" in message.lower() def test_get_user_message_after_failed_recovery(self): """Test getting user message after failed recovery.""" engine = Mock() engine.is_gpu_enabled = True engine.translate = Mock(return_value="Translated") def fail_batch(*args, **kwargs): raise RuntimeError("CUDA out of memory") engine.translate_batch = fail_batch handler = FallbackHandler(engine) try: handler.translate_batch(["Hello"], batch_size=4) except RuntimeError: pass message = handler.get_user_message() assert message is not None assert "failed" in message.lower() def test_is_recommended_batch_size(self): """Test checking recommended batch size.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") handler = FallbackHandler(engine) # Before fallback, any size > 1 is recommended assert handler.is_recommended_batch_size(4) assert not handler.is_recommended_batch_size(1) # After fallback, only batch_size=1 is recommended handler.translate_batch(["Hello"], batch_size=4) assert not handler.is_recommended_batch_size(4) assert handler.is_recommended_batch_size(1) def test_get_recommended_batch_size(self): """Test getting recommended batch size.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") handler = FallbackHandler(engine) # Before fallback assert handler.get_recommended_batch_size() == 4 # After fallback handler.translate_batch(["Hello"], batch_size=4) assert handler.get_recommended_batch_size() == 1 def test_multiple_fallback_events(self): """Test tracking multiple fallback events.""" engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory") handler = FallbackHandler(engine) # Trigger multiple fallbacks for _ in range(3): handler.translate_batch(["Hello"], batch_size=4) events = handler.get_fallback_events() assert len(events) == 3 summary = handler.get_fallback_summary() assert summary["total_fallbacks"] == 3