feat: 集成MinIO对象存储服务
All checks were successful
continuous-integration/drone/push Build is passing

- 新增storage_service.py封装MinIO操作
- 修改upload.py使用storage_service上传文件
- 修改course_service.py使用storage_service删除文件
- 适配preview.py支持从MinIO获取文件
- 适配knowledge_analysis_v2.py支持MinIO存储
- 在config.py添加MinIO配置项
- 添加minio依赖到requirements.txt

支持特性:
- 自动降级到本地存储(MinIO不可用时)
- 保持URL格式兼容(/static/uploads/)
- 文件自动缓存到本地(用于预览和分析)

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
yuliang_guo
2026-02-03 14:06:22 +08:00
parent fca82e2d44
commit 2f47193059
13 changed files with 1071 additions and 629 deletions

View File

@@ -8,6 +8,7 @@
- 写入数据库
提供稳定可靠的知识点分析能力。
支持MinIO和本地文件系统两种存储后端。
"""
import logging
@@ -20,6 +21,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.exceptions import ExternalServiceError
from app.schemas.course import KnowledgePointCreate
from app.services.storage_service import storage_service
from .ai_service import AIService, AIResponse
from .llm_json_parser import parse_with_fallback, clean_llm_output
@@ -92,8 +94,8 @@ class KnowledgeAnalysisServiceV2:
f"file_url: {file_url}"
)
# 1. 解析文件路径
file_path = self._resolve_file_path(file_url)
# 1. 解析文件路径支持MinIO和本地文件系统
file_path = await self._resolve_file_path(file_url)
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
@@ -160,11 +162,20 @@ class KnowledgeAnalysisServiceV2:
)
raise ExternalServiceError(f"知识点分析失败: {e}")
def _resolve_file_path(self, file_url: str) -> Path:
"""解析文件 URL 为本地路径"""
async def _resolve_file_path(self, file_url: str) -> Path:
"""
解析文件 URL 为本地路径
支持MinIO和本地文件系统。如果文件在MinIO中会先下载到本地缓存。
"""
if file_url.startswith(STATIC_UPLOADS_PREFIX):
relative_path = file_url.replace(STATIC_UPLOADS_PREFIX, '')
return Path(self.upload_path) / relative_path
object_name = file_url.replace(STATIC_UPLOADS_PREFIX, '')
# 使用storage_service获取文件路径自动处理MinIO下载
file_path = await storage_service.get_file_path(object_name)
if file_path:
return file_path
# 如果storage_service返回None尝试本地路径兼容旧数据
return Path(self.upload_path) / object_name
elif file_url.startswith('/'):
# 绝对路径
return Path(file_url)

View File

@@ -465,9 +465,7 @@ class CourseService(BaseService[Course]):
Returns:
是否删除成功
"""
import os
from pathlib import Path
from app.core.config import settings
from app.services.storage_service import storage_service
# 先确认课程存在
course = await self.get_by_id(db, course_id)
@@ -498,21 +496,18 @@ class CourseService(BaseService[Course]):
db.add(material)
await db.commit()
# 删除物理文件
# 删除物理文件使用storage_service
if file_url and file_url.startswith("/static/uploads/"):
try:
# 从URL中提取相对路径
relative_path = file_url.replace("/static/uploads/", "")
file_path = Path(settings.UPLOAD_PATH) / relative_path
# 检查文件是否存在并删除
if file_path.exists() and file_path.is_file():
os.remove(file_path)
logger.info(
"删除物理文件成功",
file_path=str(file_path),
material_id=material_id,
)
object_name = file_url.replace("/static/uploads/", "")
await storage_service.delete(object_name)
logger.info(
"删除物理文件成功",
object_name=object_name,
material_id=material_id,
storage="minio" if storage_service.is_minio_enabled else "local",
)
except Exception as e:
# 物理文件删除失败不影响业务流程,仅记录日志
logger.error(

View File

@@ -1,276 +1,276 @@
"""
钉钉开放平台 API 服务
用于通过钉钉 API 获取组织架构和员工信息
"""
import httpx
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from app.core.logger import get_logger
logger = get_logger(__name__)
class DingTalkService:
"""钉钉 API 服务"""
BASE_URL = "https://api.dingtalk.com"
OAPI_URL = "https://oapi.dingtalk.com"
def __init__(
self,
corp_id: str,
client_id: str,
client_secret: str
):
"""
初始化钉钉服务
Args:
corp_id: 企业 CorpId
client_id: 应用 ClientId (AppKey)
client_secret: 应用 ClientSecret (AppSecret)
"""
self.corp_id = corp_id
self.client_id = client_id
self.client_secret = client_secret
self._access_token: Optional[str] = None
self._token_expires_at: Optional[datetime] = None
async def get_access_token(self) -> str:
"""
获取钉钉 Access Token
使用新版 OAuth2 接口获取
Returns:
access_token
"""
# 检查缓存的 token 是否有效
if self._access_token and self._token_expires_at:
if datetime.now() < self._token_expires_at - timedelta(minutes=5):
return self._access_token
url = f"{self.BASE_URL}/v1.0/oauth2/{self.corp_id}/token"
payload = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "client_credentials"
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
self._access_token = data["access_token"]
expires_in = data.get("expires_in", 7200)
self._token_expires_at = datetime.now() + timedelta(seconds=expires_in)
logger.info(f"获取钉钉 Access Token 成功,有效期 {expires_in}")
return self._access_token
async def get_department_list(self, dept_id: int = 1) -> List[Dict[str, Any]]:
"""
获取部门列表
Args:
dept_id: 父部门ID根部门为1
Returns:
部门列表
"""
access_token = await self.get_access_token()
url = f"{self.OAPI_URL}/topapi/v2/department/listsub"
params = {"access_token": access_token}
payload = {"dept_id": dept_id}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, params=params, json=payload)
response.raise_for_status()
data = response.json()
if data.get("errcode") != 0:
raise Exception(f"获取部门列表失败: {data.get('errmsg')}")
return data.get("result", [])
async def get_all_departments(self) -> List[Dict[str, Any]]:
"""
递归获取所有部门
Returns:
所有部门列表(扁平化)
"""
all_departments = []
async def fetch_recursive(parent_id: int):
departments = await self.get_department_list(parent_id)
for dept in departments:
all_departments.append(dept)
# 递归获取子部门
await fetch_recursive(dept["dept_id"])
await fetch_recursive(1) # 从根部门开始
logger.info(f"获取到 {len(all_departments)} 个部门")
return all_departments
async def get_department_users(
self,
dept_id: int,
cursor: int = 0,
size: int = 100
) -> Dict[str, Any]:
"""
获取部门用户列表
Args:
dept_id: 部门ID
cursor: 分页游标
size: 每页大小最大100
Returns:
用户列表和分页信息
"""
access_token = await self.get_access_token()
url = f"{self.OAPI_URL}/topapi/v2/user/list"
params = {"access_token": access_token}
payload = {
"dept_id": dept_id,
"cursor": cursor,
"size": size
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, params=params, json=payload)
response.raise_for_status()
data = response.json()
if data.get("errcode") != 0:
raise Exception(f"获取部门用户失败: {data.get('errmsg')}")
return data.get("result", {})
async def get_all_employees(self) -> List[Dict[str, Any]]:
"""
获取所有在职员工
遍历所有部门获取员工列表
Returns:
员工列表
"""
logger.info("开始从钉钉 API 获取所有员工...")
# 1. 获取所有部门
departments = await self.get_all_departments()
# 创建部门ID到名称的映射
dept_map = {dept["dept_id"]: dept["name"] for dept in departments}
dept_map[1] = "根部门" # 添加根部门
# 2. 遍历所有部门获取员工
all_employees = {} # 使用字典去重(按 userid
for dept in [{"dept_id": 1, "name": "根部门"}] + departments:
dept_id = dept["dept_id"]
dept_name = dept["name"]
cursor = 0
while True:
result = await self.get_department_users(dept_id, cursor)
users = result.get("list", [])
for user in users:
userid = user.get("userid")
if userid and userid not in all_employees:
# 转换为统一格式
employee = self._convert_user_to_employee(user, dept_name)
all_employees[userid] = employee
# 检查是否还有更多数据
if not result.get("has_more", False):
break
cursor = result.get("next_cursor", 0)
employees = list(all_employees.values())
logger.info(f"获取到 {len(employees)} 位在职员工")
return employees
def _convert_user_to_employee(
self,
user: Dict[str, Any],
dept_name: str
) -> Dict[str, Any]:
"""
将钉钉用户数据转换为员工数据格式
Args:
user: 钉钉用户数据
dept_name: 部门名称
Returns:
标准员工数据格式
"""
return {
'full_name': user.get('name', ''),
'phone': user.get('mobile', ''),
'email': user.get('email', ''),
'department': dept_name,
'position': user.get('title', ''),
'employee_no': user.get('job_number', ''),
'is_leader': user.get('leader', False),
'is_active': user.get('active', True),
'dingtalk_id': user.get('userid', ''),
'join_date': user.get('hired_date'),
'work_location': user.get('work_place', ''),
'avatar': user.get('avatar', ''),
}
async def test_connection(self) -> Dict[str, Any]:
"""
测试钉钉 API 连接
Returns:
测试结果
"""
try:
# 1. 测试获取 token
token = await self.get_access_token()
# 2. 测试获取根部门信息
departments = await self.get_department_list(1)
# 3. 获取根部门员工数量
result = await self.get_department_users(1, size=1)
return {
"success": True,
"message": "连接成功",
"corp_id": self.corp_id,
"department_count": len(departments) + 1, # +1 是根部门
"has_employees": result.get("has_more", False) or len(result.get("list", [])) > 0
}
except httpx.HTTPStatusError as e:
error_detail = "HTTP错误"
if e.response.status_code == 400:
try:
error_data = e.response.json()
error_detail = error_data.get("message", str(e))
except:
pass
return {
"success": False,
"message": f"连接失败: {error_detail}",
"error": str(e)
}
except Exception as e:
return {
"success": False,
"message": f"连接失败: {str(e)}",
"error": str(e)
}
"""
钉钉开放平台 API 服务
用于通过钉钉 API 获取组织架构和员工信息
"""
import httpx
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from app.core.logger import get_logger
logger = get_logger(__name__)
class DingTalkService:
"""钉钉 API 服务"""
BASE_URL = "https://api.dingtalk.com"
OAPI_URL = "https://oapi.dingtalk.com"
def __init__(
self,
corp_id: str,
client_id: str,
client_secret: str
):
"""
初始化钉钉服务
Args:
corp_id: 企业 CorpId
client_id: 应用 ClientId (AppKey)
client_secret: 应用 ClientSecret (AppSecret)
"""
self.corp_id = corp_id
self.client_id = client_id
self.client_secret = client_secret
self._access_token: Optional[str] = None
self._token_expires_at: Optional[datetime] = None
async def get_access_token(self) -> str:
"""
获取钉钉 Access Token
使用新版 OAuth2 接口获取
Returns:
access_token
"""
# 检查缓存的 token 是否有效
if self._access_token and self._token_expires_at:
if datetime.now() < self._token_expires_at - timedelta(minutes=5):
return self._access_token
url = f"{self.BASE_URL}/v1.0/oauth2/{self.corp_id}/token"
payload = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "client_credentials"
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
self._access_token = data["access_token"]
expires_in = data.get("expires_in", 7200)
self._token_expires_at = datetime.now() + timedelta(seconds=expires_in)
logger.info(f"获取钉钉 Access Token 成功,有效期 {expires_in}")
return self._access_token
async def get_department_list(self, dept_id: int = 1) -> List[Dict[str, Any]]:
"""
获取部门列表
Args:
dept_id: 父部门ID根部门为1
Returns:
部门列表
"""
access_token = await self.get_access_token()
url = f"{self.OAPI_URL}/topapi/v2/department/listsub"
params = {"access_token": access_token}
payload = {"dept_id": dept_id}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, params=params, json=payload)
response.raise_for_status()
data = response.json()
if data.get("errcode") != 0:
raise Exception(f"获取部门列表失败: {data.get('errmsg')}")
return data.get("result", [])
async def get_all_departments(self) -> List[Dict[str, Any]]:
"""
递归获取所有部门
Returns:
所有部门列表(扁平化)
"""
all_departments = []
async def fetch_recursive(parent_id: int):
departments = await self.get_department_list(parent_id)
for dept in departments:
all_departments.append(dept)
# 递归获取子部门
await fetch_recursive(dept["dept_id"])
await fetch_recursive(1) # 从根部门开始
logger.info(f"获取到 {len(all_departments)} 个部门")
return all_departments
async def get_department_users(
self,
dept_id: int,
cursor: int = 0,
size: int = 100
) -> Dict[str, Any]:
"""
获取部门用户列表
Args:
dept_id: 部门ID
cursor: 分页游标
size: 每页大小最大100
Returns:
用户列表和分页信息
"""
access_token = await self.get_access_token()
url = f"{self.OAPI_URL}/topapi/v2/user/list"
params = {"access_token": access_token}
payload = {
"dept_id": dept_id,
"cursor": cursor,
"size": size
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, params=params, json=payload)
response.raise_for_status()
data = response.json()
if data.get("errcode") != 0:
raise Exception(f"获取部门用户失败: {data.get('errmsg')}")
return data.get("result", {})
async def get_all_employees(self) -> List[Dict[str, Any]]:
"""
获取所有在职员工
遍历所有部门获取员工列表
Returns:
员工列表
"""
logger.info("开始从钉钉 API 获取所有员工...")
# 1. 获取所有部门
departments = await self.get_all_departments()
# 创建部门ID到名称的映射
dept_map = {dept["dept_id"]: dept["name"] for dept in departments}
dept_map[1] = "根部门" # 添加根部门
# 2. 遍历所有部门获取员工
all_employees = {} # 使用字典去重(按 userid
for dept in [{"dept_id": 1, "name": "根部门"}] + departments:
dept_id = dept["dept_id"]
dept_name = dept["name"]
cursor = 0
while True:
result = await self.get_department_users(dept_id, cursor)
users = result.get("list", [])
for user in users:
userid = user.get("userid")
if userid and userid not in all_employees:
# 转换为统一格式
employee = self._convert_user_to_employee(user, dept_name)
all_employees[userid] = employee
# 检查是否还有更多数据
if not result.get("has_more", False):
break
cursor = result.get("next_cursor", 0)
employees = list(all_employees.values())
logger.info(f"获取到 {len(employees)} 位在职员工")
return employees
def _convert_user_to_employee(
self,
user: Dict[str, Any],
dept_name: str
) -> Dict[str, Any]:
"""
将钉钉用户数据转换为员工数据格式
Args:
user: 钉钉用户数据
dept_name: 部门名称
Returns:
标准员工数据格式
"""
return {
'full_name': user.get('name', ''),
'phone': user.get('mobile', ''),
'email': user.get('email', ''),
'department': dept_name,
'position': user.get('title', ''),
'employee_no': user.get('job_number', ''),
'is_leader': user.get('leader', False),
'is_active': user.get('active', True),
'dingtalk_id': user.get('userid', ''),
'join_date': user.get('hired_date'),
'work_location': user.get('work_place', ''),
'avatar': user.get('avatar', ''),
}
async def test_connection(self) -> Dict[str, Any]:
"""
测试钉钉 API 连接
Returns:
测试结果
"""
try:
# 1. 测试获取 token
token = await self.get_access_token()
# 2. 测试获取根部门信息
departments = await self.get_department_list(1)
# 3. 获取根部门员工数量
result = await self.get_department_users(1, size=1)
return {
"success": True,
"message": "连接成功",
"corp_id": self.corp_id,
"department_count": len(departments) + 1, # +1 是根部门
"has_employees": result.get("has_more", False) or len(result.get("list", [])) > 0
}
except httpx.HTTPStatusError as e:
error_detail = "HTTP错误"
if e.response.status_code == 400:
try:
error_data = e.response.json()
error_detail = error_data.get("message", str(e))
except:
pass
return {
"success": False,
"message": f"连接失败: {error_detail}",
"error": str(e)
}
except Exception as e:
return {
"success": False,
"message": f"连接失败: {str(e)}",
"error": str(e)
}

