- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
638 lines
23 KiB
Python
638 lines
23 KiB
Python
"""
|
||
AI 提示词管理 API
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
from typing import Optional, List
|
||
|
||
import pymysql
|
||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||
|
||
from .auth import get_current_admin, require_superadmin, get_db_connection, AdminUserInfo
|
||
from .schemas import (
|
||
AIPromptCreate,
|
||
AIPromptUpdate,
|
||
AIPromptResponse,
|
||
AIPromptVersionResponse,
|
||
TenantPromptResponse,
|
||
TenantPromptUpdate,
|
||
ResponseModel,
|
||
)
|
||
|
||
router = APIRouter(prefix="/prompts", tags=["提示词管理"])
|
||
|
||
|
||
def log_operation(cursor, admin: AdminUserInfo, tenant_id: int, tenant_code: str,
|
||
operation_type: str, resource_type: str, resource_id: int,
|
||
resource_name: str, old_value: dict = None, new_value: dict = None):
|
||
"""记录操作日志"""
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO operation_logs
|
||
(admin_user_id, admin_username, tenant_id, tenant_code, operation_type,
|
||
resource_type, resource_id, resource_name, old_value, new_value)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
""",
|
||
(admin.id, admin.username, tenant_id, tenant_code, operation_type,
|
||
resource_type, resource_id, resource_name,
|
||
json.dumps(old_value, ensure_ascii=False) if old_value else None,
|
||
json.dumps(new_value, ensure_ascii=False) if new_value else None)
|
||
)
|
||
|
||
|
||
@router.get("", response_model=List[AIPromptResponse], summary="获取提示词列表")
|
||
async def list_prompts(
|
||
module: Optional[str] = Query(None, description="模块筛选"),
|
||
is_active: Optional[bool] = Query(None, description="是否启用"),
|
||
admin: AdminUserInfo = Depends(get_current_admin),
|
||
):
|
||
"""
|
||
获取所有 AI 提示词模板
|
||
|
||
- **module**: 模块筛选(course, exam, practice, ability)
|
||
- **is_active**: 是否启用
|
||
"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
conditions = []
|
||
params = []
|
||
|
||
if module:
|
||
conditions.append("module = %s")
|
||
params.append(module)
|
||
|
||
if is_active is not None:
|
||
conditions.append("is_active = %s")
|
||
params.append(is_active)
|
||
|
||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||
|
||
cursor.execute(
|
||
f"""
|
||
SELECT * FROM ai_prompts
|
||
WHERE {where_clause}
|
||
ORDER BY module, id
|
||
""",
|
||
params
|
||
)
|
||
rows = cursor.fetchall()
|
||
|
||
result = []
|
||
for row in rows:
|
||
# 解析 JSON 字段
|
||
variables = None
|
||
if row.get("variables"):
|
||
try:
|
||
variables = json.loads(row["variables"])
|
||
except:
|
||
pass
|
||
|
||
output_schema = None
|
||
if row.get("output_schema"):
|
||
try:
|
||
output_schema = json.loads(row["output_schema"])
|
||
except:
|
||
pass
|
||
|
||
result.append(AIPromptResponse(
|
||
id=row["id"],
|
||
code=row["code"],
|
||
name=row["name"],
|
||
description=row["description"],
|
||
module=row["module"],
|
||
system_prompt=row["system_prompt"],
|
||
user_prompt_template=row["user_prompt_template"],
|
||
variables=variables,
|
||
output_schema=output_schema,
|
||
model_recommendation=row["model_recommendation"],
|
||
max_tokens=row["max_tokens"],
|
||
temperature=float(row["temperature"]) if row["temperature"] else 0.7,
|
||
is_system=row["is_system"],
|
||
is_active=row["is_active"],
|
||
version=row["version"],
|
||
created_at=row["created_at"],
|
||
updated_at=row["updated_at"],
|
||
))
|
||
|
||
return result
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@router.get("/{prompt_id}", response_model=AIPromptResponse, summary="获取提示词详情")
|
||
async def get_prompt(
|
||
prompt_id: int,
|
||
admin: AdminUserInfo = Depends(get_current_admin),
|
||
):
|
||
"""获取提示词详情"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,))
|
||
row = cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="提示词不存在",
|
||
)
|
||
|
||
# 解析 JSON 字段
|
||
variables = None
|
||
if row.get("variables"):
|
||
try:
|
||
variables = json.loads(row["variables"])
|
||
except:
|
||
pass
|
||
|
||
output_schema = None
|
||
if row.get("output_schema"):
|
||
try:
|
||
output_schema = json.loads(row["output_schema"])
|
||
except:
|
||
pass
|
||
|
||
return AIPromptResponse(
|
||
id=row["id"],
|
||
code=row["code"],
|
||
name=row["name"],
|
||
description=row["description"],
|
||
module=row["module"],
|
||
system_prompt=row["system_prompt"],
|
||
user_prompt_template=row["user_prompt_template"],
|
||
variables=variables,
|
||
output_schema=output_schema,
|
||
model_recommendation=row["model_recommendation"],
|
||
max_tokens=row["max_tokens"],
|
||
temperature=float(row["temperature"]) if row["temperature"] else 0.7,
|
||
is_system=row["is_system"],
|
||
is_active=row["is_active"],
|
||
version=row["version"],
|
||
created_at=row["created_at"],
|
||
updated_at=row["updated_at"],
|
||
)
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@router.post("", response_model=AIPromptResponse, summary="创建提示词")
|
||
async def create_prompt(
|
||
data: AIPromptCreate,
|
||
admin: AdminUserInfo = Depends(require_superadmin),
|
||
):
|
||
"""
|
||
创建新的提示词模板
|
||
|
||
需要超级管理员权限
|
||
"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
# 检查编码是否已存在
|
||
cursor.execute("SELECT id FROM ai_prompts WHERE code = %s", (data.code,))
|
||
if cursor.fetchone():
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="提示词编码已存在",
|
||
)
|
||
|
||
# 创建提示词
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO ai_prompts
|
||
(code, name, description, module, system_prompt, user_prompt_template,
|
||
variables, output_schema, model_recommendation, max_tokens, temperature,
|
||
is_system, created_by)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, FALSE, %s)
|
||
""",
|
||
(data.code, data.name, data.description, data.module,
|
||
data.system_prompt, data.user_prompt_template,
|
||
json.dumps(data.variables) if data.variables else None,
|
||
json.dumps(data.output_schema) if data.output_schema else None,
|
||
data.model_recommendation, data.max_tokens, data.temperature,
|
||
admin.id)
|
||
)
|
||
prompt_id = cursor.lastrowid
|
||
|
||
# 记录操作日志
|
||
log_operation(
|
||
cursor, admin, None, None,
|
||
"create", "prompt", prompt_id, data.name,
|
||
new_value=data.model_dump()
|
||
)
|
||
|
||
conn.commit()
|
||
|
||
return await get_prompt(prompt_id, admin)
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@router.put("/{prompt_id}", response_model=AIPromptResponse, summary="更新提示词")
|
||
async def update_prompt(
|
||
prompt_id: int,
|
||
data: AIPromptUpdate,
|
||
admin: AdminUserInfo = Depends(get_current_admin),
|
||
):
|
||
"""
|
||
更新提示词模板
|
||
|
||
更新会自动保存版本历史
|
||
"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
# 获取原提示词
|
||
cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,))
|
||
old_prompt = cursor.fetchone()
|
||
|
||
if not old_prompt:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="提示词不存在",
|
||
)
|
||
|
||
# 保存版本历史(如果系统提示词或用户提示词有变化)
|
||
if data.system_prompt or data.user_prompt_template:
|
||
new_version = old_prompt["version"] + 1
|
||
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO ai_prompt_versions
|
||
(prompt_id, version, system_prompt, user_prompt_template, variables,
|
||
output_schema, change_summary, created_by)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||
""",
|
||
(prompt_id, old_prompt["version"],
|
||
old_prompt["system_prompt"], old_prompt["user_prompt_template"],
|
||
old_prompt["variables"], old_prompt["output_schema"],
|
||
f"版本 {old_prompt['version']} 备份",
|
||
admin.id)
|
||
)
|
||
else:
|
||
new_version = old_prompt["version"]
|
||
|
||
# 构建更新语句
|
||
update_fields = []
|
||
update_values = []
|
||
|
||
if data.name is not None:
|
||
update_fields.append("name = %s")
|
||
update_values.append(data.name)
|
||
|
||
if data.description is not None:
|
||
update_fields.append("description = %s")
|
||
update_values.append(data.description)
|
||
|
||
if data.system_prompt is not None:
|
||
update_fields.append("system_prompt = %s")
|
||
update_values.append(data.system_prompt)
|
||
|
||
if data.user_prompt_template is not None:
|
||
update_fields.append("user_prompt_template = %s")
|
||
update_values.append(data.user_prompt_template)
|
||
|
||
if data.variables is not None:
|
||
update_fields.append("variables = %s")
|
||
update_values.append(json.dumps(data.variables))
|
||
|
||
if data.output_schema is not None:
|
||
update_fields.append("output_schema = %s")
|
||
update_values.append(json.dumps(data.output_schema))
|
||
|
||
if data.model_recommendation is not None:
|
||
update_fields.append("model_recommendation = %s")
|
||
update_values.append(data.model_recommendation)
|
||
|
||
if data.max_tokens is not None:
|
||
update_fields.append("max_tokens = %s")
|
||
update_values.append(data.max_tokens)
|
||
|
||
if data.temperature is not None:
|
||
update_fields.append("temperature = %s")
|
||
update_values.append(data.temperature)
|
||
|
||
if data.is_active is not None:
|
||
update_fields.append("is_active = %s")
|
||
update_values.append(data.is_active)
|
||
|
||
if not update_fields:
|
||
return await get_prompt(prompt_id, admin)
|
||
|
||
# 更新版本号
|
||
if data.system_prompt or data.user_prompt_template:
|
||
update_fields.append("version = %s")
|
||
update_values.append(new_version)
|
||
|
||
update_fields.append("updated_by = %s")
|
||
update_values.append(admin.id)
|
||
update_values.append(prompt_id)
|
||
|
||
cursor.execute(
|
||
f"UPDATE ai_prompts SET {', '.join(update_fields)} WHERE id = %s",
|
||
update_values
|
||
)
|
||
|
||
# 记录操作日志
|
||
log_operation(
|
||
cursor, admin, None, None,
|
||
"update", "prompt", prompt_id, old_prompt["name"],
|
||
old_value={"version": old_prompt["version"]},
|
||
new_value=data.model_dump(exclude_unset=True)
|
||
)
|
||
|
||
conn.commit()
|
||
|
||
return await get_prompt(prompt_id, admin)
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@router.get("/{prompt_id}/versions", response_model=List[AIPromptVersionResponse], summary="获取提示词版本历史")
|
||
async def get_prompt_versions(
|
||
prompt_id: int,
|
||
admin: AdminUserInfo = Depends(get_current_admin),
|
||
):
|
||
"""获取提示词的版本历史"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
"""
|
||
SELECT * FROM ai_prompt_versions
|
||
WHERE prompt_id = %s
|
||
ORDER BY version DESC
|
||
""",
|
||
(prompt_id,)
|
||
)
|
||
rows = cursor.fetchall()
|
||
|
||
result = []
|
||
for row in rows:
|
||
variables = None
|
||
if row.get("variables"):
|
||
try:
|
||
variables = json.loads(row["variables"])
|
||
except:
|
||
pass
|
||
|
||
result.append(AIPromptVersionResponse(
|
||
id=row["id"],
|
||
prompt_id=row["prompt_id"],
|
||
version=row["version"],
|
||
system_prompt=row["system_prompt"],
|
||
user_prompt_template=row["user_prompt_template"],
|
||
variables=variables,
|
||
change_summary=row["change_summary"],
|
||
created_at=row["created_at"],
|
||
))
|
||
|
||
return result
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@router.post("/{prompt_id}/rollback/{version}", response_model=AIPromptResponse, summary="回滚提示词版本")
|
||
async def rollback_prompt_version(
|
||
prompt_id: int,
|
||
version: int,
|
||
admin: AdminUserInfo = Depends(get_current_admin),
|
||
):
|
||
"""回滚到指定版本的提示词"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
# 获取指定版本
|
||
cursor.execute(
|
||
"""
|
||
SELECT * FROM ai_prompt_versions
|
||
WHERE prompt_id = %s AND version = %s
|
||
""",
|
||
(prompt_id, version)
|
||
)
|
||
version_row = cursor.fetchone()
|
||
|
||
if not version_row:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="版本不存在",
|
||
)
|
||
|
||
# 获取当前提示词
|
||
cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,))
|
||
current = cursor.fetchone()
|
||
|
||
if not current:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="提示词不存在",
|
||
)
|
||
|
||
# 保存当前版本到历史
|
||
new_version = current["version"] + 1
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO ai_prompt_versions
|
||
(prompt_id, version, system_prompt, user_prompt_template, variables,
|
||
output_schema, change_summary, created_by)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||
""",
|
||
(prompt_id, current["version"],
|
||
current["system_prompt"], current["user_prompt_template"],
|
||
current["variables"], current["output_schema"],
|
||
f"回滚前备份(版本 {current['version']})",
|
||
admin.id)
|
||
)
|
||
|
||
# 回滚
|
||
cursor.execute(
|
||
"""
|
||
UPDATE ai_prompts
|
||
SET system_prompt = %s, user_prompt_template = %s, variables = %s,
|
||
output_schema = %s, version = %s, updated_by = %s
|
||
WHERE id = %s
|
||
""",
|
||
(version_row["system_prompt"], version_row["user_prompt_template"],
|
||
version_row["variables"], version_row["output_schema"],
|
||
new_version, admin.id, prompt_id)
|
||
)
|
||
|
||
# 记录操作日志
|
||
log_operation(
|
||
cursor, admin, None, None,
|
||
"rollback", "prompt", prompt_id, current["name"],
|
||
old_value={"version": current["version"]},
|
||
new_value={"version": new_version, "rollback_from": version}
|
||
)
|
||
|
||
conn.commit()
|
||
|
||
return await get_prompt(prompt_id, admin)
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@router.get("/tenants/{tenant_id}", response_model=List[TenantPromptResponse], summary="获取租户自定义提示词")
|
||
async def get_tenant_prompts(
|
||
tenant_id: int,
|
||
admin: AdminUserInfo = Depends(get_current_admin),
|
||
):
|
||
"""获取租户的自定义提示词列表"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
"""
|
||
SELECT tp.*, ap.code as prompt_code, ap.name as prompt_name
|
||
FROM tenant_prompts tp
|
||
JOIN ai_prompts ap ON tp.prompt_id = ap.id
|
||
WHERE tp.tenant_id = %s
|
||
ORDER BY ap.module, ap.id
|
||
""",
|
||
(tenant_id,)
|
||
)
|
||
rows = cursor.fetchall()
|
||
|
||
return [
|
||
TenantPromptResponse(
|
||
id=row["id"],
|
||
tenant_id=row["tenant_id"],
|
||
prompt_id=row["prompt_id"],
|
||
prompt_code=row["prompt_code"],
|
||
prompt_name=row["prompt_name"],
|
||
system_prompt=row["system_prompt"],
|
||
user_prompt_template=row["user_prompt_template"],
|
||
is_active=row["is_active"],
|
||
created_at=row["created_at"],
|
||
updated_at=row["updated_at"],
|
||
)
|
||
for row in rows
|
||
]
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@router.put("/tenants/{tenant_id}/{prompt_id}", response_model=ResponseModel, summary="更新租户自定义提示词")
|
||
async def update_tenant_prompt(
|
||
tenant_id: int,
|
||
prompt_id: int,
|
||
data: TenantPromptUpdate,
|
||
admin: AdminUserInfo = Depends(get_current_admin),
|
||
):
|
||
"""创建或更新租户的自定义提示词"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
# 验证租户存在
|
||
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
|
||
tenant = cursor.fetchone()
|
||
if not tenant:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="租户不存在",
|
||
)
|
||
|
||
# 验证提示词存在
|
||
cursor.execute("SELECT name FROM ai_prompts WHERE id = %s", (prompt_id,))
|
||
prompt = cursor.fetchone()
|
||
if not prompt:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="提示词不存在",
|
||
)
|
||
|
||
# 检查是否已有自定义
|
||
cursor.execute(
|
||
"""
|
||
SELECT id FROM tenant_prompts
|
||
WHERE tenant_id = %s AND prompt_id = %s
|
||
""",
|
||
(tenant_id, prompt_id)
|
||
)
|
||
existing = cursor.fetchone()
|
||
|
||
if existing:
|
||
# 更新
|
||
update_fields = []
|
||
update_values = []
|
||
|
||
if data.system_prompt is not None:
|
||
update_fields.append("system_prompt = %s")
|
||
update_values.append(data.system_prompt)
|
||
|
||
if data.user_prompt_template is not None:
|
||
update_fields.append("user_prompt_template = %s")
|
||
update_values.append(data.user_prompt_template)
|
||
|
||
if data.is_active is not None:
|
||
update_fields.append("is_active = %s")
|
||
update_values.append(data.is_active)
|
||
|
||
if update_fields:
|
||
update_fields.append("updated_by = %s")
|
||
update_values.append(admin.id)
|
||
update_values.append(existing["id"])
|
||
|
||
cursor.execute(
|
||
f"UPDATE tenant_prompts SET {', '.join(update_fields)} WHERE id = %s",
|
||
update_values
|
||
)
|
||
else:
|
||
# 创建
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO tenant_prompts
|
||
(tenant_id, prompt_id, system_prompt, user_prompt_template, is_active, created_by)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
""",
|
||
(tenant_id, prompt_id, data.system_prompt, data.user_prompt_template,
|
||
data.is_active if data.is_active is not None else True, admin.id)
|
||
)
|
||
|
||
# 记录操作日志
|
||
log_operation(
|
||
cursor, admin, tenant_id, tenant["code"],
|
||
"update", "tenant_prompt", prompt_id, prompt["name"],
|
||
new_value=data.model_dump(exclude_unset=True)
|
||
)
|
||
|
||
conn.commit()
|
||
|
||
return ResponseModel(message="自定义提示词已保存")
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@router.delete("/tenants/{tenant_id}/{prompt_id}", response_model=ResponseModel, summary="删除租户自定义提示词")
|
||
async def delete_tenant_prompt(
|
||
tenant_id: int,
|
||
prompt_id: int,
|
||
admin: AdminUserInfo = Depends(get_current_admin),
|
||
):
|
||
"""删除租户的自定义提示词(恢复使用默认)"""
|
||
conn = get_db_connection()
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
"""
|
||
DELETE FROM tenant_prompts
|
||
WHERE tenant_id = %s AND prompt_id = %s
|
||
""",
|
||
(tenant_id, prompt_id)
|
||
)
|
||
|
||
if cursor.rowcount == 0:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="自定义提示词不存在",
|
||
)
|
||
|
||
conn.commit()
|
||
|
||
return ResponseModel(message="自定义提示词已删除,将使用默认模板")
|
||
finally:
|
||
conn.close()
|
||
|