test_fallback_handler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. """
  2. Unit tests for the GPU fallback handler module.
  3. Tests cover automatic fallback to batch_size=1 on GPU errors.
  4. """
  5. import sys
  6. from unittest.mock import Mock, MagicMock, patch
  7. # Mock torch and transformers before importing
  8. sys_mock = Mock()
  9. sys.modules["torch"] = sys_mock
  10. sys.modules["transformers"] = sys_mock
  11. import pytest
  12. from src.translator.fallback_handler import (
  13. FallbackHandler,
  14. FallbackEvent,
  15. FallbackReason,
  16. )
  17. class MockEngine:
  18. """Mock translation engine for testing."""
  19. def __init__(self, fail_on_batch: bool = False, error_msg: str = ""):
  20. self.fail_on_batch = fail_on_batch
  21. self.error_msg = error_msg
  22. self._is_gpu_enabled = True
  23. self._call_count = 0
  24. def translate(self, text: str, src_lang: str = "zh", tgt_lang: str = "en", max_length: int = None) -> str:
  25. return f"Translated: {text}"
  26. def translate_batch(
  27. self,
  28. texts: list,
  29. src_lang: str = "zh",
  30. tgt_lang: str = "en",
  31. batch_size: int = 4,
  32. max_length: int = None
  33. ) -> list:
  34. self._call_count += 1
  35. # Only fail on first call (with larger batch size), succeed with batch_size=1
  36. if self.fail_on_batch and batch_size > 1:
  37. raise RuntimeError(self.error_msg or "CUDA out of memory")
  38. return [f"Translated: {t}" for t in texts]
  39. # Use a property with setter for flexibility
  40. @property
  41. def is_gpu_enabled(self):
  42. return self._is_gpu_enabled
  43. @is_gpu_enabled.setter
  44. def is_gpu_enabled(self, value):
  45. self._is_gpu_enabled = value
  46. class TestFallbackEvent:
  47. """Test cases for FallbackEvent dataclass."""
  48. def test_create_fallback_event(self):
  49. """Test creating a fallback event."""
  50. event = FallbackEvent(
  51. original_batch_size=4,
  52. reason=FallbackReason.OUT_OF_MEMORY,
  53. error_message="CUDA out of memory",
  54. )
  55. assert event.original_batch_size == 4
  56. assert event.reason == FallbackReason.OUT_OF_MEMORY
  57. assert event.error_message == "CUDA out of memory"
  58. def test_to_dict(self):
  59. """Test converting fallback event to dictionary."""
  60. event = FallbackEvent(
  61. original_batch_size=4,
  62. reason=FallbackReason.OUT_OF_MEMORY,
  63. error_message="CUDA error",
  64. )
  65. data = event.to_dict()
  66. assert data["original_batch_size"] == 4
  67. assert data["reason"] == "out_of_memory"
  68. assert data["error_message"] == "CUDA error"
  69. assert "timestamp" in data
  70. assert data["recovered"] is False # default
  71. class TestFallbackHandler:
  72. """Test cases for FallbackHandler class."""
  73. def test_init(self):
  74. """Test FallbackHandler initialization."""
  75. engine = MockEngine()
  76. handler = FallbackHandler(engine)
  77. assert handler.engine == engine
  78. assert handler.fallback_callback is None
  79. assert not handler.has_fallback_occurred()
  80. def test_init_with_callback(self):
  81. """Test FallbackHandler with callback."""
  82. engine = MockEngine()
  83. callback = Mock()
  84. handler = FallbackHandler(engine, fallback_callback=callback)
  85. assert handler.fallback_callback == callback
  86. def test_translate_single(self):
  87. """Test single text translation (no fallback)."""
  88. engine = MockEngine()
  89. handler = FallbackHandler(engine)
  90. result = handler.translate("Hello")
  91. assert result == "Translated: Hello"
  92. assert not handler.has_fallback_occurred()
  93. def test_translate_batch_success(self):
  94. """Test successful batch translation."""
  95. engine = MockEngine()
  96. handler = FallbackHandler(engine)
  97. texts = ["Hello", "World"]
  98. results = handler.translate_batch(texts)
  99. assert len(results) == 2
  100. assert results[0] == "Translated: Hello"
  101. assert results[1] == "Translated: World"
  102. assert not handler.has_fallback_occurred()
  103. def test_translate_batch_oom_fallback(self):
  104. """Test batch translation with OOM fallback."""
  105. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  106. handler = FallbackHandler(engine)
  107. texts = ["Hello", "World"]
  108. results = handler.translate_batch(texts, batch_size=4)
  109. assert len(results) == 2
  110. assert handler.has_fallback_occurred()
  111. events = handler.get_fallback_events()
  112. assert len(events) == 1
  113. assert events[0].reason == FallbackReason.OUT_OF_MEMORY
  114. assert events[0].original_batch_size == 4
  115. assert events[0].recovered is True
  116. def test_translate_batch_cuda_error_fallback(self):
  117. """Test batch translation with CUDA error fallback."""
  118. # Create a custom mock that fails on batch_size > 1, succeeds on batch_size=1
  119. class CustomMockEngine:
  120. def __init__(self):
  121. self._gpu = True # Use different name to avoid property conflict
  122. self.call_count = 0
  123. def translate(self, text, src_lang="zh", tgt_lang="en", max_length=None):
  124. return f"Translated: {text}"
  125. def translate_batch(self, texts, src_lang="zh", tgt_lang="en", batch_size=4, max_length=None):
  126. self.call_count += 1
  127. # Only fail when batch_size > 1
  128. if batch_size > 1:
  129. raise RuntimeError("CUDA error: invalid argument")
  130. return [f"Translated: {t}" for t in texts]
  131. @property
  132. def is_gpu_enabled(self):
  133. return self._gpu
  134. engine = CustomMockEngine()
  135. handler = FallbackHandler(engine)
  136. texts = ["Hello", "World"]
  137. results = handler.translate_batch(texts, batch_size=4)
  138. assert len(results) == 2
  139. events = handler.get_fallback_events()
  140. assert events[0].reason == FallbackReason.CUDA_ERROR
  141. def test_translate_batch_runtime_error_fallback(self):
  142. """Test batch translation with runtime error fallback."""
  143. # Create a custom mock that fails on batch_size > 1, succeeds on batch_size=1
  144. class CustomMockEngine:
  145. def __init__(self):
  146. self._gpu = True # Use different name to avoid property conflict
  147. self.call_count = 0
  148. def translate(self, text, src_lang="zh", tgt_lang="en", max_length=None):
  149. return f"Translated: {text}"
  150. def translate_batch(self, texts, src_lang="zh", tgt_lang="en", batch_size=4, max_length=None):
  151. self.call_count += 1
  152. # Only fail when batch_size > 1
  153. # Use a message that matches RUNTIME_ERROR pattern
  154. if batch_size > 1:
  155. raise RuntimeError("RuntimeError: CUDA device error")
  156. return [f"Translated: {t}" for t in texts]
  157. @property
  158. def is_gpu_enabled(self):
  159. return self._gpu
  160. engine = CustomMockEngine()
  161. handler = FallbackHandler(engine)
  162. texts = ["Hello", "World"]
  163. results = handler.translate_batch(texts, batch_size=4)
  164. assert len(results) == 2
  165. events = handler.get_fallback_events()
  166. assert events[0].reason == FallbackReason.RUNTIME_ERROR
  167. def test_fallback_callback_invoked(self):
  168. """Test that fallback callback is invoked."""
  169. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  170. callback = Mock()
  171. handler = FallbackHandler(engine, fallback_callback=callback)
  172. texts = ["Hello", "World"]
  173. handler.translate_batch(texts, batch_size=4)
  174. assert callback.called
  175. event_arg = callback.call_args[0][0]
  176. assert isinstance(event_arg, FallbackEvent)
  177. assert event_arg.reason == FallbackReason.OUT_OF_MEMORY
  178. def test_no_fallback_for_non_gpu_errors(self):
  179. """Test that non-GPU errors are not caught."""
  180. engine = MockEngine(fail_on_batch=True, error_msg="Some other error")
  181. handler = FallbackHandler(engine)
  182. texts = ["Hello", "World"]
  183. with pytest.raises(RuntimeError):
  184. handler.translate_batch(texts, batch_size=4)
  185. assert not handler.has_fallback_occurred()
  186. def test_no_fallback_when_cpu_mode(self):
  187. """Test that fallback doesn't occur on CPU mode."""
  188. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  189. engine.is_gpu_enabled = False # Use property setter
  190. handler = FallbackHandler(engine)
  191. texts = ["Hello", "World"]
  192. # Should raise the error, not fall back
  193. with pytest.raises(RuntimeError):
  194. handler.translate_batch(texts, batch_size=4)
  195. def test_no_fallback_when_batch_size_is_1(self):
  196. """Test that fallback doesn't occur when batch_size is already 1."""
  197. # Use a custom engine that always fails
  198. engine = Mock()
  199. engine._is_gpu_enabled = True
  200. def always_fail(texts, src_lang=None, tgt_lang=None, batch_size=4, max_length=None):
  201. raise RuntimeError("CUDA out of memory")
  202. engine.translate = Mock(return_value="Translated: test")
  203. engine.translate_batch = always_fail
  204. handler = FallbackHandler(engine)
  205. texts = ["Hello", "World"]
  206. # Should raise the error, not fall back (already at batch_size=1)
  207. with pytest.raises(RuntimeError):
  208. handler.translate_batch(texts, batch_size=1)
  209. # Verify no fallback occurred
  210. assert not handler.has_fallback_occurred()
  211. def test_fallback_failure_propagates(self):
  212. """Test that fallback failure is propagated."""
  213. # Create an engine that fails both in batch and in retry
  214. engine = Mock()
  215. engine.is_gpu_enabled = True
  216. def fail_batch(*args, **kwargs):
  217. raise RuntimeError("CUDA out of memory")
  218. engine.translate = Mock(return_value="Translated: test")
  219. engine.translate_batch = fail_batch
  220. handler = FallbackHandler(engine)
  221. texts = ["Hello", "World"]
  222. with pytest.raises(RuntimeError):
  223. handler.translate_batch(texts, batch_size=4)
  224. events = handler.get_fallback_events()
  225. assert events[0].recovered is False
  226. def test_get_fallback_summary(self):
  227. """Test getting fallback summary."""
  228. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  229. handler = FallbackHandler(engine)
  230. # No events initially
  231. summary = handler.get_fallback_summary()
  232. assert summary["total_fallbacks"] == 0
  233. assert summary["successful_recoveries"] == 0
  234. # Trigger a fallback
  235. handler.translate_batch(["Hello"], batch_size=4)
  236. summary = handler.get_fallback_summary()
  237. assert summary["total_fallbacks"] == 1
  238. assert summary["successful_recoveries"] == 1
  239. assert summary["last_event"] is not None
  240. def test_clear_fallback_history(self):
  241. """Test clearing fallback history."""
  242. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  243. handler = FallbackHandler(engine)
  244. # Trigger a fallback
  245. handler.translate_batch(["Hello"], batch_size=4)
  246. assert handler.has_fallback_occurred()
  247. # Clear history
  248. handler.clear_fallback_history()
  249. assert not handler.has_fallback_occurred()
  250. assert len(handler.get_fallback_events()) == 0
  251. def test_get_user_message_with_fallback(self):
  252. """Test getting user message after fallback."""
  253. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  254. handler = FallbackHandler(engine)
  255. # Before fallback
  256. assert handler.get_user_message() is None
  257. # After fallback
  258. handler.translate_batch(["Hello"], batch_size=4)
  259. message = handler.get_user_message()
  260. assert message is not None
  261. assert "reduced batch size" in message.lower()
  262. def test_get_user_message_after_failed_recovery(self):
  263. """Test getting user message after failed recovery."""
  264. engine = Mock()
  265. engine.is_gpu_enabled = True
  266. engine.translate = Mock(return_value="Translated")
  267. def fail_batch(*args, **kwargs):
  268. raise RuntimeError("CUDA out of memory")
  269. engine.translate_batch = fail_batch
  270. handler = FallbackHandler(engine)
  271. try:
  272. handler.translate_batch(["Hello"], batch_size=4)
  273. except RuntimeError:
  274. pass
  275. message = handler.get_user_message()
  276. assert message is not None
  277. assert "failed" in message.lower()
  278. def test_is_recommended_batch_size(self):
  279. """Test checking recommended batch size."""
  280. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  281. handler = FallbackHandler(engine)
  282. # Before fallback, any size > 1 is recommended
  283. assert handler.is_recommended_batch_size(4)
  284. assert not handler.is_recommended_batch_size(1)
  285. # After fallback, only batch_size=1 is recommended
  286. handler.translate_batch(["Hello"], batch_size=4)
  287. assert not handler.is_recommended_batch_size(4)
  288. assert handler.is_recommended_batch_size(1)
  289. def test_get_recommended_batch_size(self):
  290. """Test getting recommended batch size."""
  291. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  292. handler = FallbackHandler(engine)
  293. # Before fallback
  294. assert handler.get_recommended_batch_size() == 4
  295. # After fallback
  296. handler.translate_batch(["Hello"], batch_size=4)
  297. assert handler.get_recommended_batch_size() == 1
  298. def test_multiple_fallback_events(self):
  299. """Test tracking multiple fallback events."""
  300. engine = MockEngine(fail_on_batch=True, error_msg="CUDA out of memory")
  301. handler = FallbackHandler(engine)
  302. # Trigger multiple fallbacks
  303. for _ in range(3):
  304. handler.translate_batch(["Hello"], batch_size=4)
  305. events = handler.get_fallback_events()
  306. assert len(events) == 3
  307. summary = handler.get_fallback_summary()
  308. assert summary["total_fallbacks"] == 3