""" AI 提示词管理 API """ import os import json from typing import Optional, List import pymysql from fastapi import APIRouter, Depends, HTTPException, status, Query from .auth import get_current_admin, require_superadmin, get_db_connection, AdminUserInfo from .schemas import ( AIPromptCreate, AIPromptUpdate, AIPromptResponse, AIPromptVersionResponse, TenantPromptResponse, TenantPromptUpdate, ResponseModel, ) router = APIRouter(prefix="/prompts", tags=["提示词管理"]) def log_operation(cursor, admin: AdminUserInfo, tenant_id: int, tenant_code: str, operation_type: str, resource_type: str, resource_id: int, resource_name: str, old_value: dict = None, new_value: dict = None): """记录操作日志""" cursor.execute( """ INSERT INTO operation_logs (admin_user_id, admin_username, tenant_id, tenant_code, operation_type, resource_type, resource_id, resource_name, old_value, new_value) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, (admin.id, admin.username, tenant_id, tenant_code, operation_type, resource_type, resource_id, resource_name, json.dumps(old_value, ensure_ascii=False) if old_value else None, json.dumps(new_value, ensure_ascii=False) if new_value else None) ) @router.get("", response_model=List[AIPromptResponse], summary="获取提示词列表") async def list_prompts( module: Optional[str] = Query(None, description="模块筛选"), is_active: Optional[bool] = Query(None, description="是否启用"), admin: AdminUserInfo = Depends(get_current_admin), ): """ 获取所有 AI 提示词模板 - **module**: 模块筛选(course, exam, practice, ability) - **is_active**: 是否启用 """ conn = get_db_connection() try: with conn.cursor() as cursor: conditions = [] params = [] if module: conditions.append("module = %s") params.append(module) if is_active is not None: conditions.append("is_active = %s") params.append(is_active) where_clause = " AND ".join(conditions) if conditions else "1=1" cursor.execute( f""" SELECT * FROM ai_prompts WHERE {where_clause} ORDER BY module, id """, params ) rows = cursor.fetchall() result = [] for row in rows: # 解析 JSON 字段 variables = None if row.get("variables"): try: variables = json.loads(row["variables"]) except: pass output_schema = None if row.get("output_schema"): try: output_schema = json.loads(row["output_schema"]) except: pass result.append(AIPromptResponse( id=row["id"], code=row["code"], name=row["name"], description=row["description"], module=row["module"], system_prompt=row["system_prompt"], user_prompt_template=row["user_prompt_template"], variables=variables, output_schema=output_schema, model_recommendation=row["model_recommendation"], max_tokens=row["max_tokens"], temperature=float(row["temperature"]) if row["temperature"] else 0.7, is_system=row["is_system"], is_active=row["is_active"], version=row["version"], created_at=row["created_at"], updated_at=row["updated_at"], )) return result finally: conn.close() @router.get("/{prompt_id}", response_model=AIPromptResponse, summary="获取提示词详情") async def get_prompt( prompt_id: int, admin: AdminUserInfo = Depends(get_current_admin), ): """获取提示词详情""" conn = get_db_connection() try: with conn.cursor() as cursor: cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,)) row = cursor.fetchone() if not row: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="提示词不存在", ) # 解析 JSON 字段 variables = None if row.get("variables"): try: variables = json.loads(row["variables"]) except: pass output_schema = None if row.get("output_schema"): try: output_schema = json.loads(row["output_schema"]) except: pass return AIPromptResponse( id=row["id"], code=row["code"], name=row["name"], description=row["description"], module=row["module"], system_prompt=row["system_prompt"], user_prompt_template=row["user_prompt_template"], variables=variables, output_schema=output_schema, model_recommendation=row["model_recommendation"], max_tokens=row["max_tokens"], temperature=float(row["temperature"]) if row["temperature"] else 0.7, is_system=row["is_system"], is_active=row["is_active"], version=row["version"], created_at=row["created_at"], updated_at=row["updated_at"], ) finally: conn.close() @router.post("", response_model=AIPromptResponse, summary="创建提示词") async def create_prompt( data: AIPromptCreate, admin: AdminUserInfo = Depends(require_superadmin), ): """ 创建新的提示词模板 需要超级管理员权限 """ conn = get_db_connection() try: with conn.cursor() as cursor: # 检查编码是否已存在 cursor.execute("SELECT id FROM ai_prompts WHERE code = %s", (data.code,)) if cursor.fetchone(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="提示词编码已存在", ) # 创建提示词 cursor.execute( """ INSERT INTO ai_prompts (code, name, description, module, system_prompt, user_prompt_template, variables, output_schema, model_recommendation, max_tokens, temperature, is_system, created_by) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, FALSE, %s) """, (data.code, data.name, data.description, data.module, data.system_prompt, data.user_prompt_template, json.dumps(data.variables) if data.variables else None, json.dumps(data.output_schema) if data.output_schema else None, data.model_recommendation, data.max_tokens, data.temperature, admin.id) ) prompt_id = cursor.lastrowid # 记录操作日志 log_operation( cursor, admin, None, None, "create", "prompt", prompt_id, data.name, new_value=data.model_dump() ) conn.commit() return await get_prompt(prompt_id, admin) finally: conn.close() @router.put("/{prompt_id}", response_model=AIPromptResponse, summary="更新提示词") async def update_prompt( prompt_id: int, data: AIPromptUpdate, admin: AdminUserInfo = Depends(get_current_admin), ): """ 更新提示词模板 更新会自动保存版本历史 """ conn = get_db_connection() try: with conn.cursor() as cursor: # 获取原提示词 cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,)) old_prompt = cursor.fetchone() if not old_prompt: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="提示词不存在", ) # 保存版本历史(如果系统提示词或用户提示词有变化) if data.system_prompt or data.user_prompt_template: new_version = old_prompt["version"] + 1 cursor.execute( """ INSERT INTO ai_prompt_versions (prompt_id, version, system_prompt, user_prompt_template, variables, output_schema, change_summary, created_by) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, (prompt_id, old_prompt["version"], old_prompt["system_prompt"], old_prompt["user_prompt_template"], old_prompt["variables"], old_prompt["output_schema"], f"版本 {old_prompt['version']} 备份", admin.id) ) else: new_version = old_prompt["version"] # 构建更新语句 update_fields = [] update_values = [] if data.name is not None: update_fields.append("name = %s") update_values.append(data.name) if data.description is not None: update_fields.append("description = %s") update_values.append(data.description) if data.system_prompt is not None: update_fields.append("system_prompt = %s") update_values.append(data.system_prompt) if data.user_prompt_template is not None: update_fields.append("user_prompt_template = %s") update_values.append(data.user_prompt_template) if data.variables is not None: update_fields.append("variables = %s") update_values.append(json.dumps(data.variables)) if data.output_schema is not None: update_fields.append("output_schema = %s") update_values.append(json.dumps(data.output_schema)) if data.model_recommendation is not None: update_fields.append("model_recommendation = %s") update_values.append(data.model_recommendation) if data.max_tokens is not None: update_fields.append("max_tokens = %s") update_values.append(data.max_tokens) if data.temperature is not None: update_fields.append("temperature = %s") update_values.append(data.temperature) if data.is_active is not None: update_fields.append("is_active = %s") update_values.append(data.is_active) if not update_fields: return await get_prompt(prompt_id, admin) # 更新版本号 if data.system_prompt or data.user_prompt_template: update_fields.append("version = %s") update_values.append(new_version) update_fields.append("updated_by = %s") update_values.append(admin.id) update_values.append(prompt_id) cursor.execute( f"UPDATE ai_prompts SET {', '.join(update_fields)} WHERE id = %s", update_values ) # 记录操作日志 log_operation( cursor, admin, None, None, "update", "prompt", prompt_id, old_prompt["name"], old_value={"version": old_prompt["version"]}, new_value=data.model_dump(exclude_unset=True) ) conn.commit() return await get_prompt(prompt_id, admin) finally: conn.close() @router.get("/{prompt_id}/versions", response_model=List[AIPromptVersionResponse], summary="获取提示词版本历史") async def get_prompt_versions( prompt_id: int, admin: AdminUserInfo = Depends(get_current_admin), ): """获取提示词的版本历史""" conn = get_db_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT * FROM ai_prompt_versions WHERE prompt_id = %s ORDER BY version DESC """, (prompt_id,) ) rows = cursor.fetchall() result = [] for row in rows: variables = None if row.get("variables"): try: variables = json.loads(row["variables"]) except: pass result.append(AIPromptVersionResponse( id=row["id"], prompt_id=row["prompt_id"], version=row["version"], system_prompt=row["system_prompt"], user_prompt_template=row["user_prompt_template"], variables=variables, change_summary=row["change_summary"], created_at=row["created_at"], )) return result finally: conn.close() @router.post("/{prompt_id}/rollback/{version}", response_model=AIPromptResponse, summary="回滚提示词版本") async def rollback_prompt_version( prompt_id: int, version: int, admin: AdminUserInfo = Depends(get_current_admin), ): """回滚到指定版本的提示词""" conn = get_db_connection() try: with conn.cursor() as cursor: # 获取指定版本 cursor.execute( """ SELECT * FROM ai_prompt_versions WHERE prompt_id = %s AND version = %s """, (prompt_id, version) ) version_row = cursor.fetchone() if not version_row: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="版本不存在", ) # 获取当前提示词 cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,)) current = cursor.fetchone() if not current: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="提示词不存在", ) # 保存当前版本到历史 new_version = current["version"] + 1 cursor.execute( """ INSERT INTO ai_prompt_versions (prompt_id, version, system_prompt, user_prompt_template, variables, output_schema, change_summary, created_by) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, (prompt_id, current["version"], current["system_prompt"], current["user_prompt_template"], current["variables"], current["output_schema"], f"回滚前备份(版本 {current['version']})", admin.id) ) # 回滚 cursor.execute( """ UPDATE ai_prompts SET system_prompt = %s, user_prompt_template = %s, variables = %s, output_schema = %s, version = %s, updated_by = %s WHERE id = %s """, (version_row["system_prompt"], version_row["user_prompt_template"], version_row["variables"], version_row["output_schema"], new_version, admin.id, prompt_id) ) # 记录操作日志 log_operation( cursor, admin, None, None, "rollback", "prompt", prompt_id, current["name"], old_value={"version": current["version"]}, new_value={"version": new_version, "rollback_from": version} ) conn.commit() return await get_prompt(prompt_id, admin) finally: conn.close() @router.get("/tenants/{tenant_id}", response_model=List[TenantPromptResponse], summary="获取租户自定义提示词") async def get_tenant_prompts( tenant_id: int, admin: AdminUserInfo = Depends(get_current_admin), ): """获取租户的自定义提示词列表""" conn = get_db_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT tp.*, ap.code as prompt_code, ap.name as prompt_name FROM tenant_prompts tp JOIN ai_prompts ap ON tp.prompt_id = ap.id WHERE tp.tenant_id = %s ORDER BY ap.module, ap.id """, (tenant_id,) ) rows = cursor.fetchall() return [ TenantPromptResponse( id=row["id"], tenant_id=row["tenant_id"], prompt_id=row["prompt_id"], prompt_code=row["prompt_code"], prompt_name=row["prompt_name"], system_prompt=row["system_prompt"], user_prompt_template=row["user_prompt_template"], is_active=row["is_active"], created_at=row["created_at"], updated_at=row["updated_at"], ) for row in rows ] finally: conn.close() @router.put("/tenants/{tenant_id}/{prompt_id}", response_model=ResponseModel, summary="更新租户自定义提示词") async def update_tenant_prompt( tenant_id: int, prompt_id: int, data: TenantPromptUpdate, admin: AdminUserInfo = Depends(get_current_admin), ): """创建或更新租户的自定义提示词""" conn = get_db_connection() try: with conn.cursor() as cursor: # 验证租户存在 cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,)) tenant = cursor.fetchone() if not tenant: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="租户不存在", ) # 验证提示词存在 cursor.execute("SELECT name FROM ai_prompts WHERE id = %s", (prompt_id,)) prompt = cursor.fetchone() if not prompt: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="提示词不存在", ) # 检查是否已有自定义 cursor.execute( """ SELECT id FROM tenant_prompts WHERE tenant_id = %s AND prompt_id = %s """, (tenant_id, prompt_id) ) existing = cursor.fetchone() if existing: # 更新 update_fields = [] update_values = [] if data.system_prompt is not None: update_fields.append("system_prompt = %s") update_values.append(data.system_prompt) if data.user_prompt_template is not None: update_fields.append("user_prompt_template = %s") update_values.append(data.user_prompt_template) if data.is_active is not None: update_fields.append("is_active = %s") update_values.append(data.is_active) if update_fields: update_fields.append("updated_by = %s") update_values.append(admin.id) update_values.append(existing["id"]) cursor.execute( f"UPDATE tenant_prompts SET {', '.join(update_fields)} WHERE id = %s", update_values ) else: # 创建 cursor.execute( """ INSERT INTO tenant_prompts (tenant_id, prompt_id, system_prompt, user_prompt_template, is_active, created_by) VALUES (%s, %s, %s, %s, %s, %s) """, (tenant_id, prompt_id, data.system_prompt, data.user_prompt_template, data.is_active if data.is_active is not None else True, admin.id) ) # 记录操作日志 log_operation( cursor, admin, tenant_id, tenant["code"], "update", "tenant_prompt", prompt_id, prompt["name"], new_value=data.model_dump(exclude_unset=True) ) conn.commit() return ResponseModel(message="自定义提示词已保存") finally: conn.close() @router.delete("/tenants/{tenant_id}/{prompt_id}", response_model=ResponseModel, summary="删除租户自定义提示词") async def delete_tenant_prompt( tenant_id: int, prompt_id: int, admin: AdminUserInfo = Depends(get_current_admin), ): """删除租户的自定义提示词(恢复使用默认)""" conn = get_db_connection() try: with conn.cursor() as cursor: cursor.execute( """ DELETE FROM tenant_prompts WHERE tenant_id = %s AND prompt_id = %s """, (tenant_id, prompt_id) ) if cursor.rowcount == 0: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="自定义提示词不存在", ) conn.commit() return ResponseModel(message="自定义提示词已删除,将使用默认模板") finally: conn.close()