test_term_injector.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. """
  2. Unit tests for the term injection module.
  3. Tests cover term injection, validation, and statistics tracking.
  4. """
  5. import sys
  6. from unittest.mock import Mock
  7. # Mock torch and transformers before importing
  8. sys_mock = Mock()
  9. sys.modules["torch"] = sys_mock
  10. sys.modules["transformers"] = sys_mock
  11. import pytest
  12. from src.glossary.models import Glossary, GlossaryEntry, TermCategory
  13. from src.translator.term_injector import (
  14. TermInjector,
  15. TermValidator,
  16. TermStatistics,
  17. TermValidationResult,
  18. TermResult,
  19. TermUsageRecord,
  20. )
  21. class TestTermInjector:
  22. """Test cases for TermInjector class."""
  23. def test_init(self):
  24. """Test TermInjector initialization."""
  25. glossary = Glossary()
  26. injector = TermInjector(glossary, "zh", "en")
  27. assert injector.glossary == glossary
  28. assert injector.src_lang == "zh"
  29. assert injector.tgt_lang == "en"
  30. def test_generate_prompt_with_terms(self):
  31. """Test prompt generation with glossary terms."""
  32. glossary = Glossary()
  33. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER, "Protagonist"))
  34. glossary.add(GlossaryEntry("火球术", "Fireball", TermCategory.SKILL))
  35. injector = TermInjector(glossary, "zh", "en")
  36. prompt = injector.generate_prompt("林风释放了火球术")
  37. # Check that prompt contains key elements
  38. assert "Chinese" in prompt or "English" in prompt
  39. assert "林风" in prompt
  40. assert "Lin Feng" in prompt
  41. assert "火球术" in prompt
  42. assert "Fireball" in prompt
  43. assert "character" in prompt.lower()
  44. assert "skill" in prompt.lower()
  45. def test_generate_prompt_without_examples(self):
  46. """Test prompt generation without few-shot examples."""
  47. glossary = Glossary()
  48. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  49. injector = TermInjector(glossary, "zh", "en")
  50. prompt = injector.generate_prompt("林风来了", include_examples=False)
  51. assert "林风" in prompt
  52. assert "Lin Feng" in prompt
  53. assert "Examples:" not in prompt
  54. def test_generate_prompt_empty_glossary(self):
  55. """Test prompt generation with empty glossary."""
  56. glossary = Glossary()
  57. injector = TermInjector(glossary, "zh", "en")
  58. prompt = injector.generate_prompt("测试文本")
  59. # Should just return the source text
  60. assert prompt == "测试文本"
  61. def test_generate_prompt_max_examples(self):
  62. """Test prompt generation with limited examples."""
  63. glossary = Glossary()
  64. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  65. glossary.add(GlossaryEntry("火球术", "Fireball", TermCategory.SKILL))
  66. glossary.add(GlossaryEntry("青云宗", "Qingyun Sect", TermCategory.ORGANIZATION))
  67. glossary.add(GlossaryEntry("龙剑", "Dragon Sword", TermCategory.ITEM))
  68. injector = TermInjector(glossary, "zh", "en")
  69. prompt = injector.generate_prompt("林风使用龙剑", max_examples=2)
  70. # Should limit examples
  71. assert "Examples:" in prompt
  72. def test_inject_terms(self):
  73. """Test term injection into source text."""
  74. glossary = Glossary()
  75. injector = TermInjector(glossary, "zh", "en")
  76. result = injector.inject_terms("林风释放了火球术", ["林风", "火球术"])
  77. assert "[TERM:林风]" in result
  78. assert "[TERM:火球术]" in result
  79. def test_get_relevant_terms(self):
  80. """Test getting relevant terms for source text."""
  81. glossary = Glossary()
  82. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  83. glossary.add(GlossaryEntry("火球术", "Fireball", TermCategory.SKILL))
  84. glossary.add(GlossaryEntry("青云宗", "Qingyun Sect", TermCategory.ORGANIZATION))
  85. injector = TermInjector(glossary, "zh", "en")
  86. relevant = injector._get_relevant_terms("林风释放了火球术")
  87. # Should find two terms
  88. assert len(relevant) == 2
  89. sources = [t.source for t in relevant]
  90. assert "林风" in sources
  91. assert "火球术" in sources
  92. assert "青云宗" not in sources
  93. def test_build_terminology_table(self):
  94. """Test terminology table building."""
  95. glossary = Glossary()
  96. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  97. glossary.add(GlossaryEntry("火球术", "Fireball", TermCategory.SKILL))
  98. injector = TermInjector(glossary, "zh", "en")
  99. terms = glossary.get_all()
  100. table = injector._build_terminology_table(terms)
  101. assert "Terminology Table:" in table
  102. assert "林风" in table
  103. assert "Lin Feng" in table
  104. assert "Character:" in table
  105. class TestTermValidator:
  106. """Test cases for TermValidator class."""
  107. def test_init(self):
  108. """Test TermValidator initialization."""
  109. glossary = Glossary()
  110. validator = TermValidator(glossary)
  111. assert validator.glossary == glossary
  112. def test_validate_translation_success(self):
  113. """Test successful validation."""
  114. glossary = Glossary()
  115. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  116. validator = TermValidator(glossary)
  117. result = validator.validate_translation(
  118. source="林风来了",
  119. target="Lin Feng came"
  120. )
  121. assert result.is_valid is True
  122. assert result.success_rate == 100.0
  123. assert len(result.term_results) == 1
  124. assert result.term_results["林风"].success is True
  125. def test_validate_translation_failure(self):
  126. """Test validation with missing expected translation."""
  127. glossary = Glossary()
  128. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  129. validator = TermValidator(glossary)
  130. result = validator.validate_translation(
  131. source="林风来了",
  132. target="Lin came" # Missing "Feng"
  133. )
  134. assert result.is_valid is False
  135. assert result.success_rate < 100.0
  136. assert len(result.issues) > 0
  137. def test_validate_translation_multiple_terms(self):
  138. """Test validation with multiple terms."""
  139. glossary = Glossary()
  140. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  141. glossary.add(GlossaryEntry("火球术", "Fireball", TermCategory.SKILL))
  142. validator = TermValidator(glossary)
  143. result = validator.validate_translation(
  144. source="林风释放了火球术",
  145. target="Lin Feng released Fireball"
  146. )
  147. assert result.is_valid is True
  148. assert len(result.term_results) == 2
  149. assert result.term_results["林风"].success is True
  150. assert result.term_results["火球术"].success is True
  151. def test_validate_translation_partial_match(self):
  152. """Test validation with partial term match."""
  153. glossary = Glossary()
  154. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  155. validator = TermValidator(glossary)
  156. # Multiple occurrences but only one translated correctly
  157. result = validator.validate_translation(
  158. source="林风说,林风知道",
  159. target="Lin Feng said, Lin knows" # Only one "Lin Feng"
  160. )
  161. # Should detect the mismatch
  162. assert result.success_rate < 100
  163. def test_validate_translation_empty_source(self):
  164. """Test validation with source text that has no terms."""
  165. glossary = Glossary()
  166. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  167. validator = TermValidator(glossary)
  168. result = validator.validate_translation(
  169. source="这是一个测试", # No glossary terms
  170. target="This is a test"
  171. )
  172. assert result.is_valid is True
  173. assert len(result.term_results) == 0
  174. def test_check_term_consistency(self):
  175. """Test individual term consistency check."""
  176. glossary = Glossary()
  177. entry = GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER)
  178. glossary.add(entry)
  179. validator = TermValidator(glossary)
  180. # Successful case
  181. result = validator._check_term_consistency(
  182. source="林风来了",
  183. target="Lin Feng came",
  184. entry=entry
  185. )
  186. assert result.success is True
  187. assert result.source == "林风"
  188. assert result.expected == "Lin Feng"
  189. # Failed case
  190. result = validator._check_term_consistency(
  191. source="林风来了",
  192. target="Lin came",
  193. entry=entry
  194. )
  195. assert result.success is False
  196. def test_successful_terms_property(self):
  197. """Test successful_terms property."""
  198. glossary = Glossary()
  199. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  200. glossary.add(GlossaryEntry("火球术", "Fireball", TermCategory.SKILL))
  201. validator = TermValidator(glossary)
  202. result = validator.validate_translation(
  203. source="林风释放了火球术",
  204. target="Lin Feng released Fireball"
  205. )
  206. assert "林风" in result.successful_terms
  207. assert "火球术" in result.successful_terms
  208. def test_failed_terms_property(self):
  209. """Test failed_terms property."""
  210. glossary = Glossary()
  211. glossary.add(GlossaryEntry("林风", "Lin Feng", TermCategory.CHARACTER))
  212. glossary.add(GlossaryEntry("火球术", "Fireball", TermCategory.SKILL))
  213. validator = TermValidator(glossary)
  214. result = validator.validate_translation(
  215. source="林风释放了火球术",
  216. target="Lin released Fireball" # Missing "Feng"
  217. )
  218. assert "林风" in result.failed_terms
  219. assert "火球术" not in result.failed_terms
  220. class TestTermStatistics:
  221. """Test cases for TermStatistics class."""
  222. def test_init(self):
  223. """Test TermStatistics initialization."""
  224. stats = TermStatistics()
  225. assert len(stats) == 0
  226. assert stats.get_statistics()["total_usages"] == 0
  227. def test_record_usage_success(self):
  228. """Test recording successful term usage."""
  229. stats = TermStatistics()
  230. stats.record_usage(
  231. term="林风",
  232. expected="Lin Feng",
  233. success=True,
  234. context="source_count=1, target_count=1"
  235. )
  236. assert len(stats) == 1
  237. stat_dict = stats.get_statistics()
  238. assert stat_dict["total_usages"] == 1
  239. assert stat_dict["total_successes"] == 1
  240. assert stat_dict["total_failures"] == 0
  241. def test_record_usage_failure(self):
  242. """Test recording failed term usage."""
  243. stats = TermStatistics()
  244. stats.record_usage(
  245. term="林风",
  246. expected="Lin Feng",
  247. success=False,
  248. context="source_count=1, target_count=0"
  249. )
  250. assert len(stats) == 1
  251. stat_dict = stats.get_statistics()
  252. assert stat_dict["total_usages"] == 1
  253. assert stat_dict["total_successes"] == 0
  254. assert stat_dict["total_failures"] == 1
  255. def test_record_multiple_usages(self):
  256. """Test recording multiple term usages."""
  257. stats = TermStatistics()
  258. stats.record_usage("林风", "Lin Feng", True)
  259. stats.record_usage("林风", "Lin Feng", True)
  260. stats.record_usage("火球术", "Fireball", False)
  261. stat_dict = stats.get_statistics()
  262. assert stat_dict["total_usages"] == 3
  263. assert stat_dict["total_successes"] == 2
  264. assert stat_dict["total_failures"] == 1
  265. assert stat_dict["unique_terms"] == 2
  266. def test_get_term_statistics(self):
  267. """Test getting statistics for a specific term."""
  268. stats = TermStatistics()
  269. stats.record_usage("林风", "Lin Feng", True)
  270. stats.record_usage("林风", "Lin Feng", False)
  271. term_stats = stats.get_term_statistics("林风")
  272. assert term_stats is not None
  273. assert term_stats["term"] == "林风"
  274. assert term_stats["total_usages"] == 2
  275. assert term_stats["successes"] == 1
  276. assert term_stats["failures"] == 1
  277. assert term_stats["success_rate"] == 50.0
  278. def test_get_term_statistics_nonexistent(self):
  279. """Test getting statistics for a non-existent term."""
  280. stats = TermStatistics()
  281. term_stats = stats.get_term_statistics("nonexistent")
  282. assert term_stats is None
  283. def test_get_failed_terms(self):
  284. """Test getting list of failed terms."""
  285. stats = TermStatistics()
  286. stats.record_usage("林风", "Lin Feng", True)
  287. stats.record_usage("火球术", "Fireball", False)
  288. stats.record_usage("青云宗", "Qingyun Sect", False)
  289. failed = stats.get_failed_terms()
  290. assert "林风" not in failed
  291. assert "火球术" in failed
  292. assert "青云宗" in failed
  293. def test_reset(self):
  294. """Test resetting statistics."""
  295. stats = TermStatistics()
  296. stats.record_usage("林风", "Lin Feng", True)
  297. stats.record_usage("火球术", "Fireball", False)
  298. assert len(stats) == 2
  299. stats.reset()
  300. assert len(stats) == 0
  301. stat_dict = stats.get_statistics()
  302. assert stat_dict["total_usages"] == 0
  303. def test_merge(self):
  304. """Test merging statistics from another instance."""
  305. stats1 = TermStatistics()
  306. stats1.record_usage("林风", "Lin Feng", True)
  307. stats2 = TermStatistics()
  308. stats2.record_usage("火球术", "Fireball", False)
  309. stats1.merge(stats2)
  310. assert len(stats1) == 2
  311. stat_dict = stats1.get_statistics()
  312. assert stat_dict["unique_terms"] == 2
  313. def test_get_records(self):
  314. """Test getting all recorded usage records."""
  315. stats = TermStatistics()
  316. stats.record_usage("林风", "Lin Feng", True)
  317. stats.record_usage("火球术", "Fireball", False)
  318. records = stats.get_records()
  319. assert len(records) == 2
  320. assert all(isinstance(r, TermUsageRecord) for r in records)
  321. def test_generate_report(self):
  322. """Test generating a human-readable report."""
  323. stats = TermStatistics()
  324. stats.record_usage("林风", "Lin Feng", True)
  325. stats.record_usage("火球术", "Fireball", False)
  326. report = stats.generate_report()
  327. assert "Term Translation Statistics" in report
  328. assert "Total usages: 2" in report
  329. assert "林风" in report
  330. assert "火球术" in report
  331. assert "Failed terms:" in report
  332. def test_overall_success_rate_calculation(self):
  333. """Test overall success rate calculation."""
  334. stats = TermStatistics()
  335. stats.record_usage("林风", "Lin Feng", True)
  336. stats.record_usage("林风", "Lin Feng", True)
  337. stats.record_usage("火球术", "Fireball", True)
  338. stats.record_usage("青云宗", "Qingyun Sect", False)
  339. stat_dict = stats.get_statistics()
  340. assert stat_dict["overall_success_rate"] == 75.0
  341. def test_empty_statistics(self):
  342. """Test statistics with no records."""
  343. stats = TermStatistics()
  344. stat_dict = stats.get_statistics()
  345. assert stat_dict["total_usages"] == 0
  346. assert stat_dict["unique_terms"] == 0
  347. assert stat_dict["overall_success_rate"] == 100.0
  348. report = stats.generate_report()
  349. assert "Total usages: 0" in report
  350. class TestTermUsageRecord:
  351. """Test cases for TermUsageRecord dataclass."""
  352. def test_create_record(self):
  353. """Test creating a usage record."""
  354. record = TermUsageRecord(
  355. term="林风",
  356. expected="Lin Feng",
  357. success=True,
  358. context="source_count=1"
  359. )
  360. assert record.term == "林风"
  361. assert record.expected == "Lin Feng"
  362. assert record.success is True
  363. assert record.context == "source_count=1"
  364. class TestTermResult:
  365. """Test cases for TermResult dataclass."""
  366. def test_create_result(self):
  367. """Test creating a term result."""
  368. result = TermResult(
  369. source="林风",
  370. expected="Lin Feng",
  371. found=True,
  372. success=True,
  373. context="Valid translation"
  374. )
  375. assert result.source == "林风"
  376. assert result.expected == "Lin Feng"
  377. assert result.found is True
  378. assert result.success is True
  379. assert result.context == "Valid translation"
  380. class TestTermValidationResult:
  381. """Test cases for TermValidationResult dataclass."""
  382. def test_create_validation_result(self):
  383. """Test creating a validation result."""
  384. term_results = {
  385. "林风": TermResult("林风", "Lin Feng", True, True),
  386. "火球术": TermResult("火球术", "Fireball", False, False)
  387. }
  388. result = TermValidationResult(
  389. is_valid=False,
  390. term_results=term_results,
  391. success_rate=50.0,
  392. issues=["Fireball not found"]
  393. )
  394. assert result.is_valid is False
  395. assert result.success_rate == 50.0
  396. assert len(result.issues) == 1
  397. assert "林风" in result.successful_terms
  398. assert "火球术" in result.failed_terms