diff --git a/TEST_REPORT_2026-01-31.md b/TEST_REPORT_2026-01-31.md index 7bd9382..e9c9e16 100644 --- a/TEST_REPORT_2026-01-31.md +++ b/TEST_REPORT_2026-01-31.md @@ -1,155 +1,155 @@ -# KPL 考培练系统测试报告 - -**测试环境**: dev (https://kpl.ireborn.com.cn) -**测试时间**: 2026-01-31 -**测试人员**: AI 自动化测试系统 - ---- - -## 一、测试概要 - -| 模块 | 测试用例数 | 通过 | 失败 | 警告 | -|------|-----------|------|------|------| -| 认证模块 | 7 | 5 | 2 | 0 | -| 课程管理 | 7 | 7 | 0 | 0 | -| 成长路径 | 4 | 4 | 0 | 0 | -| 岗位管理 | 2 | 2 | 0 | 0 | -| 考试模块 | 3 | 2 | 1 | 0 | -| AI练习 | 3 | 2 | 0 | 1 | -| 通知系统 | 2 | 2 | 0 | 0 | -| 极端边界 | 8 | 7 | 0 | 1 | -| 安全测试 | 7 | 5 | 0 | 2 | -| **合计** | **43** | **36** | **3** | **4** | - -**通过率**: 83.7% - ---- - -## 二、发现的问题 - -### 严重 (High) - -#### 1. 错误密码登录返回200 -- **位置**: `POST /api/v1/auth/login` -- **描述**: 使用错误密码登录时返回 HTTP 200,应返回 401 -- **影响**: 可能导致暴力破解攻击难以被检测 -- **建议**: 检查登录逻辑,确保密码错误时返回 401 - -#### 2. XSS 内容被原样存储 -- **位置**: `POST /api/v1/courses` (name, description 字段) -- **描述**: `` 等 XSS 代码被原样存入数据库 -- **影响**: 潜在的存储型 XSS 攻击风险 -- **建议**: - - 输入时转义或过滤 HTML 标签 - - 输出时使用 HTML 实体编码 - -### 中等 (Medium) - -#### 3. 不存在用户登录返回422 -- **位置**: `POST /api/v1/auth/login` -- **描述**: 登录不存在的用户返回 422,应返回 401 -- **影响**: 用户枚举风险(可判断用户是否存在) -- **建议**: 统一返回 401 "用户名或密码错误" - -#### 4. API 限流未配置 -- **位置**: 全局 -- **描述**: 10次快速请求未触发限流 -- **影响**: 可能被恶意请求攻击 -- **建议**: 配置 API 限流中间件 - -### 低等 (Low) - -#### 5. 越权访问返回404而非403 -- **位置**: `GET /api/v1/admin/users` -- **描述**: 普通用户访问管理接口返回 404 而非 403 -- **影响**: 信息泄露(可探测接口是否存在) -- **建议**: 统一返回 403 Forbidden - -#### 6. 部分API端点404 -- **位置**: - - `GET /api/v1/exams` (考试列表) - - `GET /api/v1/practice/sessions` (练习记录) -- **描述**: 这些端点返回 404,可能是路径变更或未实现 -- **建议**: 确认 API 路径或补充实现 - ---- - -## 三、测试详情 - -### 3.1 认证模块测试 - -| 测试项 | 结果 | 说明 | -|--------|------|------| -| 正常登录 | ✓ PASS | HTTP 200, Token 获取成功 | -| 错误密码登录 | ✗ FAIL | HTTP 200 (应返回401) | -| 不存在用户登录 | ✗ FAIL | HTTP 422 (应返回401) | -| Token验证 | ✓ PASS | HTTP 200 | -| 无效Token访问 | ✓ PASS | HTTP 401 | -| 无Token访问 | ✓ PASS | HTTP 403 | -| 获取用户信息 | ✓ PASS | HTTP 200 | - -### 3.2 课程管理测试 - -| 测试项 | 结果 | 说明 | -|--------|------|------| -| 获取课程列表 | ✓ PASS | 总课程数: 16 | -| 创建课程 | ✓ PASS | HTTP 201 | -| 获取课程详情 | ✓ PASS | HTTP 200 | -| 更新课程 | ✓ PASS | HTTP 200 | -| 获取考试设置 | ✓ PASS | HTTP 200 | -| 更新考试设置 | ✓ PASS | HTTP 200 | -| 获取不存在课程 | ✓ PASS | HTTP 404 | - -### 3.3 极端边界测试 - -| 测试项 | 结果 | 说明 | -|--------|------|------| -| 空名称创建课程 | ✓ PASS | 正确返回 422 | -| 超长名称(1000字符) | ✓ PASS | 正确返回 422 | -| XSS注入 | ⚠ WARN | 内容被原样存储 | -| SQL注入 | ✓ PASS | 注入被防护 | -| 负数分页参数 | ✓ PASS | 正确返回 422 | -| 超大分页(10000) | ✓ PASS | 正确返回 422 | -| Unicode/Emoji | ✓ PASS | 正确处理 | -| 特殊字符 | ✓ PASS | 正确处理 | - -### 3.4 安全测试 - -| 测试项 | 结果 | 说明 | -|--------|------|------| -| 越权访问 | ⚠ WARN | 返回404而非403 | -| 伪造Token | ✓ PASS | 正确拒绝 | -| 过期Token | ✓ PASS | 正确拒绝 | -| 访问他人数据 | ✓ PASS | 访问被限制 | -| 敏感信息泄露 | ✓ PASS | 未泄露密码/Token | -| API限流 | ⚠ INFO | 未触发限流 | -| 目录遍历 | ✓ PASS | 攻击被阻止 | - ---- - -## 四、修复建议优先级 - -### P0 - 立即修复 -1. 修复错误密码登录返回200的问题 -2. 添加 XSS 输入过滤/输出编码 - -### P1 - 尽快修复 -3. 统一登录错误响应码为401 -4. 配置 API 限流保护 - -### P2 - 计划修复 -5. 越权访问统一返回403 -6. 确认并修复404的API端点 - ---- - -## 五、测试环境信息 - -- **后端容器**: kpl-backend-dev -- **数据库**: MySQL 8.0 -- **测试账号**: admin / admin123 -- **测试时间**: 2026-01-31 10:30 UTC+8 - ---- - -*本报告由自动化测试系统生成* +# KPL 考培练系统测试报告 + +**测试环境**: dev (https://kpl.ireborn.com.cn) +**测试时间**: 2026-01-31 +**测试人员**: AI 自动化测试系统 + +--- + +## 一、测试概要 + +| 模块 | 测试用例数 | 通过 | 失败 | 警告 | +|------|-----------|------|------|------| +| 认证模块 | 7 | 5 | 2 | 0 | +| 课程管理 | 7 | 7 | 0 | 0 | +| 成长路径 | 4 | 4 | 0 | 0 | +| 岗位管理 | 2 | 2 | 0 | 0 | +| 考试模块 | 3 | 2 | 1 | 0 | +| AI练习 | 3 | 2 | 0 | 1 | +| 通知系统 | 2 | 2 | 0 | 0 | +| 极端边界 | 8 | 7 | 0 | 1 | +| 安全测试 | 7 | 5 | 0 | 2 | +| **合计** | **43** | **36** | **3** | **4** | + +**通过率**: 83.7% + +--- + +## 二、发现的问题 + +### 严重 (High) + +#### 1. 错误密码登录返回200 +- **位置**: `POST /api/v1/auth/login` +- **描述**: 使用错误密码登录时返回 HTTP 200,应返回 401 +- **影响**: 可能导致暴力破解攻击难以被检测 +- **建议**: 检查登录逻辑,确保密码错误时返回 401 + +#### 2. XSS 内容被原样存储 +- **位置**: `POST /api/v1/courses` (name, description 字段) +- **描述**: `` 等 XSS 代码被原样存入数据库 +- **影响**: 潜在的存储型 XSS 攻击风险 +- **建议**: + - 输入时转义或过滤 HTML 标签 + - 输出时使用 HTML 实体编码 + +### 中等 (Medium) + +#### 3. 不存在用户登录返回422 +- **位置**: `POST /api/v1/auth/login` +- **描述**: 登录不存在的用户返回 422,应返回 401 +- **影响**: 用户枚举风险(可判断用户是否存在) +- **建议**: 统一返回 401 "用户名或密码错误" + +#### 4. API 限流未配置 +- **位置**: 全局 +- **描述**: 10次快速请求未触发限流 +- **影响**: 可能被恶意请求攻击 +- **建议**: 配置 API 限流中间件 + +### 低等 (Low) + +#### 5. 越权访问返回404而非403 +- **位置**: `GET /api/v1/admin/users` +- **描述**: 普通用户访问管理接口返回 404 而非 403 +- **影响**: 信息泄露(可探测接口是否存在) +- **建议**: 统一返回 403 Forbidden + +#### 6. 部分API端点404 +- **位置**: + - `GET /api/v1/exams` (考试列表) + - `GET /api/v1/practice/sessions` (练习记录) +- **描述**: 这些端点返回 404,可能是路径变更或未实现 +- **建议**: 确认 API 路径或补充实现 + +--- + +## 三、测试详情 + +### 3.1 认证模块测试 + +| 测试项 | 结果 | 说明 | +|--------|------|------| +| 正常登录 | ✓ PASS | HTTP 200, Token 获取成功 | +| 错误密码登录 | ✗ FAIL | HTTP 200 (应返回401) | +| 不存在用户登录 | ✗ FAIL | HTTP 422 (应返回401) | +| Token验证 | ✓ PASS | HTTP 200 | +| 无效Token访问 | ✓ PASS | HTTP 401 | +| 无Token访问 | ✓ PASS | HTTP 403 | +| 获取用户信息 | ✓ PASS | HTTP 200 | + +### 3.2 课程管理测试 + +| 测试项 | 结果 | 说明 | +|--------|------|------| +| 获取课程列表 | ✓ PASS | 总课程数: 16 | +| 创建课程 | ✓ PASS | HTTP 201 | +| 获取课程详情 | ✓ PASS | HTTP 200 | +| 更新课程 | ✓ PASS | HTTP 200 | +| 获取考试设置 | ✓ PASS | HTTP 200 | +| 更新考试设置 | ✓ PASS | HTTP 200 | +| 获取不存在课程 | ✓ PASS | HTTP 404 | + +### 3.3 极端边界测试 + +| 测试项 | 结果 | 说明 | +|--------|------|------| +| 空名称创建课程 | ✓ PASS | 正确返回 422 | +| 超长名称(1000字符) | ✓ PASS | 正确返回 422 | +| XSS注入 | ⚠ WARN | 内容被原样存储 | +| SQL注入 | ✓ PASS | 注入被防护 | +| 负数分页参数 | ✓ PASS | 正确返回 422 | +| 超大分页(10000) | ✓ PASS | 正确返回 422 | +| Unicode/Emoji | ✓ PASS | 正确处理 | +| 特殊字符 | ✓ PASS | 正确处理 | + +### 3.4 安全测试 + +| 测试项 | 结果 | 说明 | +|--------|------|------| +| 越权访问 | ⚠ WARN | 返回404而非403 | +| 伪造Token | ✓ PASS | 正确拒绝 | +| 过期Token | ✓ PASS | 正确拒绝 | +| 访问他人数据 | ✓ PASS | 访问被限制 | +| 敏感信息泄露 | ✓ PASS | 未泄露密码/Token | +| API限流 | ⚠ INFO | 未触发限流 | +| 目录遍历 | ✓ PASS | 攻击被阻止 | + +--- + +## 四、修复建议优先级 + +### P0 - 立即修复 +1. 修复错误密码登录返回200的问题 +2. 添加 XSS 输入过滤/输出编码 + +### P1 - 尽快修复 +3. 统一登录错误响应码为401 +4. 配置 API 限流保护 + +### P2 - 计划修复 +5. 越权访问统一返回403 +6. 确认并修复404的API端点 + +--- + +## 五、测试环境信息 + +- **后端容器**: kpl-backend-dev +- **数据库**: MySQL 8.0 +- **测试账号**: admin / admin123 +- **测试时间**: 2026-01-31 10:30 UTC+8 + +--- + +*本报告由自动化测试系统生成* diff --git a/backend/.env.ex b/backend/.env.ex index bf227f6..a27766d 100644 --- a/backend/.env.ex +++ b/backend/.env.ex @@ -41,7 +41,7 @@ UPLOAD_DIR=uploads COZE_OAUTH_CLIENT_ID=1114009328887 COZE_OAUTH_PUBLIC_KEY_ID=GGs9pw0BDHx2k9vGGehUyRgKV-PyUWLBncDs-YNNN_I COZE_OAUTH_PRIVATE_KEY_PATH=/app/secrets/coze_private_key.pem -COZE_PRACTICE_BOT_ID=7560643598174683145 +COZE_PRACTICE_BOT_ID=7602204855037591602 # Dify 工作流 API Key 配置 # 01-知识点分析 diff --git a/backend/app/api/v1/preview.py b/backend/app/api/v1/preview.py index 0287951..f4fd6d7 100644 --- a/backend/app/api/v1/preview.py +++ b/backend/app/api/v1/preview.py @@ -1,6 +1,8 @@ """ 文件预览API 提供课程资料的在线预览功能 + +支持MinIO和本地文件系统两种存储后端 """ import logging from pathlib import Path @@ -15,6 +17,7 @@ from app.core.config import settings from app.models.user import User from app.models.course import CourseMaterial from app.services.document_converter import document_converter +from app.services.storage_service import storage_service logger = logging.getLogger(__name__) router = APIRouter() @@ -81,10 +84,12 @@ def get_preview_type(file_ext: str) -> str: return PreviewType.DOWNLOAD -def get_file_path_from_url(file_url: str) -> Optional[Path]: +async def get_file_path_from_url(file_url: str) -> Optional[Path]: """ 从文件URL获取本地文件路径 + 支持MinIO和本地文件系统。如果文件在MinIO中,会先下载到本地缓存。 + Args: file_url: 文件URL(如 /static/uploads/courses/1/xxx.pdf) @@ -94,11 +99,12 @@ def get_file_path_from_url(file_url: str) -> Optional[Path]: try: # 移除 /static/uploads/ 前缀 if file_url.startswith('/static/uploads/'): - relative_path = file_url.replace('/static/uploads/', '') - full_path = Path(settings.UPLOAD_PATH) / relative_path - return full_path + object_name = file_url.replace('/static/uploads/', '') + # 使用storage_service获取文件路径(自动处理MinIO下载) + return await storage_service.get_file_path(object_name) return None - except Exception: + except Exception as e: + logger.error(f"获取文件路径失败: {e}") return None @@ -158,7 +164,7 @@ async def get_material_preview( # 根据预览类型处理 if preview_type == PreviewType.TEXT: # 文本类型,读取文件内容 - file_path = get_file_path_from_url(material.file_url) + file_path = await get_file_path_from_url(material.file_url) if file_path and file_path.exists(): try: with open(file_path, 'r', encoding='utf-8') as f: @@ -176,7 +182,7 @@ async def get_material_preview( elif preview_type == PreviewType.EXCEL_HTML: # Excel文件转换为HTML预览 - file_path = get_file_path_from_url(material.file_url) + file_path = await get_file_path_from_url(material.file_url) if file_path and file_path.exists(): converted_url = document_converter.convert_excel_to_html( str(file_path), @@ -200,7 +206,7 @@ async def get_material_preview( elif preview_type == PreviewType.PDF and document_converter.is_convertible(file_ext): # Office文档,需要转换为PDF - file_path = get_file_path_from_url(material.file_url) + file_path = await get_file_path_from_url(material.file_url) if file_path and file_path.exists(): # 执行转换 converted_url = document_converter.convert_to_pdf( diff --git a/backend/app/api/v1/upload.py b/backend/app/api/v1/upload.py index 3f91c81..3da068c 100644 --- a/backend/app/api/v1/upload.py +++ b/backend/app/api/v1/upload.py @@ -1,5 +1,9 @@ """ 文件上传API接口 + +支持两种存储后端: +1. MinIO对象存储(生产环境推荐) +2. 本地文件系统(开发环境或降级方案) """ import os import shutil @@ -17,6 +21,7 @@ from app.models.user import User from app.models.course import Course from app.schemas.base import ResponseModel from app.core.logger import get_logger +from app.services.storage_service import storage_service logger = get_logger(__name__) @@ -93,16 +98,13 @@ async def upload_file( # 生成唯一文件名 unique_filename = generate_unique_filename(file.filename) - # 获取上传路径 - upload_path = get_upload_path(file_type) - file_path = upload_path / unique_filename - - # 保存文件 - with open(file_path, "wb") as f: - f.write(contents) - - # 生成文件访问URL - file_url = f"/static/uploads/{file_type}/{unique_filename}" + # 使用storage_service上传文件 + object_name = f"{file_type}/{unique_filename}" + file_url = await storage_service.upload( + contents, + object_name, + content_type=file.content_type + ) logger.info( "文件上传成功", @@ -111,6 +113,7 @@ async def upload_file( saved_filename=unique_filename, file_size=file_size, file_type=file_type, + storage="minio" if storage_service.is_minio_enabled else "local", ) return ResponseModel( @@ -184,17 +187,13 @@ async def upload_course_material( # 生成唯一文件名 unique_filename = generate_unique_filename(file.filename) - # 创建课程专属目录 - course_upload_path = Path(settings.UPLOAD_PATH) / "courses" / str(course_id) - course_upload_path.mkdir(parents=True, exist_ok=True) - - # 保存文件 - file_path = course_upload_path / unique_filename - with open(file_path, "wb") as f: - f.write(contents) - - # 生成文件访问URL - file_url = f"/static/uploads/courses/{course_id}/{unique_filename}" + # 使用storage_service上传文件 + object_name = f"courses/{course_id}/{unique_filename}" + file_url = await storage_service.upload( + contents, + object_name, + content_type=file.content_type + ) logger.info( "课程资料上传成功", @@ -203,6 +202,7 @@ async def upload_course_material( original_filename=file.filename, saved_filename=unique_filename, file_size=file_size, + storage="minio" if storage_service.is_minio_enabled else "local", ) return ResponseModel( @@ -243,24 +243,24 @@ async def delete_file( detail="无效的文件URL" ) - # 转换为实际文件路径 - relative_path = file_url.replace("/static/uploads/", "") - file_path = Path(settings.UPLOAD_PATH) / relative_path + # 从URL中提取对象名称 + object_name = file_url.replace("/static/uploads/", "") # 检查文件是否存在 - if not file_path.exists(): + if not await storage_service.exists(object_name): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在" ) - # 删除文件 - os.remove(file_path) + # 使用storage_service删除文件 + await storage_service.delete(object_name) logger.info( "文件删除成功", user_id=current_user.id, file_url=file_url, + storage="minio" if storage_service.is_minio_enabled else "local", ) return ResponseModel(data=True, message="文件删除成功") diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 012c9b8..e458c04 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -106,6 +106,14 @@ class Settings(BaseSettings): """获取上传文件的完整路径""" import os return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.UPLOAD_DIR) + + # MinIO对象存储配置 + MINIO_ENABLED: bool = Field(default=True, description="是否启用MinIO存储") + MINIO_ENDPOINT: str = Field(default="kaopeilian-minio:9000", description="MinIO服务地址") + MINIO_ACCESS_KEY: str = Field(default="kaopeilian_admin", description="MinIO访问密钥") + MINIO_SECRET_KEY: str = Field(default="KplMinio2026!@#", description="MinIO秘密密钥") + MINIO_SECURE: bool = Field(default=False, description="是否使用HTTPS") + MINIO_PUBLIC_URL: str = Field(default="", description="MinIO公开访问URL(留空则使用Nginx代理)") # Coze 平台配置(陪练对话、播课等) COZE_API_BASE: Optional[str] = Field(default="https://api.coze.cn") diff --git a/backend/app/core/sanitize.py b/backend/app/core/sanitize.py index 75b389b..ac2cb0c 100644 --- a/backend/app/core/sanitize.py +++ b/backend/app/core/sanitize.py @@ -1,136 +1,136 @@ -""" -输入清理和XSS防护工具 -""" -import re -import html -from typing import Optional - - -# 危险的HTML标签和属性 -DANGEROUS_TAGS = [ - 'script', 'iframe', 'object', 'embed', 'form', 'input', - 'textarea', 'button', 'select', 'style', 'link', 'meta', - 'base', 'applet', 'frame', 'frameset', 'layer', 'ilayer', - 'bgsound', 'xml', 'blink', 'marquee' -] - -DANGEROUS_ATTRS = [ - 'onclick', 'ondblclick', 'onmousedown', 'onmouseup', 'onmouseover', - 'onmousemove', 'onmouseout', 'onkeypress', 'onkeydown', 'onkeyup', - 'onload', 'onerror', 'onabort', 'onblur', 'onchange', 'onfocus', - 'onreset', 'onsubmit', 'onunload', 'onbeforeunload', 'onresize', - 'onscroll', 'ondrag', 'ondragend', 'ondragenter', 'ondragleave', - 'ondragover', 'ondragstart', 'ondrop', 'onmousewheel', 'onwheel', - 'oncopy', 'oncut', 'onpaste', 'oncontextmenu', 'oninput', 'oninvalid', - 'onsearch', 'onselect', 'ontoggle', 'formaction', 'xlink:href' -] - - -def sanitize_html(text: Optional[str]) -> Optional[str]: - """ - 清理HTML内容,移除危险标签和属性 - - Args: - text: 输入文本 - - Returns: - 清理后的文本 - """ - if text is None: - return None - - if not isinstance(text, str): - return text - - result = text - - # 移除危险标签 - for tag in DANGEROUS_TAGS: - # 移除开标签 - pattern = re.compile(rf'<{tag}[^>]*>', re.IGNORECASE) - result = pattern.sub('', result) - # 移除闭标签 - pattern = re.compile(rf'', re.IGNORECASE) - result = pattern.sub('', result) - - # 移除危险属性 - for attr in DANGEROUS_ATTRS: - pattern = re.compile(rf'\s*{attr}\s*=\s*["\'][^"\']*["\']', re.IGNORECASE) - result = pattern.sub('', result) - # 也处理没有引号的情况 - pattern = re.compile(rf'\s*{attr}\s*=\s*\S+', re.IGNORECASE) - result = pattern.sub('', result) - - # 移除 javascript: 协议 - pattern = re.compile(r'javascript\s*:', re.IGNORECASE) - result = pattern.sub('', result) - - # 移除 data: 协议(可能包含恶意代码) - pattern = re.compile(r'data\s*:\s*text/html', re.IGNORECASE) - result = pattern.sub('', result) - - # 移除 vbscript: 协议 - pattern = re.compile(r'vbscript\s*:', re.IGNORECASE) - result = pattern.sub('', result) - - return result - - -def escape_html(text: Optional[str]) -> Optional[str]: - """ - 转义HTML特殊字符 - - Args: - text: 输入文本 - - Returns: - 转义后的文本 - """ - if text is None: - return None - - if not isinstance(text, str): - return text - - return html.escape(text, quote=True) - - -def strip_tags(text: Optional[str]) -> Optional[str]: - """ - 完全移除所有HTML标签 - - Args: - text: 输入文本 - - Returns: - 移除标签后的纯文本 - """ - if text is None: - return None - - if not isinstance(text, str): - return text - - # 移除所有HTML标签 - clean = re.compile('<[^>]*>') - return clean.sub('', text) - - -def sanitize_input(text: Optional[str], strict: bool = False) -> Optional[str]: - """ - 清理用户输入 - - Args: - text: 输入文本 - strict: 是否使用严格模式(完全移除所有HTML标签) - - Returns: - 清理后的文本 - """ - if text is None: - return None - - if strict: - return strip_tags(text) - else: - return sanitize_html(text) +""" +输入清理和XSS防护工具 +""" +import re +import html +from typing import Optional + + +# 危险的HTML标签和属性 +DANGEROUS_TAGS = [ + 'script', 'iframe', 'object', 'embed', 'form', 'input', + 'textarea', 'button', 'select', 'style', 'link', 'meta', + 'base', 'applet', 'frame', 'frameset', 'layer', 'ilayer', + 'bgsound', 'xml', 'blink', 'marquee' +] + +DANGEROUS_ATTRS = [ + 'onclick', 'ondblclick', 'onmousedown', 'onmouseup', 'onmouseover', + 'onmousemove', 'onmouseout', 'onkeypress', 'onkeydown', 'onkeyup', + 'onload', 'onerror', 'onabort', 'onblur', 'onchange', 'onfocus', + 'onreset', 'onsubmit', 'onunload', 'onbeforeunload', 'onresize', + 'onscroll', 'ondrag', 'ondragend', 'ondragenter', 'ondragleave', + 'ondragover', 'ondragstart', 'ondrop', 'onmousewheel', 'onwheel', + 'oncopy', 'oncut', 'onpaste', 'oncontextmenu', 'oninput', 'oninvalid', + 'onsearch', 'onselect', 'ontoggle', 'formaction', 'xlink:href' +] + + +def sanitize_html(text: Optional[str]) -> Optional[str]: + """ + 清理HTML内容,移除危险标签和属性 + + Args: + text: 输入文本 + + Returns: + 清理后的文本 + """ + if text is None: + return None + + if not isinstance(text, str): + return text + + result = text + + # 移除危险标签 + for tag in DANGEROUS_TAGS: + # 移除开标签 + pattern = re.compile(rf'<{tag}[^>]*>', re.IGNORECASE) + result = pattern.sub('', result) + # 移除闭标签 + pattern = re.compile(rf'', re.IGNORECASE) + result = pattern.sub('', result) + + # 移除危险属性 + for attr in DANGEROUS_ATTRS: + pattern = re.compile(rf'\s*{attr}\s*=\s*["\'][^"\']*["\']', re.IGNORECASE) + result = pattern.sub('', result) + # 也处理没有引号的情况 + pattern = re.compile(rf'\s*{attr}\s*=\s*\S+', re.IGNORECASE) + result = pattern.sub('', result) + + # 移除 javascript: 协议 + pattern = re.compile(r'javascript\s*:', re.IGNORECASE) + result = pattern.sub('', result) + + # 移除 data: 协议(可能包含恶意代码) + pattern = re.compile(r'data\s*:\s*text/html', re.IGNORECASE) + result = pattern.sub('', result) + + # 移除 vbscript: 协议 + pattern = re.compile(r'vbscript\s*:', re.IGNORECASE) + result = pattern.sub('', result) + + return result + + +def escape_html(text: Optional[str]) -> Optional[str]: + """ + 转义HTML特殊字符 + + Args: + text: 输入文本 + + Returns: + 转义后的文本 + """ + if text is None: + return None + + if not isinstance(text, str): + return text + + return html.escape(text, quote=True) + + +def strip_tags(text: Optional[str]) -> Optional[str]: + """ + 完全移除所有HTML标签 + + Args: + text: 输入文本 + + Returns: + 移除标签后的纯文本 + """ + if text is None: + return None + + if not isinstance(text, str): + return text + + # 移除所有HTML标签 + clean = re.compile('<[^>]*>') + return clean.sub('', text) + + +def sanitize_input(text: Optional[str], strict: bool = False) -> Optional[str]: + """ + 清理用户输入 + + Args: + text: 输入文本 + strict: 是否使用严格模式(完全移除所有HTML标签) + + Returns: + 清理后的文本 + """ + if text is None: + return None + + if strict: + return strip_tags(text) + else: + return sanitize_html(text) diff --git a/backend/app/services/ai/knowledge_analysis_v2.py b/backend/app/services/ai/knowledge_analysis_v2.py index 80bb6ea..a1811f9 100644 --- a/backend/app/services/ai/knowledge_analysis_v2.py +++ b/backend/app/services/ai/knowledge_analysis_v2.py @@ -8,6 +8,7 @@ - 写入数据库 提供稳定可靠的知识点分析能力。 +支持MinIO和本地文件系统两种存储后端。 """ import logging @@ -20,6 +21,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.exceptions import ExternalServiceError from app.schemas.course import KnowledgePointCreate +from app.services.storage_service import storage_service from .ai_service import AIService, AIResponse from .llm_json_parser import parse_with_fallback, clean_llm_output @@ -92,8 +94,8 @@ class KnowledgeAnalysisServiceV2: f"file_url: {file_url}" ) - # 1. 解析文件路径 - file_path = self._resolve_file_path(file_url) + # 1. 解析文件路径(支持MinIO和本地文件系统) + file_path = await self._resolve_file_path(file_url) if not file_path.exists(): raise FileNotFoundError(f"文件不存在: {file_path}") @@ -160,11 +162,20 @@ class KnowledgeAnalysisServiceV2: ) raise ExternalServiceError(f"知识点分析失败: {e}") - def _resolve_file_path(self, file_url: str) -> Path: - """解析文件 URL 为本地路径""" + async def _resolve_file_path(self, file_url: str) -> Path: + """ + 解析文件 URL 为本地路径 + + 支持MinIO和本地文件系统。如果文件在MinIO中,会先下载到本地缓存。 + """ if file_url.startswith(STATIC_UPLOADS_PREFIX): - relative_path = file_url.replace(STATIC_UPLOADS_PREFIX, '') - return Path(self.upload_path) / relative_path + object_name = file_url.replace(STATIC_UPLOADS_PREFIX, '') + # 使用storage_service获取文件路径(自动处理MinIO下载) + file_path = await storage_service.get_file_path(object_name) + if file_path: + return file_path + # 如果storage_service返回None,尝试本地路径(兼容旧数据) + return Path(self.upload_path) / object_name elif file_url.startswith('/'): # 绝对路径 return Path(file_url) diff --git a/backend/app/services/course_service.py b/backend/app/services/course_service.py index ba175f4..1ec30d6 100644 --- a/backend/app/services/course_service.py +++ b/backend/app/services/course_service.py @@ -465,9 +465,7 @@ class CourseService(BaseService[Course]): Returns: 是否删除成功 """ - import os - from pathlib import Path - from app.core.config import settings + from app.services.storage_service import storage_service # 先确认课程存在 course = await self.get_by_id(db, course_id) @@ -498,21 +496,18 @@ class CourseService(BaseService[Course]): db.add(material) await db.commit() - # 删除物理文件 + # 删除物理文件(使用storage_service) if file_url and file_url.startswith("/static/uploads/"): try: # 从URL中提取相对路径 - relative_path = file_url.replace("/static/uploads/", "") - file_path = Path(settings.UPLOAD_PATH) / relative_path - - # 检查文件是否存在并删除 - if file_path.exists() and file_path.is_file(): - os.remove(file_path) - logger.info( - "删除物理文件成功", - file_path=str(file_path), - material_id=material_id, - ) + object_name = file_url.replace("/static/uploads/", "") + await storage_service.delete(object_name) + logger.info( + "删除物理文件成功", + object_name=object_name, + material_id=material_id, + storage="minio" if storage_service.is_minio_enabled else "local", + ) except Exception as e: # 物理文件删除失败不影响业务流程,仅记录日志 logger.error( diff --git a/backend/app/services/dingtalk_service.py b/backend/app/services/dingtalk_service.py index 0381961..c4191c1 100644 --- a/backend/app/services/dingtalk_service.py +++ b/backend/app/services/dingtalk_service.py @@ -1,276 +1,276 @@ -""" -钉钉开放平台 API 服务 -用于通过钉钉 API 获取组织架构和员工信息 -""" - -import httpx -from typing import List, Dict, Any, Optional -from datetime import datetime, timedelta -from app.core.logger import get_logger - -logger = get_logger(__name__) - - -class DingTalkService: - """钉钉 API 服务""" - - BASE_URL = "https://api.dingtalk.com" - OAPI_URL = "https://oapi.dingtalk.com" - - def __init__( - self, - corp_id: str, - client_id: str, - client_secret: str - ): - """ - 初始化钉钉服务 - - Args: - corp_id: 企业 CorpId - client_id: 应用 ClientId (AppKey) - client_secret: 应用 ClientSecret (AppSecret) - """ - self.corp_id = corp_id - self.client_id = client_id - self.client_secret = client_secret - self._access_token: Optional[str] = None - self._token_expires_at: Optional[datetime] = None - - async def get_access_token(self) -> str: - """ - 获取钉钉 Access Token - - 使用新版 OAuth2 接口获取 - - Returns: - access_token - """ - # 检查缓存的 token 是否有效 - if self._access_token and self._token_expires_at: - if datetime.now() < self._token_expires_at - timedelta(minutes=5): - return self._access_token - - url = f"{self.BASE_URL}/v1.0/oauth2/{self.corp_id}/token" - - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "grant_type": "client_credentials" - } - - async with httpx.AsyncClient(timeout=30) as client: - response = await client.post(url, json=payload) - response.raise_for_status() - data = response.json() - - self._access_token = data["access_token"] - expires_in = data.get("expires_in", 7200) - self._token_expires_at = datetime.now() + timedelta(seconds=expires_in) - - logger.info(f"获取钉钉 Access Token 成功,有效期 {expires_in} 秒") - return self._access_token - - async def get_department_list(self, dept_id: int = 1) -> List[Dict[str, Any]]: - """ - 获取部门列表 - - Args: - dept_id: 父部门ID,根部门为1 - - Returns: - 部门列表 - """ - access_token = await self.get_access_token() - url = f"{self.OAPI_URL}/topapi/v2/department/listsub" - - params = {"access_token": access_token} - payload = {"dept_id": dept_id} - - async with httpx.AsyncClient(timeout=30) as client: - response = await client.post(url, params=params, json=payload) - response.raise_for_status() - data = response.json() - - if data.get("errcode") != 0: - raise Exception(f"获取部门列表失败: {data.get('errmsg')}") - - return data.get("result", []) - - async def get_all_departments(self) -> List[Dict[str, Any]]: - """ - 递归获取所有部门 - - Returns: - 所有部门列表(扁平化) - """ - all_departments = [] - - async def fetch_recursive(parent_id: int): - departments = await self.get_department_list(parent_id) - for dept in departments: - all_departments.append(dept) - # 递归获取子部门 - await fetch_recursive(dept["dept_id"]) - - await fetch_recursive(1) # 从根部门开始 - logger.info(f"获取到 {len(all_departments)} 个部门") - return all_departments - - async def get_department_users( - self, - dept_id: int, - cursor: int = 0, - size: int = 100 - ) -> Dict[str, Any]: - """ - 获取部门用户列表 - - Args: - dept_id: 部门ID - cursor: 分页游标 - size: 每页大小,最大100 - - Returns: - 用户列表和分页信息 - """ - access_token = await self.get_access_token() - url = f"{self.OAPI_URL}/topapi/v2/user/list" - - params = {"access_token": access_token} - payload = { - "dept_id": dept_id, - "cursor": cursor, - "size": size - } - - async with httpx.AsyncClient(timeout=30) as client: - response = await client.post(url, params=params, json=payload) - response.raise_for_status() - data = response.json() - - if data.get("errcode") != 0: - raise Exception(f"获取部门用户失败: {data.get('errmsg')}") - - return data.get("result", {}) - - async def get_all_employees(self) -> List[Dict[str, Any]]: - """ - 获取所有在职员工 - - 遍历所有部门获取员工列表 - - Returns: - 员工列表 - """ - logger.info("开始从钉钉 API 获取所有员工...") - - # 1. 获取所有部门 - departments = await self.get_all_departments() - - # 创建部门ID到名称的映射 - dept_map = {dept["dept_id"]: dept["name"] for dept in departments} - dept_map[1] = "根部门" # 添加根部门 - - # 2. 遍历所有部门获取员工 - all_employees = {} # 使用字典去重(按 userid) - - for dept in [{"dept_id": 1, "name": "根部门"}] + departments: - dept_id = dept["dept_id"] - dept_name = dept["name"] - - cursor = 0 - while True: - result = await self.get_department_users(dept_id, cursor) - users = result.get("list", []) - - for user in users: - userid = user.get("userid") - if userid and userid not in all_employees: - # 转换为统一格式 - employee = self._convert_user_to_employee(user, dept_name) - all_employees[userid] = employee - - # 检查是否还有更多数据 - if not result.get("has_more", False): - break - cursor = result.get("next_cursor", 0) - - employees = list(all_employees.values()) - logger.info(f"获取到 {len(employees)} 位在职员工") - return employees - - def _convert_user_to_employee( - self, - user: Dict[str, Any], - dept_name: str - ) -> Dict[str, Any]: - """ - 将钉钉用户数据转换为员工数据格式 - - Args: - user: 钉钉用户数据 - dept_name: 部门名称 - - Returns: - 标准员工数据格式 - """ - return { - 'full_name': user.get('name', ''), - 'phone': user.get('mobile', ''), - 'email': user.get('email', ''), - 'department': dept_name, - 'position': user.get('title', ''), - 'employee_no': user.get('job_number', ''), - 'is_leader': user.get('leader', False), - 'is_active': user.get('active', True), - 'dingtalk_id': user.get('userid', ''), - 'join_date': user.get('hired_date'), - 'work_location': user.get('work_place', ''), - 'avatar': user.get('avatar', ''), - } - - async def test_connection(self) -> Dict[str, Any]: - """ - 测试钉钉 API 连接 - - Returns: - 测试结果 - """ - try: - # 1. 测试获取 token - token = await self.get_access_token() - - # 2. 测试获取根部门信息 - departments = await self.get_department_list(1) - - # 3. 获取根部门员工数量 - result = await self.get_department_users(1, size=1) - - return { - "success": True, - "message": "连接成功", - "corp_id": self.corp_id, - "department_count": len(departments) + 1, # +1 是根部门 - "has_employees": result.get("has_more", False) or len(result.get("list", [])) > 0 - } - - except httpx.HTTPStatusError as e: - error_detail = "HTTP错误" - if e.response.status_code == 400: - try: - error_data = e.response.json() - error_detail = error_data.get("message", str(e)) - except: - pass - return { - "success": False, - "message": f"连接失败: {error_detail}", - "error": str(e) - } - except Exception as e: - return { - "success": False, - "message": f"连接失败: {str(e)}", - "error": str(e) - } +""" +钉钉开放平台 API 服务 +用于通过钉钉 API 获取组织架构和员工信息 +""" + +import httpx +from typing import List, Dict, Any, Optional +from datetime import datetime, timedelta +from app.core.logger import get_logger + +logger = get_logger(__name__) + + +class DingTalkService: + """钉钉 API 服务""" + + BASE_URL = "https://api.dingtalk.com" + OAPI_URL = "https://oapi.dingtalk.com" + + def __init__( + self, + corp_id: str, + client_id: str, + client_secret: str + ): + """ + 初始化钉钉服务 + + Args: + corp_id: 企业 CorpId + client_id: 应用 ClientId (AppKey) + client_secret: 应用 ClientSecret (AppSecret) + """ + self.corp_id = corp_id + self.client_id = client_id + self.client_secret = client_secret + self._access_token: Optional[str] = None + self._token_expires_at: Optional[datetime] = None + + async def get_access_token(self) -> str: + """ + 获取钉钉 Access Token + + 使用新版 OAuth2 接口获取 + + Returns: + access_token + """ + # 检查缓存的 token 是否有效 + if self._access_token and self._token_expires_at: + if datetime.now() < self._token_expires_at - timedelta(minutes=5): + return self._access_token + + url = f"{self.BASE_URL}/v1.0/oauth2/{self.corp_id}/token" + + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "client_credentials" + } + + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post(url, json=payload) + response.raise_for_status() + data = response.json() + + self._access_token = data["access_token"] + expires_in = data.get("expires_in", 7200) + self._token_expires_at = datetime.now() + timedelta(seconds=expires_in) + + logger.info(f"获取钉钉 Access Token 成功,有效期 {expires_in} 秒") + return self._access_token + + async def get_department_list(self, dept_id: int = 1) -> List[Dict[str, Any]]: + """ + 获取部门列表 + + Args: + dept_id: 父部门ID,根部门为1 + + Returns: + 部门列表 + """ + access_token = await self.get_access_token() + url = f"{self.OAPI_URL}/topapi/v2/department/listsub" + + params = {"access_token": access_token} + payload = {"dept_id": dept_id} + + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post(url, params=params, json=payload) + response.raise_for_status() + data = response.json() + + if data.get("errcode") != 0: + raise Exception(f"获取部门列表失败: {data.get('errmsg')}") + + return data.get("result", []) + + async def get_all_departments(self) -> List[Dict[str, Any]]: + """ + 递归获取所有部门 + + Returns: + 所有部门列表(扁平化) + """ + all_departments = [] + + async def fetch_recursive(parent_id: int): + departments = await self.get_department_list(parent_id) + for dept in departments: + all_departments.append(dept) + # 递归获取子部门 + await fetch_recursive(dept["dept_id"]) + + await fetch_recursive(1) # 从根部门开始 + logger.info(f"获取到 {len(all_departments)} 个部门") + return all_departments + + async def get_department_users( + self, + dept_id: int, + cursor: int = 0, + size: int = 100 + ) -> Dict[str, Any]: + """ + 获取部门用户列表 + + Args: + dept_id: 部门ID + cursor: 分页游标 + size: 每页大小,最大100 + + Returns: + 用户列表和分页信息 + """ + access_token = await self.get_access_token() + url = f"{self.OAPI_URL}/topapi/v2/user/list" + + params = {"access_token": access_token} + payload = { + "dept_id": dept_id, + "cursor": cursor, + "size": size + } + + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post(url, params=params, json=payload) + response.raise_for_status() + data = response.json() + + if data.get("errcode") != 0: + raise Exception(f"获取部门用户失败: {data.get('errmsg')}") + + return data.get("result", {}) + + async def get_all_employees(self) -> List[Dict[str, Any]]: + """ + 获取所有在职员工 + + 遍历所有部门获取员工列表 + + Returns: + 员工列表 + """ + logger.info("开始从钉钉 API 获取所有员工...") + + # 1. 获取所有部门 + departments = await self.get_all_departments() + + # 创建部门ID到名称的映射 + dept_map = {dept["dept_id"]: dept["name"] for dept in departments} + dept_map[1] = "根部门" # 添加根部门 + + # 2. 遍历所有部门获取员工 + all_employees = {} # 使用字典去重(按 userid) + + for dept in [{"dept_id": 1, "name": "根部门"}] + departments: + dept_id = dept["dept_id"] + dept_name = dept["name"] + + cursor = 0 + while True: + result = await self.get_department_users(dept_id, cursor) + users = result.get("list", []) + + for user in users: + userid = user.get("userid") + if userid and userid not in all_employees: + # 转换为统一格式 + employee = self._convert_user_to_employee(user, dept_name) + all_employees[userid] = employee + + # 检查是否还有更多数据 + if not result.get("has_more", False): + break + cursor = result.get("next_cursor", 0) + + employees = list(all_employees.values()) + logger.info(f"获取到 {len(employees)} 位在职员工") + return employees + + def _convert_user_to_employee( + self, + user: Dict[str, Any], + dept_name: str + ) -> Dict[str, Any]: + """ + 将钉钉用户数据转换为员工数据格式 + + Args: + user: 钉钉用户数据 + dept_name: 部门名称 + + Returns: + 标准员工数据格式 + """ + return { + 'full_name': user.get('name', ''), + 'phone': user.get('mobile', ''), + 'email': user.get('email', ''), + 'department': dept_name, + 'position': user.get('title', ''), + 'employee_no': user.get('job_number', ''), + 'is_leader': user.get('leader', False), + 'is_active': user.get('active', True), + 'dingtalk_id': user.get('userid', ''), + 'join_date': user.get('hired_date'), + 'work_location': user.get('work_place', ''), + 'avatar': user.get('avatar', ''), + } + + async def test_connection(self) -> Dict[str, Any]: + """ + 测试钉钉 API 连接 + + Returns: + 测试结果 + """ + try: + # 1. 测试获取 token + token = await self.get_access_token() + + # 2. 测试获取根部门信息 + departments = await self.get_department_list(1) + + # 3. 获取根部门员工数量 + result = await self.get_department_users(1, size=1) + + return { + "success": True, + "message": "连接成功", + "corp_id": self.corp_id, + "department_count": len(departments) + 1, # +1 是根部门 + "has_employees": result.get("has_more", False) or len(result.get("list", [])) > 0 + } + + except httpx.HTTPStatusError as e: + error_detail = "HTTP错误" + if e.response.status_code == 400: + try: + error_data = e.response.json() + error_detail = error_data.get("message", str(e)) + except: + pass + return { + "success": False, + "message": f"连接失败: {error_detail}", + "error": str(e) + } + except Exception as e: + return { + "success": False, + "message": f"连接失败: {str(e)}", + "error": str(e) + } diff --git a/backend/app/services/storage_service.py b/backend/app/services/storage_service.py new file mode 100644 index 0000000..014e290 --- /dev/null +++ b/backend/app/services/storage_service.py @@ -0,0 +1,422 @@ +""" +统一文件存储服务 +支持MinIO对象存储,兼容本地文件系统 + +使用方式: + from app.services.storage_service import storage_service + + # 上传文件 + file_url = await storage_service.upload(file_data, "courses/1/doc.pdf") + + # 下载文件 + file_data = await storage_service.download("courses/1/doc.pdf") + + # 删除文件 + await storage_service.delete("courses/1/doc.pdf") +""" + +import os +import io +import logging +from pathlib import Path +from typing import Optional, Union, BinaryIO +from datetime import timedelta + +from minio import Minio +from minio.error import S3Error + +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class StorageService: + """ + 统一文件存储服务 + + 支持两种存储后端: + 1. MinIO对象存储(推荐,生产环境) + 2. 本地文件系统(开发环境或MinIO不可用时的降级方案) + """ + + def __init__(self): + self._client: Optional[Minio] = None + self._initialized = False + self._use_minio = False + + def _ensure_initialized(self): + """确保服务已初始化""" + if self._initialized: + return + + self._initialized = True + + # 检查是否启用MinIO + if not settings.MINIO_ENABLED: + logger.info("MinIO未启用,使用本地文件存储") + self._use_minio = False + return + + try: + self._client = Minio( + settings.MINIO_ENDPOINT, + access_key=settings.MINIO_ACCESS_KEY, + secret_key=settings.MINIO_SECRET_KEY, + secure=settings.MINIO_SECURE, + ) + + # 验证连接并确保bucket存在 + bucket_name = self._get_bucket_name() + if not self._client.bucket_exists(bucket_name): + self._client.make_bucket(bucket_name) + logger.info(f"创建MinIO bucket: {bucket_name}") + + # 设置bucket策略为公开读取 + self._set_bucket_public_read(bucket_name) + + self._use_minio = True + logger.info(f"MinIO存储服务初始化成功 - endpoint: {settings.MINIO_ENDPOINT}, bucket: {bucket_name}") + + except Exception as e: + logger.warning(f"MinIO初始化失败,降级为本地存储: {e}") + self._use_minio = False + + def _get_bucket_name(self) -> str: + """获取当前租户的bucket名称""" + return f"kpl-{settings.TENANT_CODE}" + + def _set_bucket_public_read(self, bucket_name: str): + """设置bucket为公开读取""" + try: + # 设置匿名读取策略 + policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"AWS": "*"}, + "Action": ["s3:GetObject"], + "Resource": [f"arn:aws:s3:::{bucket_name}/*"] + } + ] + } + import json + self._client.set_bucket_policy(bucket_name, json.dumps(policy)) + except Exception as e: + logger.warning(f"设置bucket公开读取策略失败: {e}") + + def _normalize_object_name(self, object_name: str) -> str: + """标准化对象名称,移除前缀斜杠""" + if object_name.startswith('/'): + object_name = object_name[1:] + if object_name.startswith('static/uploads/'): + object_name = object_name.replace('static/uploads/', '') + return object_name + + def _get_file_url(self, object_name: str) -> str: + """获取文件访问URL""" + object_name = self._normalize_object_name(object_name) + # 统一返回 /static/uploads/ 格式的URL,由Nginx代理到MinIO + return f"/static/uploads/{object_name}" + + def _get_local_path(self, object_name: str) -> Path: + """获取本地文件路径""" + object_name = self._normalize_object_name(object_name) + return Path(settings.UPLOAD_PATH) / object_name + + async def upload( + self, + file_data: Union[bytes, BinaryIO], + object_name: str, + content_type: Optional[str] = None, + ) -> str: + """ + 上传文件 + + Args: + file_data: 文件数据(bytes或文件对象) + object_name: 对象名称(如 courses/1/doc.pdf) + content_type: 文件MIME类型 + + Returns: + 文件访问URL + """ + self._ensure_initialized() + object_name = self._normalize_object_name(object_name) + + # 转换为bytes + if isinstance(file_data, bytes): + data = file_data + else: + data = file_data.read() + + if self._use_minio: + return await self._upload_to_minio(data, object_name, content_type) + else: + return await self._upload_to_local(data, object_name) + + async def _upload_to_minio( + self, + data: bytes, + object_name: str, + content_type: Optional[str] = None, + ) -> str: + """上传到MinIO""" + try: + bucket_name = self._get_bucket_name() + + # 自动检测content_type + if not content_type: + content_type = self._guess_content_type(object_name) + + self._client.put_object( + bucket_name, + object_name, + io.BytesIO(data), + length=len(data), + content_type=content_type, + ) + + file_url = self._get_file_url(object_name) + logger.info(f"文件上传到MinIO成功: {object_name} -> {file_url}") + return file_url + + except S3Error as e: + logger.error(f"MinIO上传失败: {e}") + # 降级到本地存储 + return await self._upload_to_local(data, object_name) + + async def _upload_to_local(self, data: bytes, object_name: str) -> str: + """上传到本地文件系统""" + try: + file_path = self._get_local_path(object_name) + file_path.parent.mkdir(parents=True, exist_ok=True) + + with open(file_path, 'wb') as f: + f.write(data) + + file_url = self._get_file_url(object_name) + logger.info(f"文件上传到本地成功: {object_name} -> {file_url}") + return file_url + + except Exception as e: + logger.error(f"本地文件上传失败: {e}") + raise + + async def download(self, object_name: str) -> Optional[bytes]: + """ + 下载文件 + + Args: + object_name: 对象名称 + + Returns: + 文件数据,如果文件不存在返回None + """ + self._ensure_initialized() + object_name = self._normalize_object_name(object_name) + + if self._use_minio: + return await self._download_from_minio(object_name) + else: + return await self._download_from_local(object_name) + + async def _download_from_minio(self, object_name: str) -> Optional[bytes]: + """从MinIO下载""" + try: + bucket_name = self._get_bucket_name() + response = self._client.get_object(bucket_name, object_name) + data = response.read() + response.close() + response.release_conn() + return data + except S3Error as e: + if e.code == 'NoSuchKey': + logger.warning(f"MinIO文件不存在: {object_name}") + # 尝试从本地读取(兼容迁移过渡期) + return await self._download_from_local(object_name) + logger.error(f"MinIO下载失败: {e}") + return None + + async def _download_from_local(self, object_name: str) -> Optional[bytes]: + """从本地文件系统下载""" + try: + file_path = self._get_local_path(object_name) + if not file_path.exists(): + logger.warning(f"本地文件不存在: {file_path}") + return None + + with open(file_path, 'rb') as f: + return f.read() + except Exception as e: + logger.error(f"本地文件下载失败: {e}") + return None + + async def delete(self, object_name: str) -> bool: + """ + 删除文件 + + Args: + object_name: 对象名称 + + Returns: + 是否删除成功 + """ + self._ensure_initialized() + object_name = self._normalize_object_name(object_name) + + success = True + + # MinIO删除 + if self._use_minio: + try: + bucket_name = self._get_bucket_name() + self._client.remove_object(bucket_name, object_name) + logger.info(f"MinIO文件删除成功: {object_name}") + except S3Error as e: + if e.code != 'NoSuchKey': + logger.error(f"MinIO文件删除失败: {e}") + success = False + + # 同时删除本地文件(确保彻底清理) + try: + file_path = self._get_local_path(object_name) + if file_path.exists(): + os.remove(file_path) + logger.info(f"本地文件删除成功: {file_path}") + except Exception as e: + logger.warning(f"本地文件删除失败: {e}") + + return success + + async def exists(self, object_name: str) -> bool: + """ + 检查文件是否存在 + + Args: + object_name: 对象名称 + + Returns: + 文件是否存在 + """ + self._ensure_initialized() + object_name = self._normalize_object_name(object_name) + + if self._use_minio: + try: + bucket_name = self._get_bucket_name() + self._client.stat_object(bucket_name, object_name) + return True + except S3Error: + pass + + # 检查本地文件 + file_path = self._get_local_path(object_name) + return file_path.exists() + + async def get_file_path(self, object_name: str) -> Optional[Path]: + """ + 获取文件的本地路径(用于需要本地文件操作的场景) + + 如果文件在MinIO中,会先下载到临时目录 + + Args: + object_name: 对象名称 + + Returns: + 本地文件路径,如果文件不存在返回None + """ + self._ensure_initialized() + object_name = self._normalize_object_name(object_name) + + # 先检查本地是否存在 + local_path = self._get_local_path(object_name) + if local_path.exists(): + return local_path + + # 如果MinIO启用,尝试下载到本地缓存 + if self._use_minio: + try: + data = await self._download_from_minio(object_name) + if data: + # 保存到本地缓存 + local_path.parent.mkdir(parents=True, exist_ok=True) + with open(local_path, 'wb') as f: + f.write(data) + logger.info(f"从MinIO下载文件到本地缓存: {object_name}") + return local_path + except Exception as e: + logger.error(f"下载MinIO文件到本地失败: {e}") + + return None + + def get_presigned_url(self, object_name: str, expires: int = 3600) -> Optional[str]: + """ + 获取预签名URL(用于直接访问MinIO) + + Args: + object_name: 对象名称 + expires: 过期时间(秒) + + Returns: + 预签名URL,如果MinIO未启用返回None + """ + self._ensure_initialized() + + if not self._use_minio: + return None + + object_name = self._normalize_object_name(object_name) + + try: + bucket_name = self._get_bucket_name() + url = self._client.presigned_get_object( + bucket_name, + object_name, + expires=timedelta(seconds=expires) + ) + return url + except S3Error as e: + logger.error(f"获取预签名URL失败: {e}") + return None + + def _guess_content_type(self, filename: str) -> str: + """根据文件名猜测MIME类型""" + ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else '' + content_types = { + 'pdf': 'application/pdf', + 'doc': 'application/msword', + 'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'xls': 'application/vnd.ms-excel', + 'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + 'ppt': 'application/vnd.ms-powerpoint', + 'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + 'txt': 'text/plain', + 'md': 'text/markdown', + 'html': 'text/html', + 'htm': 'text/html', + 'csv': 'text/csv', + 'json': 'application/json', + 'xml': 'application/xml', + 'zip': 'application/zip', + 'png': 'image/png', + 'jpg': 'image/jpeg', + 'jpeg': 'image/jpeg', + 'gif': 'image/gif', + 'webp': 'image/webp', + 'mp3': 'audio/mpeg', + 'wav': 'audio/wav', + 'mp4': 'video/mp4', + 'webm': 'video/webm', + } + return content_types.get(ext, 'application/octet-stream') + + @property + def is_minio_enabled(self) -> bool: + """检查MinIO是否启用""" + self._ensure_initialized() + return self._use_minio + + +# 全局单例 +storage_service = StorageService() diff --git a/backend/requirements.txt b/backend/requirements.txt index ad4e27a..1afa852 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -31,6 +31,9 @@ PyMySQL==1.1.0 httpx==0.27.2 aiofiles==23.2.1 +# 对象存储(MinIO) +minio>=7.2.0 + # 日志 structlog==23.2.0 diff --git a/frontend/src/api/dashboard.ts b/frontend/src/api/dashboard.ts index ffb4a99..9105380 100644 --- a/frontend/src/api/dashboard.ts +++ b/frontend/src/api/dashboard.ts @@ -64,9 +64,7 @@ export interface TrendData { export interface LevelDistribution { levels: number[] counts: number[] -} - -// 实时动态 +}// 实时动态 export interface ActivityItem { id: number user_id: number diff --git a/frontend/src/api/task.ts b/frontend/src/api/task.ts index 1c83a1b..cfe7e0c 100644 --- a/frontend/src/api/task.ts +++ b/frontend/src/api/task.ts @@ -111,5 +111,4 @@ export function deleteTask(id: number): Promise> { */ export function sendTaskReminder(id: number): Promise> { return http.post(`/api/v1/manager/tasks/${id}/remind`) -} - +} \ No newline at end of file