|
@@ -19,6 +19,8 @@ from anthropic import Anthropic
|
|
|
from config import MCP_SERVERS, ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL, ANTHROPIC_MODEL
|
|
from config import MCP_SERVERS, ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL, ANTHROPIC_MODEL
|
|
|
from conversation_manager import ConversationManager
|
|
from conversation_manager import ConversationManager
|
|
|
from tool_handler import ToolCallHandler
|
|
from tool_handler import ToolCallHandler
|
|
|
|
|
+from tool_converter import ToolConverter
|
|
|
|
|
+from mcp_client import MCPClient
|
|
|
|
|
|
|
|
|
|
|
|
|
# 存储认证会话 (生产环境应使用 Redis 或数据库)
|
|
# 存储认证会话 (生产环境应使用 Redis 或数据库)
|
|
@@ -108,22 +110,37 @@ async def health():
|
|
|
async def chat(request: Request):
|
|
async def chat(request: Request):
|
|
|
"""
|
|
"""
|
|
|
聊天端点 - 接收用户消息,返回 Claude 响应(支持 MCP 工具调用)
|
|
聊天端点 - 接收用户消息,返回 Claude 响应(支持 MCP 工具调用)
|
|
|
|
|
+
|
|
|
|
|
+ 支持 MCP 认证:通过 X-MCP-Tokens header 传递 JWT tokens
|
|
|
"""
|
|
"""
|
|
|
try:
|
|
try:
|
|
|
data = await request.json()
|
|
data = await request.json()
|
|
|
message = data.get('message', '')
|
|
message = data.get('message', '')
|
|
|
conversation_history = data.get('history', [])
|
|
conversation_history = data.get('history', [])
|
|
|
session_id = request.headers.get('X-Session-ID')
|
|
session_id = request.headers.get('X-Session-ID')
|
|
|
|
|
+ mcp_tokens = request.headers.get('X-MCP-Tokens') # MCP tokens (JSON string)
|
|
|
|
|
|
|
|
if not message:
|
|
if not message:
|
|
|
raise HTTPException(status_code=400, detail="Message is required")
|
|
raise HTTPException(status_code=400, detail="Message is required")
|
|
|
|
|
|
|
|
- # 创建对话管理器
|
|
|
|
|
|
|
+ # 解析 MCP tokens
|
|
|
|
|
+ parsed_tokens = {}
|
|
|
|
|
+ if mcp_tokens:
|
|
|
|
|
+ if isinstance(mcp_tokens, str):
|
|
|
|
|
+ try:
|
|
|
|
|
+ parsed_tokens = json_module.loads(mcp_tokens)
|
|
|
|
|
+ except:
|
|
|
|
|
+ parsed_tokens = {}
|
|
|
|
|
+ else:
|
|
|
|
|
+ parsed_tokens = mcp_tokens
|
|
|
|
|
+
|
|
|
|
|
+ # 创建对话管理器(带 token)
|
|
|
conv_manager = ConversationManager(
|
|
conv_manager = ConversationManager(
|
|
|
api_key=ANTHROPIC_API_KEY,
|
|
api_key=ANTHROPIC_API_KEY,
|
|
|
base_url=ANTHROPIC_BASE_URL,
|
|
base_url=ANTHROPIC_BASE_URL,
|
|
|
model=ANTHROPIC_MODEL,
|
|
model=ANTHROPIC_MODEL,
|
|
|
- session_id=session_id
|
|
|
|
|
|
|
+ session_id=session_id,
|
|
|
|
|
+ mcp_tokens=parsed_tokens
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 格式化对话历史
|
|
# 格式化对话历史
|
|
@@ -163,19 +180,32 @@ async def chat(request: Request):
|
|
|
async def generate_chat_stream(
|
|
async def generate_chat_stream(
|
|
|
message: str,
|
|
message: str,
|
|
|
conversation_history: List[Dict[str, Any]],
|
|
conversation_history: List[Dict[str, Any]],
|
|
|
- session_id: Optional[str]
|
|
|
|
|
|
|
+ session_id: Optional[str],
|
|
|
|
|
+ mcp_tokens: Optional[Dict[str, str]] = None
|
|
|
):
|
|
):
|
|
|
"""生成 SSE 流式响应的异步生成器"""
|
|
"""生成 SSE 流式响应的异步生成器"""
|
|
|
try:
|
|
try:
|
|
|
# 发送开始事件
|
|
# 发送开始事件
|
|
|
yield f"event: start\ndata: {json_module.dumps({'status': 'started'})}\n\n"
|
|
yield f"event: start\ndata: {json_module.dumps({'status': 'started'})}\n\n"
|
|
|
|
|
|
|
|
- # 创建对话管理器
|
|
|
|
|
|
|
+ # 解析 MCP tokens (从 JSON 字符串)
|
|
|
|
|
+ parsed_tokens = {}
|
|
|
|
|
+ if mcp_tokens:
|
|
|
|
|
+ if isinstance(mcp_tokens, str):
|
|
|
|
|
+ try:
|
|
|
|
|
+ parsed_tokens = json_module.loads(mcp_tokens)
|
|
|
|
|
+ except:
|
|
|
|
|
+ parsed_tokens = {}
|
|
|
|
|
+ else:
|
|
|
|
|
+ parsed_tokens = mcp_tokens
|
|
|
|
|
+
|
|
|
|
|
+ # 创建对话管理器(带 token)
|
|
|
conv_manager = ConversationManager(
|
|
conv_manager = ConversationManager(
|
|
|
api_key=ANTHROPIC_API_KEY,
|
|
api_key=ANTHROPIC_API_KEY,
|
|
|
base_url=ANTHROPIC_BASE_URL,
|
|
base_url=ANTHROPIC_BASE_URL,
|
|
|
model=ANTHROPIC_MODEL,
|
|
model=ANTHROPIC_MODEL,
|
|
|
- session_id=session_id
|
|
|
|
|
|
|
+ session_id=session_id,
|
|
|
|
|
+ mcp_tokens=parsed_tokens
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 格式化对话历史
|
|
# 格式化对话历史
|
|
@@ -349,18 +379,21 @@ async def chat_stream(request: Request):
|
|
|
1. Claude 的思考过程
|
|
1. Claude 的思考过程
|
|
|
2. 工具调用状态
|
|
2. 工具调用状态
|
|
|
3. 最终响应
|
|
3. 最终响应
|
|
|
|
|
+
|
|
|
|
|
+ 支持 MCP 认证:通过 X-MCP-Tokens header 传递 JWT tokens
|
|
|
"""
|
|
"""
|
|
|
try:
|
|
try:
|
|
|
data = await request.json()
|
|
data = await request.json()
|
|
|
message = data.get('message', '')
|
|
message = data.get('message', '')
|
|
|
conversation_history = data.get('history', [])
|
|
conversation_history = data.get('history', [])
|
|
|
session_id = request.headers.get('X-Session-ID')
|
|
session_id = request.headers.get('X-Session-ID')
|
|
|
|
|
+ mcp_tokens = request.headers.get('X-MCP-Tokens') # MCP tokens (JSON string)
|
|
|
|
|
|
|
|
if not message:
|
|
if not message:
|
|
|
raise HTTPException(status_code=400, detail="Message is required")
|
|
raise HTTPException(status_code=400, detail="Message is required")
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
return StreamingResponse(
|
|
|
- generate_chat_stream(message, conversation_history, session_id),
|
|
|
|
|
|
|
+ generate_chat_stream(message, conversation_history, session_id, mcp_tokens),
|
|
|
media_type="text/event-stream",
|
|
media_type="text/event-stream",
|
|
|
headers={
|
|
headers={
|
|
|
'Cache-Control': 'no-cache',
|
|
'Cache-Control': 'no-cache',
|
|
@@ -399,15 +432,30 @@ async def list_mcp_servers():
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/mcp/tools")
|
|
@app.get("/api/mcp/tools")
|
|
|
-async def list_mcp_tools(x_session_id: Optional[str] = Header(None, alias='X-Session-ID')):
|
|
|
|
|
- """获取可用的 MCP 工具列表"""
|
|
|
|
|
|
|
+async def list_mcp_tools(
|
|
|
|
|
+ x_session_id: Optional[str] = Header(None, alias='X-Session-ID'),
|
|
|
|
|
+ x_mcp_tokens: Optional[str] = Header(None, alias='X-MCP-Tokens')
|
|
|
|
|
+):
|
|
|
|
|
+ """获取可用的 MCP 工具列表(支持带 token 的认证)"""
|
|
|
try:
|
|
try:
|
|
|
- # 使用静态方法获取工具
|
|
|
|
|
- tools = ConversationManager.get_tools(session_id=x_session_id)
|
|
|
|
|
|
|
+ # 解析 MCP tokens
|
|
|
|
|
+ parsed_tokens = {}
|
|
|
|
|
+ if x_mcp_tokens:
|
|
|
|
|
+ try:
|
|
|
|
|
+ parsed_tokens = json_module.loads(x_mcp_tokens)
|
|
|
|
|
+ except:
|
|
|
|
|
+ parsed_tokens = {}
|
|
|
|
|
+
|
|
|
|
|
+ # 使用带 token 的方法获取工具
|
|
|
|
|
+ tools = await MCPClient.get_all_tools_with_tokens_async(
|
|
|
|
|
+ session_id=x_session_id,
|
|
|
|
|
+ mcp_tokens=parsed_tokens
|
|
|
|
|
+ )
|
|
|
|
|
+ claude_tools = ToolConverter.convert_mcp_tools(tools)
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
- "tools": tools,
|
|
|
|
|
- "count": len(tools)
|
|
|
|
|
|
|
+ "tools": claude_tools,
|
|
|
|
|
+ "count": len(claude_tools)
|
|
|
}
|
|
}
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
import traceback
|
|
import traceback
|
|
@@ -431,32 +479,28 @@ async def login(request: Request):
|
|
|
"""
|
|
"""
|
|
|
try:
|
|
try:
|
|
|
data = await request.json()
|
|
data = await request.json()
|
|
|
- username = data.get('username')
|
|
|
|
|
|
|
+ # 支持 email 和 username 两种参数名
|
|
|
|
|
+ email = data.get('email') or data.get('username')
|
|
|
password = data.get('password')
|
|
password = data.get('password')
|
|
|
|
|
|
|
|
- if not username or not password:
|
|
|
|
|
- raise HTTPException(status_code=400, detail="Username and password are required")
|
|
|
|
|
-
|
|
|
|
|
- # 查找需要 JWT 认证的 MCP 服务器
|
|
|
|
|
- target_server = None
|
|
|
|
|
- for server_id, config in MCP_SERVERS.items():
|
|
|
|
|
- if config.get('auth_type') == 'jwt' and 'login_url' in config:
|
|
|
|
|
- target_server = config
|
|
|
|
|
- break
|
|
|
|
|
|
|
+ if not email or not password:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="Email and password are required")
|
|
|
|
|
|
|
|
|
|
+ # 查找用户 MCP 服务器
|
|
|
|
|
+ target_server = MCP_SERVERS.get('novel-platform-user')
|
|
|
if not target_server:
|
|
if not target_server:
|
|
|
- raise HTTPException(status_code=400, detail="No JWT-authenticated server configured")
|
|
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="Novel Platform User server not configured")
|
|
|
|
|
|
|
|
# 构建登录 URL
|
|
# 构建登录 URL
|
|
|
base_url = target_server.get('base_url', '')
|
|
base_url = target_server.get('base_url', '')
|
|
|
- login_path = target_server.get('login_url', '/api/auth/login')
|
|
|
|
|
|
|
+ login_path = target_server.get('login_url', '/api/v1/auth/login')
|
|
|
login_url = f"{base_url}{login_path}"
|
|
login_url = f"{base_url}{login_path}"
|
|
|
|
|
|
|
|
# 调用实际的登录接口(异步版本)
|
|
# 调用实际的登录接口(异步版本)
|
|
|
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
|
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
|
|
response = await http_client.post(
|
|
response = await http_client.post(
|
|
|
login_url,
|
|
login_url,
|
|
|
- json={"username": username, "password": password}
|
|
|
|
|
|
|
+ json={"email": email, "password": password}
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if response.status_code == 200:
|
|
if response.status_code == 200:
|
|
@@ -465,7 +509,8 @@ async def login(request: Request):
|
|
|
|
|
|
|
|
# 存储会话信息
|
|
# 存储会话信息
|
|
|
auth_sessions[session_id] = {
|
|
auth_sessions[session_id] = {
|
|
|
- "username": username,
|
|
|
|
|
|
|
+ "username": result.get("username", email),
|
|
|
|
|
+ "email": email,
|
|
|
"token": result.get("token"),
|
|
"token": result.get("token"),
|
|
|
"refresh_token": result.get("refresh_token"),
|
|
"refresh_token": result.get("refresh_token"),
|
|
|
"server": target_server.get("name")
|
|
"server": target_server.get("name")
|
|
@@ -474,7 +519,7 @@ async def login(request: Request):
|
|
|
return {
|
|
return {
|
|
|
"success": True,
|
|
"success": True,
|
|
|
"session_id": session_id,
|
|
"session_id": session_id,
|
|
|
- "username": username,
|
|
|
|
|
|
|
+ "username": result.get("username", email),
|
|
|
"server": target_server.get("name"),
|
|
"server": target_server.get("name"),
|
|
|
"token": result.get("token")
|
|
"token": result.get("token")
|
|
|
}
|
|
}
|
|
@@ -494,14 +539,16 @@ async def login(request: Request):
|
|
|
async def admin_login(request: Request):
|
|
async def admin_login(request: Request):
|
|
|
"""
|
|
"""
|
|
|
Novel Platform 管理员登录
|
|
Novel Platform 管理员登录
|
|
|
|
|
+ 代理到实际的管理员登录端点并返回 JWT Token
|
|
|
"""
|
|
"""
|
|
|
try:
|
|
try:
|
|
|
data = await request.json()
|
|
data = await request.json()
|
|
|
- username = data.get('username')
|
|
|
|
|
|
|
+ # 支持 email 和 username 两种参数名
|
|
|
|
|
+ email = data.get('email') or data.get('username')
|
|
|
password = data.get('password')
|
|
password = data.get('password')
|
|
|
|
|
|
|
|
- if not username or not password:
|
|
|
|
|
- raise HTTPException(status_code=400, detail="Username and password are required")
|
|
|
|
|
|
|
+ if not email or not password:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="Email and password are required")
|
|
|
|
|
|
|
|
# 查找管理员 MCP 服务器
|
|
# 查找管理员 MCP 服务器
|
|
|
target_server = MCP_SERVERS.get('novel-platform-admin')
|
|
target_server = MCP_SERVERS.get('novel-platform-admin')
|
|
@@ -510,14 +557,14 @@ async def admin_login(request: Request):
|
|
|
|
|
|
|
|
# 构建登录 URL
|
|
# 构建登录 URL
|
|
|
base_url = target_server.get('base_url', '')
|
|
base_url = target_server.get('base_url', '')
|
|
|
- login_path = target_server.get('login_url', '/api/auth/admin-login')
|
|
|
|
|
|
|
+ login_path = target_server.get('login_url', '/api/v1/auth/admin-login')
|
|
|
login_url = f"{base_url}{login_path}"
|
|
login_url = f"{base_url}{login_path}"
|
|
|
|
|
|
|
|
# 调用实际的登录接口(异步版本)
|
|
# 调用实际的登录接口(异步版本)
|
|
|
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
|
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
|
|
response = await http_client.post(
|
|
response = await http_client.post(
|
|
|
login_url,
|
|
login_url,
|
|
|
- json={"username": username, "password": password}
|
|
|
|
|
|
|
+ json={"email": email, "password": password}
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if response.status_code == 200:
|
|
if response.status_code == 200:
|
|
@@ -525,7 +572,8 @@ async def admin_login(request: Request):
|
|
|
session_id = str(uuid.uuid4())
|
|
session_id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
auth_sessions[session_id] = {
|
|
auth_sessions[session_id] = {
|
|
|
- "username": username,
|
|
|
|
|
|
|
+ "username": result.get("username", email),
|
|
|
|
|
+ "email": email,
|
|
|
"token": result.get("token"),
|
|
"token": result.get("token"),
|
|
|
"refresh_token": result.get("refresh_token"),
|
|
"refresh_token": result.get("refresh_token"),
|
|
|
"server": target_server.get("name"),
|
|
"server": target_server.get("name"),
|
|
@@ -535,7 +583,7 @@ async def admin_login(request: Request):
|
|
|
return {
|
|
return {
|
|
|
"success": True,
|
|
"success": True,
|
|
|
"session_id": session_id,
|
|
"session_id": session_id,
|
|
|
- "username": username,
|
|
|
|
|
|
|
+ "username": result.get("username", email),
|
|
|
"server": target_server.get("name"),
|
|
"server": target_server.get("name"),
|
|
|
"role": "admin",
|
|
"role": "admin",
|
|
|
"token": result.get("token")
|
|
"token": result.get("token")
|