test_translator.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. """
  2. Unit tests for the translator module.
  3. Tests cover:
  4. - TranslationEngine initialization and basic translation
  5. - TranslationPipeline with glossary integration
  6. - ProgressReporter callbacks
  7. - ChapterTranslator (mocked)
  8. """
  9. import pytest
  10. from pathlib import Path
  11. from unittest.mock import Mock, MagicMock, patch
  12. from datetime import datetime
  13. from src.translator.engine import TranslationEngine
  14. from src.translator.pipeline import TranslationPipeline, TranslationResult
  15. from src.translator.progress import ProgressReporter, ProgressStatus, ConsoleProgressReporter
  16. from src.translator.chapter_translator import ChapterTranslator
  17. from src.glossary.models import Glossary, GlossaryEntry, TermCategory
  18. # ============================================================================
  19. # Test TranslationEngine (Story 5.1)
  20. # ============================================================================
  21. class TestTranslationEngine:
  22. """Tests for TranslationEngine class."""
  23. @pytest.fixture
  24. def mock_transformers(self):
  25. """Mock the transformers library."""
  26. with patch('src.translator.engine.M2M100ForConditionalGeneration') as mock_model, \
  27. patch('src.translator.engine.M2M100Tokenizer') as mock_tokenizer:
  28. # Setup mock tokenizer
  29. mock_tok_instance = MagicMock()
  30. mock_tok_instance.src_lang = "zh"
  31. mock_tok_instance.lang_code_to_id = {"zh": 1, "en": 2, "fr": 3}
  32. mock_tok_instance.return_tensors = "pt"
  33. mock_tokenizer.from_pretrained.return_value = mock_tok_instance
  34. # Setup mock model
  35. mock_model_instance = MagicMock()
  36. mock_model.from_pretrained.return_value = mock_model_instance
  37. mock_model_instance.eval.return_value = None
  38. yield {
  39. "model": mock_model,
  40. "tokenizer": mock_tokenizer,
  41. "model_instance": mock_model_instance,
  42. "tokenizer_instance": mock_tok_instance
  43. }
  44. @pytest.fixture
  45. def mock_model_path(self, tmp_path):
  46. """Create a temporary mock model directory."""
  47. model_dir = tmp_path / "m2m100_418M"
  48. model_dir.mkdir()
  49. (model_dir / "config.json").write_text("{}")
  50. return str(model_dir)
  51. def test_engine_init_with_mock_path(self, mock_transformers, mock_model_path):
  52. """Test engine initialization with a mock model path."""
  53. mock_transformers["tokenizer_instance"].batch_decode.return_value = ["Hello world"]
  54. engine = TranslationEngine(model_path=mock_model_path)
  55. assert engine.model_path == mock_model_path
  56. assert engine.device in ("cpu", "cuda")
  57. def test_engine_init_import_error(self):
  58. """Test that ImportError is raised when transformers is not available."""
  59. with patch('src.translator.engine.M2M100ForConditionalGeneration', None):
  60. with pytest.raises(ImportError, match="transformers library"):
  61. TranslationEngine(model_path="/fake/path")
  62. def test_translate_single_text(self, mock_transformers, mock_model_path):
  63. """Test basic single-text translation."""
  64. mock_tok = mock_transformers["tokenizer_instance"]
  65. mock_tok.batch_decode.return_value = ["Hello world"]
  66. engine = TranslationEngine(model_path=mock_model_path)
  67. result = engine.translate("你好世界", src_lang="zh", tgt_lang="en")
  68. assert result == "Hello world"
  69. mock_tok.batch_decode.assert_called_once()
  70. def test_translate_empty_text_raises_error(self, mock_transformers, mock_model_path):
  71. """Test that translating empty text raises ValueError."""
  72. mock_tok = mock_transformers["tokenizer_instance"]
  73. mock_tok.batch_decode.return_value = ["Hello"]
  74. engine = TranslationEngine(model_path=mock_model_path)
  75. with pytest.raises(ValueError, match="cannot be empty"):
  76. engine.translate("", src_lang="zh", tgt_lang="en")
  77. def test_translate_batch(self, mock_transformers, mock_model_path):
  78. """Test batch translation."""
  79. mock_tok = mock_transformers["tokenizer_instance"]
  80. mock_tok.batch_decode.return_value = ["Hello", "World", "Test"]
  81. engine = TranslationEngine(model_path=mock_model_path)
  82. results = engine.translate_batch(
  83. ["你好", "世界", "测试"],
  84. src_lang="zh",
  85. tgt_lang="en",
  86. batch_size=3
  87. )
  88. assert len(results) == 3
  89. assert results == ["Hello", "World", "Test"]
  90. def test_translate_batch_empty_raises_error(self, mock_transformers, mock_model_path):
  91. """Test that empty batch list raises ValueError."""
  92. mock_tok = mock_transformers["tokenizer_instance"]
  93. mock_tok.batch_decode.return_value = []
  94. engine = TranslationEngine(model_path=mock_model_path)
  95. with pytest.raises(ValueError, match="cannot be empty"):
  96. engine.translate_batch([], src_lang="zh", tgt_lang="en")
  97. def test_is_language_supported(self, mock_transformers, mock_model_path):
  98. """Test language support checking."""
  99. mock_tok = mock_transformers["tokenizer_instance"]
  100. mock_tok.batch_decode.return_value = ["Hello"]
  101. mock_tok.lang_code_to_id = {"zh": 1, "en": 2, "fr": 3}
  102. engine = TranslationEngine(model_path=mock_model_path)
  103. assert engine.is_language_supported("zh") is True
  104. assert engine.is_language_supported("en") is True
  105. assert engine.is_language_supported("de") is False
  106. # ============================================================================
  107. # Test TranslationPipeline (Story 5.2)
  108. # ============================================================================
  109. class TestTranslationPipeline:
  110. """Tests for TranslationPipeline class."""
  111. @pytest.fixture
  112. def mock_engine(self):
  113. """Create a mock translation engine."""
  114. engine = MagicMock(spec=TranslationEngine)
  115. engine.translate.return_value = "Lin Feng is a disciple"
  116. engine.translate_batch.return_value = ["Hello", "World"]
  117. engine.is_language_supported.return_value = True
  118. return engine
  119. @pytest.fixture
  120. def sample_glossary(self):
  121. """Create a sample glossary."""
  122. glossary = Glossary()
  123. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  124. glossary.add(GlossaryEntry("青云宗", "Qingyun Sect", TermCategory.LOCATION))
  125. return glossary
  126. def test_pipeline_init(self, mock_engine):
  127. """Test pipeline initialization."""
  128. pipeline = TranslationPipeline(mock_engine)
  129. assert pipeline.engine is mock_engine
  130. assert pipeline.src_lang == "zh"
  131. assert pipeline.tgt_lang == "en"
  132. def test_pipeline_with_glossary(self, mock_engine, sample_glossary):
  133. """Test pipeline with glossary."""
  134. pipeline = TranslationPipeline(mock_engine, sample_glossary)
  135. assert pipeline.has_glossary is True
  136. assert len(pipeline.glossary) == 2
  137. def test_translate_simple(self, mock_engine):
  138. """Test simple translation without glossary."""
  139. pipeline = TranslationPipeline(mock_engine)
  140. result = pipeline.translate("Hello world")
  141. assert result == "Lin Feng is a disciple"
  142. mock_engine.translate.assert_called_once()
  143. def test_translate_with_glossary(self, mock_engine, sample_glossary):
  144. """Test translation with glossary preprocessing."""
  145. mock_engine.translate.return_value = "__en__Lin Feng is a disciple of __en__Qingyun Sect"
  146. pipeline = TranslationPipeline(mock_engine, sample_glossary)
  147. result = pipeline.translate("林风是青云宗的弟子")
  148. # After post-processing, placeholders should be replaced
  149. assert "Lin Feng" in result
  150. assert "Qingyun Sect" in result
  151. def test_translate_return_details(self, mock_engine, sample_glossary):
  152. """Test translation with detailed result."""
  153. mock_engine.translate.return_value = "__en__Lin Feng is here"
  154. pipeline = TranslationPipeline(mock_engine, sample_glossary)
  155. result = pipeline.translate("林风在这里", return_details=True)
  156. assert isinstance(result, TranslationResult)
  157. assert result.original == "林风在这里"
  158. assert "Lin Feng" in result.translated
  159. assert len(result.terms_used) > 0
  160. assert isinstance(result.placeholder_map, dict)
  161. def test_translate_batch(self, mock_engine):
  162. """Test batch translation."""
  163. mock_engine.translate_batch.return_value = ["Result 1", "Result 2"]
  164. pipeline = TranslationPipeline(mock_engine)
  165. results = pipeline.translate_batch(["Text 1", "Text 2"])
  166. assert len(results) == 2
  167. mock_engine.translate_batch.assert_called_once()
  168. def test_add_term(self, mock_engine):
  169. """Test adding a term to the pipeline glossary."""
  170. pipeline = TranslationPipeline(mock_engine)
  171. entry = GlossaryEntry("test", "TEST", TermCategory.OTHER)
  172. pipeline.add_term(entry)
  173. assert pipeline.has_glossary is True
  174. assert "test" in pipeline.glossary
  175. def test_set_languages_valid(self, mock_engine):
  176. """Test setting valid languages."""
  177. pipeline = TranslationPipeline(mock_engine)
  178. mock_engine.is_language_supported.return_value = True
  179. pipeline.set_languages("en", "fr")
  180. assert pipeline.src_lang == "en"
  181. assert pipeline.tgt_lang == "fr"
  182. def test_set_languages_invalid(self, mock_engine):
  183. """Test setting invalid language raises error."""
  184. pipeline = TranslationPipeline(mock_engine)
  185. mock_engine.is_language_supported.side_effect = lambda x: x in ["zh", "en"]
  186. with pytest.raises(ValueError, match="not supported"):
  187. pipeline.set_languages("xx", "yy")
  188. # ============================================================================
  189. # Test ProgressReporter (Story 5.4)
  190. # ============================================================================
  191. class TestProgressReporter:
  192. """Tests for ProgressReporter class."""
  193. def test_reporter_init(self):
  194. """Test reporter initialization."""
  195. callback = Mock()
  196. reporter = ProgressReporter(callback)
  197. assert reporter.callback is callback
  198. assert reporter.total == 0
  199. assert reporter.completed == 0
  200. def test_on_start(self):
  201. """Test start event."""
  202. callback = Mock()
  203. reporter = ProgressReporter(callback)
  204. reporter.on_start(total=10)
  205. assert reporter.total == 10
  206. callback.assert_called_once()
  207. status, data = callback.call_args[0]
  208. assert status == ProgressStatus.START
  209. assert data["total"] == 10
  210. def test_on_chapter_complete(self):
  211. """Test chapter complete event."""
  212. callback = Mock()
  213. reporter = ProgressReporter(callback)
  214. reporter.on_start(total=5)
  215. reporter.on_chapter_complete(chapter_index=0, chapter_title="Chapter 1")
  216. assert reporter.completed == 1
  217. assert reporter.progress_percent == 20.0
  218. def test_on_chapter_failed(self):
  219. """Test chapter failed event."""
  220. callback = Mock()
  221. reporter = ProgressReporter(callback)
  222. reporter.on_start(total=5)
  223. error = Exception("Test error")
  224. reporter.on_chapter_failed(chapter_index=0, error=error)
  225. assert reporter.failed == 1
  226. def test_on_complete(self):
  227. """Test complete event."""
  228. callback = Mock()
  229. reporter = ProgressReporter(callback)
  230. reporter.on_start(total=3)
  231. reporter.on_chapter_complete(chapter_index=0)
  232. reporter.on_chapter_complete(chapter_index=1)
  233. reporter.on_complete()
  234. assert reporter.is_complete is True
  235. assert reporter.duration_seconds is not None
  236. def test_progress_percent(self):
  237. """Test progress percentage calculation."""
  238. reporter = ProgressReporter()
  239. reporter.on_start(total=10)
  240. assert reporter.progress_percent == 0.0
  241. for i in range(5):
  242. reporter.on_chapter_complete(chapter_index=i)
  243. assert reporter.progress_percent == 50.0
  244. def test_get_summary(self):
  245. """Test getting progress summary."""
  246. reporter = ProgressReporter()
  247. reporter.on_start(total=10)
  248. reporter.on_chapter_complete(chapter_index=0)
  249. reporter.on_chapter_complete(chapter_index=1)
  250. reporter.on_chapter_failed(chapter_index=2, error=Exception("test"))
  251. summary = reporter.get_summary()
  252. assert summary["total"] == 10
  253. assert summary["completed"] == 2
  254. assert summary["failed"] == 1
  255. assert summary["remaining"] == 7
  256. assert summary["progress_percent"] == 20.0
  257. class TestConsoleProgressReporter:
  258. """Tests for ConsoleProgressReporter class."""
  259. def test_console_reporter_init(self):
  260. """Test console reporter initialization."""
  261. reporter = ConsoleProgressReporter(show_details=True)
  262. assert reporter.show_details is True
  263. assert reporter.reporter is not None
  264. def test_get_reporter(self):
  265. """Test getting underlying reporter."""
  266. console = ConsoleProgressReporter()
  267. reporter = console.get_reporter()
  268. assert isinstance(reporter, ProgressReporter)
  269. # ============================================================================
  270. # Test ChapterTranslator (Story 5.3)
  271. # ============================================================================
  272. class TestChapterTranslator:
  273. """Tests for ChapterTranslator class."""
  274. @pytest.fixture
  275. def mock_pipeline(self):
  276. """Create a mock translation pipeline."""
  277. pipeline = MagicMock(spec=TranslationPipeline)
  278. pipeline.translate.return_value = "Translated text"
  279. return pipeline
  280. @pytest.fixture
  281. def mock_repository(self):
  282. """Create a mock repository."""
  283. repo = MagicMock()
  284. repo.save_chapter = MagicMock()
  285. repo.get_pending_chapters.return_value = []
  286. repo.get_chapters.return_value = []
  287. repo.get_failed_chapters.return_value = []
  288. repo.record_failure = MagicMock()
  289. repo.update_work_status = MagicMock()
  290. return repo
  291. @pytest.fixture
  292. def sample_chapter(self):
  293. """Create a sample chapter."""
  294. from src.repository.models import ChapterItem, ChapterStatus
  295. return ChapterItem(
  296. work_id="test_work",
  297. chapter_index=0,
  298. title="Test Chapter",
  299. content="Test content for translation.",
  300. status=ChapterStatus.PENDING
  301. )
  302. def test_translator_init(self, mock_pipeline, mock_repository):
  303. """Test translator initialization."""
  304. translator = ChapterTranslator(mock_pipeline, mock_repository)
  305. assert translator.pipeline is mock_pipeline
  306. assert translator.repository is mock_repository
  307. def test_split_paragraphs_simple(self, mock_pipeline, mock_repository):
  308. """Test splitting simple paragraphs."""
  309. translator = ChapterTranslator(mock_pipeline, mock_repository)
  310. content = "Para 1\n\nPara 2\n\nPara 3"
  311. segments = translator._split_paragraphs(content)
  312. assert len(segments) == 3
  313. assert segments[0] == "Para 1"
  314. assert segments[1] == "Para 2"
  315. assert segments[2] == "Para 3"
  316. def test_split_long_paragraph(self, mock_pipeline, mock_repository):
  317. """Test splitting a long paragraph."""
  318. translator = ChapterTranslator(mock_pipeline, mock_repository)
  319. # Create a long paragraph
  320. long_text = "。".join(["Sentence " + str(i) for i in range(100)])
  321. segments = translator._split_long_paragraph(long_text)
  322. assert len(segments) > 1
  323. # Each segment should be under the max length
  324. for seg in segments:
  325. assert len(seg) <= translator.MAX_SEGMENT_LENGTH + 100 # Allow some buffer
  326. def test_translate_chapter_success(
  327. self, mock_pipeline, mock_repository, sample_chapter
  328. ):
  329. """Test successful chapter translation."""
  330. translator = ChapterTranslator(mock_pipeline, mock_repository)
  331. result = translator.translate_chapter("test_work", sample_chapter)
  332. assert result.status == "completed"
  333. assert result.translation is not None
  334. mock_repository.save_chapter.assert_called()
  335. def test_translate_chapter_already_completed(
  336. self, mock_pipeline, mock_repository
  337. ):
  338. """Test skipping already translated chapter."""
  339. from src.repository.models import ChapterItem, ChapterStatus
  340. chapter = ChapterItem(
  341. work_id="test_work",
  342. chapter_index=0,
  343. title="Test",
  344. content="Content",
  345. status=ChapterStatus.COMPLETED,
  346. translation="Already translated"
  347. )
  348. translator = ChapterTranslator(mock_pipeline, mock_repository)
  349. result = translator.translate_chapter("test_work", chapter)
  350. assert result.translation == "Already translated"
  351. # translate should not be called
  352. mock_pipeline.translate.assert_not_called()
  353. def test_translate_work_empty(self, mock_pipeline, mock_repository):
  354. """Test translating work with no pending chapters."""
  355. mock_repository.get_pending_chapters.return_value = []
  356. mock_repository.get_chapters.return_value = []
  357. translator = ChapterTranslator(mock_pipeline, mock_repository)
  358. translator.translate_work("test_work")
  359. # Should not crash, should just return
  360. mock_pipeline.translate.assert_not_called()
  361. def test_retry_failed_chapters(self, mock_pipeline, mock_repository):
  362. """Test retrying failed chapters."""
  363. from src.repository.models import ChapterItem, ChapterStatus
  364. failed_chapter = ChapterItem(
  365. work_id="test_work",
  366. chapter_index=0,
  367. title="Failed",
  368. content="Content",
  369. status=ChapterStatus.FAILED,
  370. retry_count=0
  371. )
  372. mock_repository.get_failed_chapters.return_value = [failed_chapter]
  373. translator = ChapterTranslator(mock_pipeline, mock_repository)
  374. translator.retry_failed_chapters("test_work")
  375. assert mock_pipeline.translate.called
  376. def test_set_progress_callback(self, mock_pipeline, mock_repository):
  377. """Test setting a new progress callback."""
  378. translator = ChapterTranslator(mock_pipeline, mock_repository)
  379. new_callback = Mock()
  380. translator.set_progress_callback(new_callback)
  381. assert translator.progress_reporter.callback is new_callback
  382. # ============================================================================
  383. # Integration Tests (with mocked external dependencies)
  384. # ============================================================================
  385. class TestIntegration:
  386. """Integration tests for the translator module."""
  387. @pytest.fixture
  388. def full_pipeline(self, tmp_path):
  389. """Create a full pipeline with mocked model but real other components."""
  390. with patch('src.translator.engine.M2M100ForConditionalGeneration') as mock_model, \
  391. patch('src.translator.engine.M2M100Tokenizer') as mock_tokenizer:
  392. # Setup mocks
  393. mock_tok_instance = MagicMock()
  394. mock_tok_instance.src_lang = "zh"
  395. mock_tok_instance.lang_code_to_id = {"zh": 1, "en": 2}
  396. mock_tokenizer.from_pretrained.return_value = mock_tok_instance
  397. mock_model_instance = MagicMock()
  398. mock_model.from_pretrained.return_value = mock_model_instance
  399. # Create mock model directory
  400. model_dir = tmp_path / "model"
  401. model_dir.mkdir()
  402. (model_dir / "config.json").write_text("{}")
  403. # Return configured components
  404. mock_tok_instance.batch_decode.return_value = ["Translated text"]
  405. engine = TranslationEngine(model_path=str(model_dir))
  406. glossary = Glossary()
  407. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  408. pipeline = TranslationPipeline(engine, glossary)
  409. return {
  410. "engine": engine,
  411. "pipeline": pipeline,
  412. "glossary": glossary,
  413. "mock_tok": mock_tok_instance
  414. }
  415. def test_full_pipeline_translate(self, full_pipeline):
  416. """Test full pipeline from text to translation."""
  417. pipeline = full_pipeline["pipeline"]
  418. mock_tok = full_pipeline["mock_tok"]
  419. # Setup mock to return text with placeholder
  420. mock_tok.batch_decode.return_value = ["__en__Lin Feng is here"]
  421. result = pipeline.translate("林风在这里")
  422. assert "Lin Feng" in result
  423. def test_full_pipeline_statistics(self, full_pipeline):
  424. """Test getting statistics from pipeline."""
  425. pipeline = full_pipeline["pipeline"]
  426. stats = pipeline.get_statistics("林风是林风的剑")
  427. assert "林风" in stats
  428. assert stats["林风"] == 2