View File

@@ -0,0 +1,422 @@
"""
统一文件存储服务
支持MinIO对象存储兼容本地文件系统
使用方式:
from app.services.storage_service import storage_service
# 上传文件
file_url = await storage_service.upload(file_data, "courses/1/doc.pdf")
# 下载文件
file_data = await storage_service.download("courses/1/doc.pdf")
# 删除文件
await storage_service.delete("courses/1/doc.pdf")
"""
import os
import io
import logging
from pathlib import Path
from typing import Optional, Union, BinaryIO
from datetime import timedelta
from minio import Minio
from minio.error import S3Error
from app.core.config import settings
logger = logging.getLogger(__name__)
class StorageService:
"""
统一文件存储服务
支持两种存储后端:
1. MinIO对象存储推荐生产环境
2. 本地文件系统开发环境或MinIO不可用时的降级方案
"""
def __init__(self):
self._client: Optional[Minio] = None
self._initialized = False
self._use_minio = False
def _ensure_initialized(self):
"""确保服务已初始化"""
if self._initialized:
return
self._initialized = True
# 检查是否启用MinIO
if not settings.MINIO_ENABLED:
logger.info("MinIO未启用使用本地文件存储")
self._use_minio = False
return
try:
self._client = Minio(
settings.MINIO_ENDPOINT,
access_key=settings.MINIO_ACCESS_KEY,
secret_key=settings.MINIO_SECRET_KEY,
secure=settings.MINIO_SECURE,
)
# 验证连接并确保bucket存在
bucket_name = self._get_bucket_name()
if not self._client.bucket_exists(bucket_name):
self._client.make_bucket(bucket_name)
logger.info(f"创建MinIO bucket: {bucket_name}")
# 设置bucket策略为公开读取
self._set_bucket_public_read(bucket_name)
self._use_minio = True
logger.info(f"MinIO存储服务初始化成功 - endpoint: {settings.MINIO_ENDPOINT}, bucket: {bucket_name}")
except Exception as e:
logger.warning(f"MinIO初始化失败降级为本地存储: {e}")
self._use_minio = False
def _get_bucket_name(self) -> str:
"""获取当前租户的bucket名称"""
return f"kpl-{settings.TENANT_CODE}"
def _set_bucket_public_read(self, bucket_name: str):
"""设置bucket为公开读取"""
try:
# 设置匿名读取策略
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"AWS": "*"},
"Action": ["s3:GetObject"],
"Resource": [f"arn:aws:s3:::{bucket_name}/*"]
}
]
}
import json
self._client.set_bucket_policy(bucket_name, json.dumps(policy))
except Exception as e:
logger.warning(f"设置bucket公开读取策略失败: {e}")
def _normalize_object_name(self, object_name: str) -> str:
"""标准化对象名称,移除前缀斜杠"""
if object_name.startswith('/'):
object_name = object_name[1:]
if object_name.startswith('static/uploads/'):
object_name = object_name.replace('static/uploads/', '')
return object_name
def _get_file_url(self, object_name: str) -> str:
"""获取文件访问URL"""
object_name = self._normalize_object_name(object_name)
# 统一返回 /static/uploads/ 格式的URL由Nginx代理到MinIO
return f"/static/uploads/{object_name}"
def _get_local_path(self, object_name: str) -> Path:
"""获取本地文件路径"""
object_name = self._normalize_object_name(object_name)
return Path(settings.UPLOAD_PATH) / object_name
async def upload(
self,
file_data: Union[bytes, BinaryIO],
object_name: str,
content_type: Optional[str] = None,
) -> str:
"""
上传文件
Args:
file_data: 文件数据bytes或文件对象
object_name: 对象名称(如 courses/1/doc.pdf
content_type: 文件MIME类型
Returns:
文件访问URL
"""
self._ensure_initialized()
object_name = self._normalize_object_name(object_name)
# 转换为bytes
if isinstance(file_data, bytes):
data = file_data
else:
data = file_data.read()
if self._use_minio:
return await self._upload_to_minio(data, object_name, content_type)
else:
return await self._upload_to_local(data, object_name)
async def _upload_to_minio(
self,
data: bytes,
object_name: str,
content_type: Optional[str] = None,
) -> str:
"""上传到MinIO"""
try:
bucket_name = self._get_bucket_name()
# 自动检测content_type
if not content_type:
content_type = self._guess_content_type(object_name)
self._client.put_object(
bucket_name,
object_name,
io.BytesIO(data),
length=len(data),
content_type=content_type,
)
file_url = self._get_file_url(object_name)
logger.info(f"文件上传到MinIO成功: {object_name} -> {file_url}")
return file_url
except S3Error as e:
logger.error(f"MinIO上传失败: {e}")
# 降级到本地存储
return await self._upload_to_local(data, object_name)
async def _upload_to_local(self, data: bytes, object_name: str) -> str:
"""上传到本地文件系统"""
try:
file_path = self._get_local_path(object_name)
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'wb') as f:
f.write(data)
file_url = self._get_file_url(object_name)
logger.info(f"文件上传到本地成功: {object_name} -> {file_url}")
return file_url
except Exception as e:
logger.error(f"本地文件上传失败: {e}")
raise
async def download(self, object_name: str) -> Optional[bytes]:
"""
下载文件
Args:
object_name: 对象名称
Returns:
文件数据如果文件不存在返回None
"""
self._ensure_initialized()
object_name = self._normalize_object_name(object_name)
if self._use_minio:
return await self._download_from_minio(object_name)
else:
return await self._download_from_local(object_name)
async def _download_from_minio(self, object_name: str) -> Optional[bytes]:
"""从MinIO下载"""
try:
bucket_name = self._get_bucket_name()
response = self._client.get_object(bucket_name, object_name)
data = response.read()
response.close()
response.release_conn()
return data
except S3Error as e:
if e.code == 'NoSuchKey':
logger.warning(f"MinIO文件不存在: {object_name}")
# 尝试从本地读取(兼容迁移过渡期)
return await self._download_from_local(object_name)
logger.error(f"MinIO下载失败: {e}")
return None
async def _download_from_local(self, object_name: str) -> Optional[bytes]:
"""从本地文件系统下载"""
try:
file_path = self._get_local_path(object_name)
if not file_path.exists():
logger.warning(f"本地文件不存在: {file_path}")
return None
with open(file_path, 'rb') as f:
return f.read()
except Exception as e:
logger.error(f"本地文件下载失败: {e}")
return None
async def delete(self, object_name: str) -> bool:
"""
删除文件
Args:
object_name: 对象名称
Returns:
是否删除成功
"""
self._ensure_initialized()
object_name = self._normalize_object_name(object_name)
success = True
# MinIO删除
if self._use_minio:
try:
bucket_name = self._get_bucket_name()
self._client.remove_object(bucket_name, object_name)
logger.info(f"MinIO文件删除成功: {object_name}")
except S3Error as e:
if e.code != 'NoSuchKey':
logger.error(f"MinIO文件删除失败: {e}")
success = False
# 同时删除本地文件(确保彻底清理)
try:
file_path = self._get_local_path(object_name)
if file_path.exists():
os.remove(file_path)
logger.info(f"本地文件删除成功: {file_path}")
except Exception as e:
logger.warning(f"本地文件删除失败: {e}")
return success
async def exists(self, object_name: str) -> bool:
"""
检查文件是否存在
Args:
object_name: 对象名称
Returns:
文件是否存在
"""
self._ensure_initialized()
object_name = self._normalize_object_name(object_name)
if self._use_minio:
try:
bucket_name = self._get_bucket_name()
self._client.stat_object(bucket_name, object_name)
return True
except S3Error:
pass
# 检查本地文件
file_path = self._get_local_path(object_name)
return file_path.exists()
async def get_file_path(self, object_name: str) -> Optional[Path]:
"""
获取文件的本地路径(用于需要本地文件操作的场景)
如果文件在MinIO中会先下载到临时目录
Args:
object_name: 对象名称
Returns:
本地文件路径如果文件不存在返回None
"""
self._ensure_initialized()
object_name = self._normalize_object_name(object_name)
# 先检查本地是否存在
local_path = self._get_local_path(object_name)
if local_path.exists():
return local_path
# 如果MinIO启用尝试下载到本地缓存
if self._use_minio:
try:
data = await self._download_from_minio(object_name)
if data:
# 保存到本地缓存
local_path.parent.mkdir(parents=True, exist_ok=True)
with open(local_path, 'wb') as f:
f.write(data)
logger.info(f"从MinIO下载文件到本地缓存: {object_name}")
return local_path
except Exception as e:
logger.error(f"下载MinIO文件到本地失败: {e}")
return None
def get_presigned_url(self, object_name: str, expires: int = 3600) -> Optional[str]:
"""
获取预签名URL用于直接访问MinIO
Args:
object_name: 对象名称
expires: 过期时间(秒)
Returns:
预签名URL如果MinIO未启用返回None
"""
self._ensure_initialized()
if not self._use_minio:
return None
object_name = self._normalize_object_name(object_name)
try:
bucket_name = self._get_bucket_name()
url = self._client.presigned_get_object(
bucket_name,
object_name,
expires=timedelta(seconds=expires)
)
return url
except S3Error as e:
logger.error(f"获取预签名URL失败: {e}")
return None
def _guess_content_type(self, filename: str) -> str:
"""根据文件名猜测MIME类型"""
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
content_types = {
'pdf': 'application/pdf',
'doc': 'application/msword',
'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'xls': 'application/vnd.ms-excel',
'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'ppt': 'application/vnd.ms-powerpoint',
'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'txt': 'text/plain',
'md': 'text/markdown',
'html': 'text/html',
'htm': 'text/html',
'csv': 'text/csv',
'json': 'application/json',
'xml': 'application/xml',
'zip': 'application/zip',
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'gif': 'image/gif',
'webp': 'image/webp',
'mp3': 'audio/mpeg',
'wav': 'audio/wav',
'mp4': 'video/mp4',
'webm': 'video/webm',
}
return content_types.get(ext, 'application/octet-stream')
@property
def is_minio_enabled(self) -> bool:
"""检查MinIO是否启用"""
self._ensure_initialized()
return self._use_minio
# 全局单例
storage_service = StorageService()