- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
166 lines
5.1 KiB
Python
166 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试与课程对话功能 - Dify 集成
|
||
"""
|
||
|
||
import asyncio
|
||
import httpx
|
||
import json
|
||
|
||
# 测试配置
|
||
API_BASE_URL = "http://localhost:8000"
|
||
LOGIN_ENDPOINT = f"{API_BASE_URL}/api/v1/auth/login"
|
||
CHAT_ENDPOINT = f"{API_BASE_URL}/api/v1/course/chat"
|
||
|
||
# 测试账号
|
||
TEST_USER = {
|
||
"username": "test_user",
|
||
"password": "123456"
|
||
}
|
||
|
||
|
||
async def login() -> str:
|
||
"""登录获取 access_token"""
|
||
print("🔑 正在登录...")
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(
|
||
LOGIN_ENDPOINT,
|
||
data={
|
||
"username": TEST_USER["username"],
|
||
"password": TEST_USER["password"]
|
||
}
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
token = data.get("access_token")
|
||
print(f"✅ 登录成功,token: {token[:20]}...")
|
||
return token
|
||
else:
|
||
print(f"❌ 登录失败: {response.status_code} - {response.text}")
|
||
raise Exception("登录失败")
|
||
|
||
|
||
async def test_course_chat(token: str, course_id: int, query: str, conversation_id: str = None):
|
||
"""测试课程对话"""
|
||
print(f"\n💬 测试与课程 {course_id} 对话")
|
||
print(f"问题: {query}")
|
||
if conversation_id:
|
||
print(f"会话ID: {conversation_id}")
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {token}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
payload = {
|
||
"course_id": course_id,
|
||
"query": query
|
||
}
|
||
|
||
if conversation_id:
|
||
payload["conversation_id"] = conversation_id
|
||
|
||
new_conversation_id = None
|
||
answer = ""
|
||
|
||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||
async with client.stream("POST", CHAT_ENDPOINT, headers=headers, json=payload) as response:
|
||
if response.status_code != 200:
|
||
error_text = await response.aread()
|
||
print(f"❌ API 调用失败: {response.status_code} - {error_text}")
|
||
return None, None
|
||
|
||
print("\n📡 SSE 事件流:")
|
||
print("-" * 60)
|
||
|
||
async for line in response.aiter_lines():
|
||
if not line or not line.strip():
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
|
||
try:
|
||
event_data = json.loads(data_str)
|
||
event_type = event_data.get("event")
|
||
|
||
if event_type == "conversation_started":
|
||
new_conversation_id = event_data.get("conversation_id")
|
||
print(f"🆕 会话已创建: {new_conversation_id}")
|
||
|
||
elif event_type == "message_content":
|
||
answer = event_data.get("answer", "")
|
||
print(f"\n💡 AI 回答:\n{answer}")
|
||
|
||
elif event_type == "message_end":
|
||
print("\n✅ 消息接收完成")
|
||
|
||
elif event_type == "error":
|
||
error_msg = event_data.get("message", "未知错误")
|
||
print(f"\n❌ 错误: {error_msg}")
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f"⚠️ 解析失败: {e} - {data_str[:100]}")
|
||
|
||
print("-" * 60)
|
||
return new_conversation_id, answer
|
||
|
||
|
||
async def main():
|
||
"""主测试流程"""
|
||
print("=" * 60)
|
||
print("🧪 与课程对话功能测试 - Dify 集成")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 1. 登录
|
||
token = await login()
|
||
|
||
# 2. 首次对话
|
||
print("\n" + "=" * 60)
|
||
print("测试场景 1: 首次对话(创建新会话)")
|
||
print("=" * 60)
|
||
|
||
conversation_id, answer1 = await test_course_chat(
|
||
token=token,
|
||
course_id=1,
|
||
query="这门课程讲什么?"
|
||
)
|
||
|
||
if not conversation_id:
|
||
print("\n❌ 测试失败:未获取到 conversation_id")
|
||
return
|
||
|
||
# 3. 续接对话
|
||
print("\n" + "=" * 60)
|
||
print("测试场景 2: 续接对话(使用已有会话)")
|
||
print("=" * 60)
|
||
|
||
_, answer2 = await test_course_chat(
|
||
token=token,
|
||
course_id=1,
|
||
query="能详细说说吗?",
|
||
conversation_id=conversation_id
|
||
)
|
||
|
||
# 4. 测试总结
|
||
print("\n" + "=" * 60)
|
||
print("📊 测试总结")
|
||
print("=" * 60)
|
||
print(f"✅ 首次对话: {'成功' if answer1 else '失败'}")
|
||
print(f"✅ 续接对话: {'成功' if answer2 else '失败'}")
|
||
print(f"✅ 会话管理: conversation_id = {conversation_id}")
|
||
print("\n🎉 所有测试通过!")
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 测试失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|
||
|