test_mcp_server.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. """
  2. Unit tests for the Novel Translator MCP Server.
  3. This test suite validates all MCP tools and server functionality.
  4. """
  5. import asyncio
  6. import json
  7. import tempfile
  8. from pathlib import Path
  9. from unittest.mock import AsyncMock, MagicMock, Mock, patch
  10. import pytest
  11. # Mock the heavy ML imports before importing server modules
  12. sys_modules_mock = MagicMock()
  13. sys_modules_mock.transformers = MagicMock()
  14. sys_modules_mock.torch = MagicMock()
  15. with patch.dict('sys.modules', {
  16. 'transformers': sys_modules_mock.transformers,
  17. 'torch': sys_modules_mock.torch
  18. }):
  19. from src.mcp_server.server import (
  20. mcp,
  21. get_pipeline,
  22. get_glossary,
  23. get_cleaning_pipeline,
  24. get_fingerprint_service,
  25. create_task,
  26. update_progress,
  27. complete_task,
  28. notify_glossary_updated,
  29. )
  30. from src.glossary.models import Glossary, GlossaryEntry, TermCategory
  31. # ============================================================================
  32. # Fixtures
  33. # ============================================================================
  34. @pytest.fixture
  35. def temp_file():
  36. """Create a temporary file for testing."""
  37. with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8') as f:
  38. f.write("## Chapter 1\n\nThis is test content.\n\n## Chapter 2\n\nMore content here.")
  39. temp_path = f.name
  40. yield temp_path
  41. Path(temp_path).unlink(missing_ok=True)
  42. @pytest.fixture
  43. def sample_glossary():
  44. """Create a sample glossary for testing."""
  45. glossary = Glossary()
  46. glossary.add(GlossaryEntry(
  47. source="林风",
  48. target="Lin Feng",
  49. category=TermCategory.CHARACTER,
  50. context="Main protagonist"
  51. ))
  52. glossary.add(GlossaryEntry(
  53. source="火球术",
  54. target="Fireball",
  55. category=TermCategory.SKILL,
  56. context="Magic spell"
  57. ))
  58. return glossary
  59. @pytest.fixture
  60. def mock_pipeline():
  61. """Create a mock translation pipeline."""
  62. pipeline = MagicMock()
  63. pipeline.translate = MagicMock(return_value=MagicMock(
  64. translated="Translated text",
  65. terms_used=["林风"]
  66. ))
  67. pipeline.translate_batch = MagicMock(return_value=[
  68. MagicMock(translated="Translation 1", terms_used=["term1"]),
  69. MagicMock(translated="Translation 2", terms_used=["term2"])
  70. ])
  71. pipeline.set_languages = MagicMock()
  72. pipeline.update_glossary = MagicMock()
  73. pipeline.src_lang = "zh"
  74. pipeline.tgt_lang = "en"
  75. return pipeline
  76. @pytest.fixture
  77. def mock_cleaning_pipeline():
  78. """Create a mock cleaning pipeline."""
  79. from src.cleaning.models import Chapter
  80. pipeline = MagicMock()
  81. pipeline.enable_cleaning = True
  82. pipeline.enable_splitting = True
  83. pipeline.process = MagicMock(return_value=[
  84. Chapter(index=0, title="Chapter 1", content="Content 1", char_count=100),
  85. Chapter(index=1, title="Chapter 2", content="Content 2", char_count=150),
  86. ])
  87. return pipeline
  88. @pytest.fixture
  89. def mock_fingerprint_service():
  90. """Create a mock fingerprint service."""
  91. service = MagicMock()
  92. service.check_before_import = MagicMock(return_value=(False, None))
  93. service.get_fingerprint = MagicMock(return_value="abc123def456")
  94. service.get_file_info = MagicMock(return_value={
  95. "fingerprint": "abc123",
  96. "metadata": {"size": 1000},
  97. "is_duplicate": False,
  98. "existing_work_id": None
  99. })
  100. return service
  101. # ============================================================================
  102. # Server State Tests
  103. # ============================================================================
  104. class TestServerState:
  105. """Tests for server state management."""
  106. def test_get_glossary_returns_singleton(self):
  107. """Test that get_glossary returns the same instance."""
  108. g1 = get_glossary()
  109. g2 = get_glossary()
  110. assert g1 is g2
  111. def test_create_task_generates_unique_id(self):
  112. """Test that create_task generates unique task IDs."""
  113. task1 = create_task("test_type", 10)
  114. task2 = create_task("test_type", 10)
  115. assert task1 != task2
  116. # ============================================================================
  117. # Translation Tool Tests
  118. # ============================================================================
  119. class TestTranslationTools:
  120. """Tests for translation tools."""
  121. @pytest.mark.asyncio
  122. async def test_translate_text_with_mock(self, mock_pipeline):
  123. """Test translate_text with mocked pipeline."""
  124. from src.mcp_server.server import translate_text
  125. with patch('src.mcp_server.server.get_pipeline', return_value=mock_pipeline):
  126. result = await translate_text(
  127. text="你好世界",
  128. src_lang="zh",
  129. tgt_lang="en"
  130. )
  131. assert result["success"] is True
  132. assert result["translated"] == "Translated text"
  133. assert result["terms_used"] == ["林风"]
  134. @pytest.mark.asyncio
  135. async def test_translate_text_empty_input(self):
  136. """Test translate_text with empty input."""
  137. from src.mcp_server.server import translate_text
  138. result = await translate_text(text="", src_lang="zh", tgt_lang="en")
  139. assert result["success"] is False
  140. assert "empty" in result["error"].lower()
  141. @pytest.mark.asyncio
  142. async def test_translate_batch_with_mock(self, mock_pipeline):
  143. """Test translate_batch with mocked pipeline."""
  144. from src.mcp_server.server import translate_batch
  145. with patch('src.mcp_server.server.get_pipeline', return_value=mock_pipeline):
  146. result = await translate_batch(
  147. texts=["Text 1", "Text 2"],
  148. src_lang="zh",
  149. tgt_lang="en"
  150. )
  151. assert result["success"] is True
  152. assert result["translations"] == ["Translation 1", "Translation 2"]
  153. assert len(result["terms_used"]) == 2
  154. @pytest.mark.asyncio
  155. async def test_translate_batch_empty_list(self):
  156. """Test translate_batch with empty list."""
  157. from src.mcp_server.server import translate_batch
  158. result = await translate_batch(texts=[], src_lang="zh", tgt_lang="en")
  159. assert result["success"] is False
  160. assert "empty" in result["error"].lower()
  161. @pytest.mark.asyncio
  162. async def test_translate_file_with_mock(
  163. self,
  164. temp_file,
  165. mock_pipeline,
  166. mock_cleaning_pipeline
  167. ):
  168. """Test translate_file with mocked dependencies."""
  169. from src.mcp_server.server import translate_file
  170. with patch('src.mcp_server.server.get_pipeline', return_value=mock_pipeline), \
  171. patch('src.mcp_server.server.get_cleaning_pipeline', return_value=mock_cleaning_pipeline), \
  172. patch('src.mcp_server.server.create_task', return_value="test-task-id"), \
  173. patch('src.mcp_server.server.update_progress', new_callable=AsyncMock), \
  174. patch('src.mcp_server.server.complete_task', new_callable=AsyncMock):
  175. # Create a temp output file
  176. with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_en.txt') as f:
  177. output_path = f.name
  178. try:
  179. result = await translate_file(
  180. file_path=temp_file,
  181. output_path=output_path
  182. )
  183. assert result["success"] is True
  184. assert "task_id" in result
  185. assert result["chapters_translated"] == 2
  186. finally:
  187. Path(output_path).unlink(missing_ok=True)
  188. @pytest.mark.asyncio
  189. async def test_translate_file_not_found(self):
  190. """Test translate_file with non-existent file."""
  191. from src.mcp_server.server import translate_file
  192. result = await translate_file(file_path="/nonexistent/file.txt")
  193. assert result["success"] is False
  194. assert "not found" in result["error"].lower()
  195. # ============================================================================
  196. # Cleaning Tool Tests
  197. # ============================================================================
  198. class TestCleaningTools:
  199. """Tests for cleaning tools."""
  200. @pytest.mark.asyncio
  201. async def test_clean_file_with_mock(self, temp_file, mock_cleaning_pipeline):
  202. """Test clean_file with mocked pipeline."""
  203. from src.mcp_server.server import clean_file
  204. with patch('src.mcp_server.server.get_cleaning_pipeline', return_value=mock_cleaning_pipeline):
  205. result = await clean_file(file_path=temp_file)
  206. assert result["success"] is True
  207. assert result["chapter_count"] == 2
  208. assert result["total_chars"] == 250
  209. assert len(result["chapters"]) == 2
  210. @pytest.mark.asyncio
  211. async def test_clean_file_not_found(self):
  212. """Test clean_file with non-existent file."""
  213. from src.mcp_server.server import clean_file
  214. result = await clean_file(file_path="/nonexistent/file.txt")
  215. assert result["success"] is False
  216. assert "not found" in result["error"].lower()
  217. @pytest.mark.asyncio
  218. async def test_split_chapters_with_mock(self, mock_cleaning_pipeline):
  219. """Test split_chapters with mocked splitter."""
  220. from src.mcp_server.server import split_chapters
  221. from src.cleaning.splitter import ChapterSplitter
  222. mock_splitter = MagicMock()
  223. mock_splitter.split = MagicMock(return_value=[
  224. MagicMock(index=0, title="Chapter 1", char_count=100, content="Content 1"),
  225. MagicMock(index=1, title="Chapter 2", char_count=150, content="Content 2"),
  226. ])
  227. with patch('src.cleaning.splitter.ChapterSplitter', return_value=mock_splitter):
  228. result = await split_chapters(text="## Chapter 1\n\nContent\n\n## Chapter 2\n\nMore content")
  229. assert result["success"] is True
  230. assert result["chapter_count"] == 2
  231. @pytest.mark.asyncio
  232. async def test_split_chapters_empty_text(self):
  233. """Test split_chapters with empty text."""
  234. from src.mcp_server.server import split_chapters
  235. result = await split_chapters(text="")
  236. assert result["success"] is False
  237. assert "empty" in result["error"].lower()
  238. # ============================================================================
  239. # Glossary Tool Tests
  240. # ============================================================================
  241. class TestGlossaryTools:
  242. """Tests for glossary tools."""
  243. @pytest.mark.asyncio
  244. async def test_glossary_add(self):
  245. """Test adding a term to the glossary."""
  246. from src.mcp_server.server import glossary_add, get_glossary
  247. # Clear the glossary first
  248. get_glossary()._terms.clear()
  249. result = await glossary_add(
  250. source="林风",
  251. target="Lin Feng",
  252. category="character",
  253. context="Main protagonist"
  254. )
  255. assert result["success"] is True
  256. assert "entry" in result
  257. assert result["entry"]["source"] == "林风"
  258. assert result["entry"]["target"] == "Lin Feng"
  259. @pytest.mark.asyncio
  260. async def test_glossary_add_empty_source(self):
  261. """Test adding a term with empty source."""
  262. from src.mcp_server.server import glossary_add
  263. result = await glossary_add(source="", target="Lin Feng")
  264. assert result["success"] is False
  265. assert "empty" in result["error"].lower()
  266. @pytest.mark.asyncio
  267. async def test_glossary_add_empty_target(self):
  268. """Test adding a term with empty target."""
  269. from src.mcp_server.server import glossary_add
  270. result = await glossary_add(source="林风", target="")
  271. assert result["success"] is False
  272. assert "empty" in result["error"].lower()
  273. @pytest.mark.asyncio
  274. async def test_glossary_list(self, sample_glossary):
  275. """Test listing glossary entries."""
  276. from src.mcp_server.server import glossary_list, get_glossary
  277. # Replace with sample glossary
  278. with patch('src.mcp_server.server.get_glossary', return_value=sample_glossary):
  279. result = await glossary_list()
  280. assert result["success"] is True
  281. assert result["count"] == 2
  282. assert len(result["entries"]) == 2
  283. # Check entries
  284. sources = [e["source"] for e in result["entries"]]
  285. assert "林风" in sources
  286. assert "火球术" in sources
  287. @pytest.mark.asyncio
  288. async def test_glossary_clear(self):
  289. """Test clearing the glossary."""
  290. from src.mcp_server.server import glossary_clear, get_glossary
  291. # Add some entries first
  292. get_glossary().add(GlossaryEntry(
  293. source="Test",
  294. target="Test EN",
  295. category=TermCategory.OTHER
  296. ))
  297. result = await glossary_clear()
  298. assert result["success"] is True
  299. assert len(get_glossary()._terms) == 0
  300. # ============================================================================
  301. # Fingerprint Tool Tests
  302. # ============================================================================
  303. class TestFingerprintTools:
  304. """Tests for fingerprint tools."""
  305. @pytest.mark.asyncio
  306. async def test_check_duplicate(self, temp_file, mock_fingerprint_service):
  307. """Test checking for duplicate files."""
  308. from src.mcp_server.server import check_duplicate
  309. with patch('src.mcp_server.server.get_fingerprint_service', return_value=mock_fingerprint_service):
  310. result = await check_duplicate(file_path=temp_file)
  311. assert result["success"] is True
  312. assert result["is_duplicate"] is False
  313. assert result["fingerprint"] == "abc123def456"
  314. @pytest.mark.asyncio
  315. async def test_check_duplicate_not_found(self):
  316. """Test check_duplicate with non-existent file."""
  317. from src.mcp_server.server import check_duplicate
  318. result = await check_duplicate(file_path="/nonexistent/file.txt")
  319. assert result["success"] is False
  320. assert "not found" in result["error"].lower()
  321. @pytest.mark.asyncio
  322. async def test_get_fingerprint(self, temp_file, mock_fingerprint_service):
  323. """Test getting file fingerprint."""
  324. from src.mcp_server.server import get_fingerprint
  325. with patch('src.mcp_server.server.get_fingerprint_service', return_value=mock_fingerprint_service):
  326. result = await get_fingerprint(file_path=temp_file)
  327. assert result["success"] is True
  328. assert result["fingerprint"] == "abc123def456"
  329. assert "file_name" in result
  330. @pytest.mark.asyncio
  331. async def test_get_fingerprint_not_found(self):
  332. """Test get_fingerprint with non-existent file."""
  333. from src.mcp_server.server import get_fingerprint
  334. result = await get_fingerprint(file_path="/nonexistent/file.txt")
  335. assert result["success"] is False
  336. assert "not found" in result["error"].lower()
  337. # ============================================================================
  338. # Progress Resource Tests
  339. # ============================================================================
  340. class TestProgressResources:
  341. """Tests for progress resources."""
  342. @pytest.mark.asyncio
  343. async def test_progress_resource_flow(self):
  344. """Test the full progress resource flow."""
  345. from src.mcp_server.server import get_progress_resource, list_all_progress
  346. # Create a task
  347. task_id = create_task("test_type", 10)
  348. # Update progress
  349. await update_progress(task_id, {"current": 5, "percent": 50.0})
  350. # Get progress
  351. progress_json = await get_progress_resource(task_id)
  352. progress = json.loads(progress_json)
  353. assert progress["task_id"] == task_id
  354. assert progress["current"] == 5
  355. assert progress["percent"] == 50.0
  356. # List all progress
  357. list_json = await list_all_progress()
  358. task_list = json.loads(list_json)
  359. assert task_list["count"] >= 1
  360. # Complete task
  361. await complete_task(task_id, success=True)
  362. # Verify completion
  363. final_progress = json.loads(await get_progress_resource(task_id))
  364. assert final_progress["status"] == "completed"
  365. @pytest.mark.asyncio
  366. async def test_progress_resource_not_found(self):
  367. """Test getting progress for non-existent task."""
  368. from src.mcp_server.server import get_progress_resource
  369. progress_json = await get_progress_resource("non-existent-task-id")
  370. progress = json.loads(progress_json)
  371. assert "error" in progress
  372. assert "not found" in progress["error"].lower()
  373. # ============================================================================
  374. # Integration Tests
  375. # ============================================================================
  376. class TestIntegration:
  377. """Integration tests for complete workflows."""
  378. @pytest.mark.asyncio
  379. async def test_glossary_translation_workflow(self):
  380. """Test adding terms and using them in translation."""
  381. from src.mcp_server.server import glossary_add, glossary_list
  382. # Clear glossary
  383. get_glossary()._terms.clear()
  384. # Add terms
  385. await glossary_add(source="林风", target="Lin Feng", category="character")
  386. await glossary_add(source="青云宗", target="Qingyun Sect", category="organization")
  387. # List terms
  388. result = await glossary_list()
  389. assert result["success"] is True
  390. assert result["count"] == 2
  391. # Verify terms were added
  392. sources = {e["source"] for e in result["entries"]}
  393. assert "林风" in sources
  394. assert "青云宗" in sources
  395. if __name__ == "__main__":
  396. pytest.main([__file__, "-v", "--cov=src/mcp_server", "--cov-report=term"])