Files
012-kaopeilian/backend/app/services/ai/course_chat_service.py
111 442ac78b56
Some checks failed
continuous-integration/drone/push Build is failing
sync: 同步服务器最新代码 (2026-01-27)
更新内容:
- 后端 AI 服务优化(能力分析、知识点解析等)
- 前端考试和陪练界面更新
- 修复多个 prompt 和 JSON 解析问题
- 更新 Coze 语音客户端
2026-01-27 10:03:28 +08:00

791 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
课程对话服务 V2 - Python 原生实现
功能:
- 查询课程知识点作为知识库
- 调用 AI 进行对话
- 支持流式输出
- 多轮对话历史管理Redis 缓存)
提供稳定可靠的课程对话能力。
"""
import json
import logging
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ExternalServiceError
from .ai_service import AIService
from .prompts.course_chat_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
KNOWLEDGE_ITEM_TEMPLATE,
CONVERSATION_WINDOW_SIZE,
CONVERSATION_TTL,
MAX_KNOWLEDGE_POINTS,
MAX_KNOWLEDGE_BASE_LENGTH,
DEFAULT_CHAT_MODEL,
DEFAULT_TEMPERATURE,
)
logger = logging.getLogger(__name__)
# 会话索引 Redis key 前缀/后缀
CONVERSATION_INDEX_PREFIX = "course_chat:user:"
CONVERSATION_INDEX_SUFFIX = ":conversations"
# 会话元数据 key 前缀
CONVERSATION_META_PREFIX = "course_chat:meta:"
# 会话索引过期时间(与会话数据一致)
CONVERSATION_INDEX_TTL = CONVERSATION_TTL
class CourseChatServiceV2:
"""
课程对话服务 V2
使用 Python 原生实现。
使用示例:
```python
service = CourseChatServiceV2()
# 非流式对话
response = await service.chat(
db=db_session,
course_id=1,
query="什么是玻尿酸?",
user_id=1,
conversation_id=None
)
# 流式对话
async for chunk in service.chat_stream(
db=db_session,
course_id=1,
query="什么是玻尿酸?",
user_id=1,
conversation_id=None
):
print(chunk, end="", flush=True)
```
"""
# Redis key 前缀
CONVERSATION_KEY_PREFIX = "course_chat:conversation:"
# 模块标识
MODULE_CODE = "course_chat"
def __init__(self):
"""初始化服务AIService 在方法中动态创建,以传入 db_session"""
pass
async def chat(
self,
db: AsyncSession,
course_id: int,
query: str,
user_id: int,
conversation_id: Optional[str] = None
) -> Dict[str, Any]:
"""
与课程对话(非流式)
Args:
db: 数据库会话
course_id: 课程ID
query: 用户问题
user_id: 用户ID
conversation_id: 会话ID续接对话时传入
Returns:
包含 answer、conversation_id 等字段的字典
"""
try:
logger.info(
f"开始课程对话 V2 - course_id: {course_id}, user_id: {user_id}, "
f"conversation_id: {conversation_id}"
)
# 1. 获取课程知识点
knowledge_base = await self._get_course_knowledge(db, course_id)
if not knowledge_base:
logger.warning(f"课程 {course_id} 没有知识点,使用空知识库")
knowledge_base = "(该课程暂无知识点内容)"
# 2. 获取或创建会话ID
is_new_conversation = False
if not conversation_id:
conversation_id = self._generate_conversation_id(user_id, course_id)
is_new_conversation = True
logger.info(f"创建新会话: {conversation_id}")
# 3. 构建消息列表
messages = await self._build_messages(
knowledge_base=knowledge_base,
query=query,
user_id=user_id,
conversation_id=conversation_id
)
# 4. 创建 AIService 并调用(传入 db_session 以记录调用日志)
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
response = await ai_service.chat(
messages=messages,
model=DEFAULT_CHAT_MODEL,
temperature=DEFAULT_TEMPERATURE,
prompt_name="course_chat"
)
answer = response.content
# 5. 保存对话历史
await self._save_conversation_history(
conversation_id=conversation_id,
user_message=query,
assistant_message=answer
)
# 6. 更新会话索引
if is_new_conversation:
await self._add_to_conversation_index(
user_id, conversation_id, course_id,
first_message=query # 使用用户第一条消息作为会话名称
)
else:
await self._update_conversation_index(user_id, conversation_id)
logger.info(
f"课程对话完成 - course_id: {course_id}, conversation_id: {conversation_id}, "
f"provider: {response.provider}, tokens: {response.total_tokens}"
)
return {
"success": True,
"answer": answer,
"conversation_id": conversation_id,
"ai_provider": response.provider,
"ai_model": response.model,
"ai_tokens": response.total_tokens,
"ai_latency_ms": response.latency_ms,
}
except Exception as e:
logger.error(
f"课程对话失败 - course_id: {course_id}, user_id: {user_id}, error: {e}",
exc_info=True
)
raise ExternalServiceError(f"课程对话失败: {e}")
async def chat_stream(
self,
db: AsyncSession,
course_id: int,
query: str,
user_id: int,
conversation_id: Optional[str] = None
) -> AsyncGenerator[Tuple[str, Optional[str]], None]:
"""
与课程对话(流式输出)
Args:
db: 数据库会话
course_id: 课程ID
query: 用户问题
user_id: 用户ID
conversation_id: 会话ID续接对话时传入
Yields:
Tuple[str, Optional[str]]: (事件类型, 数据)
- ("conversation_started", conversation_id): 会话开始
- ("chunk", text): 文本块
- ("done", full_answer): 结束,附带完整回答
- ("error", message): 错误
"""
full_answer = ""
try:
logger.info(
f"开始流式课程对话 V2 - course_id: {course_id}, user_id: {user_id}, "
f"conversation_id: {conversation_id}"
)
# 1. 获取课程知识点
knowledge_base = await self._get_course_knowledge(db, course_id)
if not knowledge_base:
logger.warning(f"课程 {course_id} 没有知识点,使用空知识库")
knowledge_base = "(该课程暂无知识点内容)"
# 2. 获取或创建会话ID
is_new_conversation = False
if not conversation_id:
conversation_id = self._generate_conversation_id(user_id, course_id)
is_new_conversation = True
logger.info(f"创建新会话: {conversation_id}")
# 3. 发送会话开始事件(如果是新会话)
if is_new_conversation:
yield ("conversation_started", conversation_id)
# 4. 构建消息列表
messages = await self._build_messages(
knowledge_base=knowledge_base,
query=query,
user_id=user_id,
conversation_id=conversation_id
)
# 5. 创建 AIService 并流式调用(传入 db_session 以记录调用日志)
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
async for chunk in ai_service.chat_stream(
messages=messages,
model=DEFAULT_CHAT_MODEL,
temperature=DEFAULT_TEMPERATURE,
prompt_name="course_chat"
):
full_answer += chunk
yield ("chunk", chunk)
# 6. 发送结束事件
yield ("done", full_answer)
# 7. 保存对话历史
await self._save_conversation_history(
conversation_id=conversation_id,
user_message=query,
assistant_message=full_answer
)
# 8. 更新会话索引
if is_new_conversation:
await self._add_to_conversation_index(
user_id, conversation_id, course_id,
first_message=query # 使用用户第一条消息作为会话名称
)
else:
await self._update_conversation_index(user_id, conversation_id)
logger.info(
f"流式课程对话完成 - course_id: {course_id}, conversation_id: {conversation_id}, "
f"answer_length: {len(full_answer)}"
)
except Exception as e:
logger.error(
f"流式课程对话失败 - course_id: {course_id}, user_id: {user_id}, error: {e}",
exc_info=True
)
yield ("error", str(e))
async def _get_course_knowledge(
self,
db: AsyncSession,
course_id: int
) -> str:
"""
获取课程知识点,构建知识库文本
Args:
db: 数据库会话
course_id: 课程ID
Returns:
知识库文本
"""
try:
# 查询知识点(课程知识点查询)
query = text("""
SELECT kp.name, kp.description
FROM knowledge_points kp
INNER JOIN course_materials cm ON kp.material_id = cm.id
WHERE kp.course_id = :course_id
AND kp.is_deleted = 0
AND cm.is_deleted = 0
ORDER BY kp.id
LIMIT :limit
""")
result = await db.execute(
query,
{"course_id": course_id, "limit": MAX_KNOWLEDGE_POINTS}
)
rows = result.fetchall()
if not rows:
logger.warning(f"课程 {course_id} 没有关联的知识点")
return ""
# 构建知识库文本
knowledge_items = []
total_length = 0
for row in rows:
name = row[0] or ""
description = row[1] or ""
item = KNOWLEDGE_ITEM_TEMPLATE.format(
name=name,
description=description
)
# 检查是否超过长度限制
if total_length + len(item) > MAX_KNOWLEDGE_BASE_LENGTH:
logger.warning(
f"知识库文本已达到最大长度限制 {MAX_KNOWLEDGE_BASE_LENGTH}"
f"停止添加更多知识点"
)
break
knowledge_items.append(item)
total_length += len(item)
knowledge_base = "\n".join(knowledge_items)
logger.info(
f"获取课程知识点成功 - course_id: {course_id}, "
f"count: {len(knowledge_items)}, length: {len(knowledge_base)}"
)
return knowledge_base
except Exception as e:
logger.error(f"获取课程知识点失败: {e}")
raise
async def _build_messages(
self,
knowledge_base: str,
query: str,
user_id: int,
conversation_id: str
) -> List[Dict[str, str]]:
"""
构建消息列表(包含历史对话)
Args:
knowledge_base: 知识库文本
query: 当前用户问题
user_id: 用户ID
conversation_id: 会话ID
Returns:
消息列表
"""
messages = []
# 1. 系统提示词
system_content = SYSTEM_PROMPT.format(knowledge_base=knowledge_base)
messages.append({"role": "system", "content": system_content})
# 2. 获取历史对话
history = await self._get_conversation_history(conversation_id)
# 限制历史窗口大小
if len(history) > CONVERSATION_WINDOW_SIZE * 2:
history = history[-(CONVERSATION_WINDOW_SIZE * 2):]
# 添加历史消息
messages.extend(history)
# 3. 当前用户问题
user_content = USER_PROMPT.format(query=query)
messages.append({"role": "user", "content": user_content})
logger.debug(
f"构建消息列表 - total: {len(messages)}, history: {len(history)}"
)
return messages
def _generate_conversation_id(self, user_id: int, course_id: int) -> str:
"""生成会话ID"""
unique_id = uuid.uuid4().hex[:8]
return f"conv_{user_id}_{course_id}_{unique_id}"
async def _get_conversation_history(
self,
conversation_id: str
) -> List[Dict[str, str]]:
"""
从 Redis 获取会话历史
Args:
conversation_id: 会话ID
Returns:
消息列表 [{"role": "user/assistant", "content": "..."}]
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
key = f"{self.CONVERSATION_KEY_PREFIX}{conversation_id}"
data = await redis.get(key)
if not data:
return []
history = json.loads(data)
return history
except RuntimeError:
# Redis 未初始化,返回空历史
logger.warning("Redis 未初始化,无法获取会话历史")
return []
except Exception as e:
logger.warning(f"获取会话历史失败: {e}")
return []
async def _save_conversation_history(
self,
conversation_id: str,
user_message: str,
assistant_message: str
) -> None:
"""
保存对话历史到 Redis
Args:
conversation_id: 会话ID
user_message: 用户消息
assistant_message: AI 回复
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
key = f"{self.CONVERSATION_KEY_PREFIX}{conversation_id}"
# 获取现有历史
history = await self._get_conversation_history(conversation_id)
# 添加新消息
history.append({"role": "user", "content": user_message})
history.append({"role": "assistant", "content": assistant_message})
# 限制历史长度
max_messages = CONVERSATION_WINDOW_SIZE * 2
if len(history) > max_messages:
history = history[-max_messages:]
# 保存到 Redis
await redis.setex(
key,
CONVERSATION_TTL,
json.dumps(history, ensure_ascii=False)
)
logger.debug(
f"保存会话历史成功 - conversation_id: {conversation_id}, "
f"messages: {len(history)}"
)
except RuntimeError:
# Redis 未初始化,跳过保存
logger.warning("Redis 未初始化,无法保存会话历史")
except Exception as e:
logger.warning(f"保存会话历史失败: {e}")
async def get_conversation_messages(
self,
conversation_id: str,
user_id: int
) -> List[Dict[str, Any]]:
"""
获取会话的历史消息
Args:
conversation_id: 会话ID
user_id: 用户ID用于权限验证
Returns:
消息列表
"""
# 验证会话ID是否属于该用户
if not conversation_id.startswith(f"conv_{user_id}_"):
logger.warning(
f"用户 {user_id} 尝试访问不属于自己的会话: {conversation_id}"
)
return []
history = await self._get_conversation_history(conversation_id)
# 格式化返回数据
messages = []
for i, msg in enumerate(history):
messages.append({
"id": i,
"role": msg["role"],
"content": msg["content"],
})
return messages
async def _add_to_conversation_index(
self,
user_id: int,
conversation_id: str,
course_id: int,
first_message: str = ""
) -> None:
"""
将会话添加到用户索引
Args:
user_id: 用户ID
conversation_id: 会话ID
course_id: 课程ID
first_message: 用户第一条消息(用于生成会话名称)
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 1. 添加到用户的会话索引Sorted Setscore 为时间戳)
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
timestamp = time.time()
await redis.zadd(index_key, {conversation_id: timestamp})
await redis.expire(index_key, CONVERSATION_INDEX_TTL)
# 2. 生成会话名称取用户第一条消息的前30个字符
conversation_name = ""
if first_message:
# 移除换行符截取前30个字符
conversation_name = first_message.replace('\n', ' ').strip()[:30]
if len(first_message) > 30:
conversation_name += "..."
# 3. 保存会话元数据(包含会话名称)
meta_key = f"{CONVERSATION_META_PREFIX}{conversation_id}"
meta_data = {
"conversation_id": conversation_id,
"user_id": user_id,
"course_id": course_id,
"name": conversation_name, # 会话名称
"created_at": timestamp,
"updated_at": timestamp,
}
await redis.setex(
meta_key,
CONVERSATION_INDEX_TTL,
json.dumps(meta_data, ensure_ascii=False)
)
logger.debug(
f"会话已添加到索引 - user_id: {user_id}, conversation_id: {conversation_id}, "
f"name: {conversation_name}"
)
except RuntimeError:
logger.warning("Redis 未初始化,无法添加会话索引")
except Exception as e:
logger.warning(f"添加会话索引失败: {e}")
async def _update_conversation_index(
self,
user_id: int,
conversation_id: str
) -> None:
"""
更新会话的最后活跃时间
Args:
user_id: 用户ID
conversation_id: 会话ID
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 更新索引中的时间戳
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
timestamp = time.time()
await redis.zadd(index_key, {conversation_id: timestamp})
await redis.expire(index_key, CONVERSATION_INDEX_TTL)
# 更新元数据中的 updated_at
meta_key = f"{CONVERSATION_META_PREFIX}{conversation_id}"
meta_data = await redis.get(meta_key)
if meta_data:
meta = json.loads(meta_data)
meta["updated_at"] = timestamp
await redis.setex(
meta_key,
CONVERSATION_INDEX_TTL,
json.dumps(meta, ensure_ascii=False)
)
logger.debug(
f"会话索引已更新 - user_id: {user_id}, conversation_id: {conversation_id}"
)
except RuntimeError:
logger.warning("Redis 未初始化,无法更新会话索引")
except Exception as e:
logger.warning(f"更新会话索引失败: {e}")
async def list_user_conversations(
self,
user_id: int,
limit: int = 20
) -> List[Dict[str, Any]]:
"""
获取用户的会话列表
Args:
user_id: 用户ID
limit: 返回数量限制
Returns:
会话列表,按更新时间倒序
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 1. 从索引获取最近的会话ID列表倒序
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
conversation_ids = await redis.zrevrange(index_key, 0, limit - 1)
if not conversation_ids:
logger.debug(f"用户 {user_id} 没有会话记录")
return []
# 2. 获取每个会话的元数据和最后消息
conversations = []
for conv_id in conversation_ids:
# 确保是字符串
if isinstance(conv_id, bytes):
conv_id = conv_id.decode('utf-8')
# 获取元数据
meta_key = f"{CONVERSATION_META_PREFIX}{conv_id}"
meta_data = await redis.get(meta_key)
if meta_data:
if isinstance(meta_data, bytes):
meta_data = meta_data.decode('utf-8')
meta = json.loads(meta_data)
else:
# 从 conversation_id 解析 course_id
# 格式: conv_{user_id}_{course_id}_{uuid}
parts = conv_id.split('_')
course_id = int(parts[2]) if len(parts) >= 3 else 0
meta = {
"conversation_id": conv_id,
"user_id": user_id,
"course_id": course_id,
"created_at": time.time(),
"updated_at": time.time(),
}
# 获取历史消息
history = await self._get_conversation_history(conv_id)
# 获取最后一条消息作为预览
last_message = ""
if history:
# 获取最后一条 assistant 消息
for msg in reversed(history):
if msg["role"] == "assistant":
last_message = msg["content"][:100] # 截取前100字符
if len(msg["content"]) > 100:
last_message += "..."
break
# 获取会话名称
# 优先使用元数据中保存的名称,如果没有则从历史消息中提取
conversation_name = meta.get("name", "")
if not conversation_name and history:
# 从第一条用户消息生成名称
for msg in history:
if msg["role"] == "user":
conversation_name = msg["content"].replace('\n', ' ').strip()[:30]
if len(msg["content"]) > 30:
conversation_name += "..."
break
conversations.append({
"id": conv_id,
"name": conversation_name, # 添加会话名称字段
"course_id": meta.get("course_id"),
"created_at": meta.get("created_at"),
"updated_at": meta.get("updated_at"),
"last_message": last_message,
"message_count": len(history),
})
logger.info(f"获取用户会话列表 - user_id: {user_id}, count: {len(conversations)}")
return conversations
except RuntimeError:
logger.warning("Redis 未初始化,无法获取会话列表")
return []
except Exception as e:
logger.warning(f"获取会话列表失败: {e}")
return []
# 别名方法,供 API 层调用
async def get_conversations(
self,
user_id: int,
course_id: Optional[int] = None,
limit: int = 20
) -> List[Dict[str, Any]]:
"""
获取用户的会话列表(别名方法)
Args:
user_id: 用户ID
course_id: 课程ID可选用于过滤
limit: 返回数量限制
Returns:
会话列表
"""
conversations = await self.list_user_conversations(user_id, limit)
# 如果指定了 course_id进行过滤
if course_id is not None:
conversations = [
c for c in conversations
if c.get("course_id") == course_id
]
return conversations
async def get_messages(
self,
conversation_id: str,
user_id: int,
limit: int = 50
) -> List[Dict[str, Any]]:
"""
获取会话历史消息(别名方法)
Args:
conversation_id: 会话ID
user_id: 用户ID用于权限验证
limit: 返回数量限制
Returns:
消息列表
"""
messages = await self.get_conversation_messages(conversation_id, limit)
return messages
# 创建全局实例
course_chat_service_v2 = CourseChatServiceV2()