- 新增 platform_scheduled_tasks, platform_task_logs, platform_script_vars, platform_secrets 数据库表 - 实现 ScriptSDK 提供 AI/通知/DB/HTTP/变量存储/参数获取等功能 - 实现安全的脚本执行器,支持沙箱环境和禁止危险操作 - 实现 APScheduler 调度服务,支持简单时间点和 CRON 表达式 - 新增定时任务 API 路由,包含 CRUD、执行、日志、密钥管理 - 新增定时任务前端页面,支持脚本编辑、测试运行、日志查看
This commit is contained in:
@@ -1,399 +1,308 @@
|
||||
"""定时任务调度器服务"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
"""定时任务调度服务"""
|
||||
import json
|
||||
import httpx
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..database import SessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局调度器实例
|
||||
scheduler: Optional[AsyncIOScheduler] = None
|
||||
from ..models.scheduled_task import ScheduledTask, TaskLog
|
||||
from .script_executor import ScriptExecutor
|
||||
|
||||
|
||||
def get_scheduler() -> AsyncIOScheduler:
|
||||
"""获取调度器实例"""
|
||||
global scheduler
|
||||
if scheduler is None:
|
||||
scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
|
||||
return scheduler
|
||||
|
||||
|
||||
def get_db_session() -> Session:
|
||||
"""获取数据库会话"""
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
async def send_alert(webhook: str, task_name: str, error_message: str):
|
||||
"""发送失败告警通知"""
|
||||
try:
|
||||
# 自动判断钉钉或企微
|
||||
if "dingtalk" in webhook or "oapi.dingtalk.com" in webhook:
|
||||
data = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"title": "定时任务执行失败",
|
||||
"text": f"### ⚠️ 定时任务执行失败\n\n**任务名称**:{task_name}\n\n**错误信息**:{error_message[:500]}\n\n**时间**:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
}
|
||||
}
|
||||
else:
|
||||
# 企微格式
|
||||
data = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"content": f"### ⚠️ 定时任务执行失败\n\n**任务名称**:{task_name}\n\n**错误信息**:{error_message[:500]}\n\n**时间**:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
}
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
await client.post(webhook, json=data)
|
||||
logger.info(f"Alert sent for task {task_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send alert: {e}")
|
||||
|
||||
|
||||
async def execute_task_with_retry(task_id: int, retry_count: int = 0, max_retries: int = 0, retry_interval: int = 60):
|
||||
"""带重试的任务执行"""
|
||||
success = await execute_task_once(task_id)
|
||||
class SchedulerService:
|
||||
"""调度服务 - 管理定时任务的调度和执行"""
|
||||
|
||||
if not success and retry_count < max_retries:
|
||||
logger.info(f"Task {task_id} failed, scheduling retry {retry_count + 1}/{max_retries} in {retry_interval}s")
|
||||
await asyncio.sleep(retry_interval)
|
||||
await execute_task_with_retry(task_id, retry_count + 1, max_retries, retry_interval)
|
||||
elif not success:
|
||||
# 所有重试都失败,发送告警
|
||||
db = get_db_session()
|
||||
_instance: Optional['SchedulerService'] = None
|
||||
_scheduler: Optional[AsyncIOScheduler] = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._scheduler is None:
|
||||
self._scheduler = AsyncIOScheduler(timezone='Asia/Shanghai')
|
||||
|
||||
@property
|
||||
def scheduler(self) -> AsyncIOScheduler:
|
||||
return self._scheduler
|
||||
|
||||
def start(self):
|
||||
"""启动调度器并加载所有任务"""
|
||||
if not self._scheduler.running:
|
||||
self._scheduler.start()
|
||||
self._load_all_tasks()
|
||||
print("调度器已启动")
|
||||
|
||||
def shutdown(self):
|
||||
"""关闭调度器"""
|
||||
if self._scheduler.running:
|
||||
self._scheduler.shutdown()
|
||||
print("调度器已关闭")
|
||||
|
||||
def _load_all_tasks(self):
|
||||
"""从数据库加载所有启用的任务"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
result = db.execute(
|
||||
text("SELECT task_name, alert_on_failure, alert_webhook, last_run_message FROM platform_scheduled_tasks WHERE id = :id"),
|
||||
{"id": task_id}
|
||||
)
|
||||
task = result.mappings().first()
|
||||
if task and task["alert_on_failure"] and task["alert_webhook"]:
|
||||
await send_alert(task["alert_webhook"], task["task_name"], task["last_run_message"] or "未知错误")
|
||||
tasks = db.query(ScheduledTask).filter(ScheduledTask.status == 1).all()
|
||||
for task in tasks:
|
||||
self._add_task_to_scheduler(task)
|
||||
print(f"已加载 {len(tasks)} 个定时任务")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _add_task_to_scheduler(self, task: ScheduledTask):
|
||||
"""将任务添加到调度器"""
|
||||
job_id = f"task_{task.id}"
|
||||
|
||||
# 移除已存在的任务
|
||||
if self._scheduler.get_job(job_id):
|
||||
self._scheduler.remove_job(job_id)
|
||||
|
||||
if task.schedule_type == 'cron' and task.cron_expression:
|
||||
# CRON模式
|
||||
try:
|
||||
trigger = CronTrigger.from_crontab(task.cron_expression, timezone='Asia/Shanghai')
|
||||
self._scheduler.add_job(
|
||||
self._execute_task,
|
||||
trigger,
|
||||
id=job_id,
|
||||
args=[task.id],
|
||||
replace_existing=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"任务 {task.id} CRON表达式解析失败: {e}")
|
||||
|
||||
elif task.schedule_type == 'simple' and task.time_points:
|
||||
# 简单模式 - 多个时间点
|
||||
try:
|
||||
time_points = json.loads(task.time_points)
|
||||
for i, time_point in enumerate(time_points):
|
||||
hour, minute = map(int, time_point.split(':'))
|
||||
sub_job_id = f"{job_id}_{i}"
|
||||
self._scheduler.add_job(
|
||||
self._execute_task,
|
||||
CronTrigger(hour=hour, minute=minute, timezone='Asia/Shanghai'),
|
||||
id=sub_job_id,
|
||||
args=[task.id],
|
||||
replace_existing=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"任务 {task.id} 时间点解析失败: {e}")
|
||||
|
||||
def add_task(self, task_id: int):
|
||||
"""添加或更新任务调度"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if task and task.status == 1:
|
||||
self._add_task_to_scheduler(task)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def remove_task(self, task_id: int):
|
||||
"""移除任务调度"""
|
||||
job_id = f"task_{task_id}"
|
||||
|
||||
# 移除主任务
|
||||
if self._scheduler.get_job(job_id):
|
||||
self._scheduler.remove_job(job_id)
|
||||
|
||||
# 移除简单模式的子任务
|
||||
for i in range(24): # 最多24个时间点
|
||||
sub_job_id = f"{job_id}_{i}"
|
||||
if self._scheduler.get_job(sub_job_id):
|
||||
self._scheduler.remove_job(sub_job_id)
|
||||
|
||||
async def _execute_task(self, task_id: int):
|
||||
"""执行任务(带重试)"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
return
|
||||
|
||||
max_retries = task.retry_count or 0
|
||||
retry_interval = task.retry_interval or 60
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
success, output, error = await self._execute_task_once(db, task)
|
||||
|
||||
if success:
|
||||
return
|
||||
|
||||
# 如果还有重试机会
|
||||
if attempt < max_retries:
|
||||
print(f"任务 {task_id} 执行失败,{retry_interval}秒后重试 ({attempt + 1}/{max_retries})")
|
||||
await asyncio.sleep(retry_interval)
|
||||
else:
|
||||
# 最后一次失败,发送告警
|
||||
if task.alert_on_failure and task.alert_webhook:
|
||||
await self._send_alert(task, error)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _execute_task_once(self, db: Session, task: ScheduledTask):
|
||||
"""执行一次任务"""
|
||||
trace_id = f"{int(datetime.now().timestamp())}-{task.id}"
|
||||
started_at = datetime.now()
|
||||
|
||||
# 创建日志记录
|
||||
log = TaskLog(
|
||||
task_id=task.id,
|
||||
tenant_id=task.tenant_id,
|
||||
trace_id=trace_id,
|
||||
status='running',
|
||||
started_at=started_at
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
db.refresh(log)
|
||||
|
||||
success = False
|
||||
output = ''
|
||||
error = ''
|
||||
|
||||
try:
|
||||
# 解析输入参数
|
||||
params = {}
|
||||
if task.input_params:
|
||||
try:
|
||||
params = json.loads(task.input_params)
|
||||
except:
|
||||
pass
|
||||
|
||||
if task.task_type == 'webhook':
|
||||
success, output, error = await self._execute_webhook(task)
|
||||
else:
|
||||
success, output, error = await self._execute_script(db, task, trace_id, params)
|
||||
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
|
||||
# 更新日志
|
||||
finished_at = datetime.now()
|
||||
duration_ms = int((finished_at - started_at).total_seconds() * 1000)
|
||||
|
||||
log.status = 'success' if success else 'failed'
|
||||
log.finished_at = finished_at
|
||||
log.duration_ms = duration_ms
|
||||
log.output = output[:10000] if output else None # 限制长度
|
||||
log.error = error[:5000] if error else None
|
||||
|
||||
# 更新任务状态
|
||||
task.last_run_at = finished_at
|
||||
task.last_run_status = 'success' if success else 'failed'
|
||||
|
||||
db.commit()
|
||||
|
||||
return success, output, error
|
||||
|
||||
async def _execute_webhook(self, task: ScheduledTask):
|
||||
"""执行Webhook任务"""
|
||||
try:
|
||||
headers = {}
|
||||
if task.webhook_headers:
|
||||
headers = json.loads(task.webhook_headers)
|
||||
|
||||
body = {}
|
||||
if task.input_params:
|
||||
body = json.loads(task.input_params)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
if task.webhook_method.upper() == 'GET':
|
||||
response = await client.get(task.webhook_url, headers=headers, params=body)
|
||||
else:
|
||||
response = await client.post(task.webhook_url, headers=headers, json=body)
|
||||
|
||||
response.raise_for_status()
|
||||
return True, response.text[:5000], ''
|
||||
|
||||
except Exception as e:
|
||||
return False, '', str(e)
|
||||
|
||||
async def _execute_script(self, db: Session, task: ScheduledTask, trace_id: str, params: dict):
|
||||
"""执行脚本任务"""
|
||||
if not task.script_content:
|
||||
return False, '', '脚本内容为空'
|
||||
|
||||
executor = ScriptExecutor(db)
|
||||
success, output, error = executor.execute(
|
||||
script_content=task.script_content,
|
||||
task_id=task.id,
|
||||
tenant_id=task.tenant_id,
|
||||
trace_id=trace_id,
|
||||
params=params,
|
||||
timeout=task.script_timeout or 300
|
||||
)
|
||||
|
||||
return success, output, error
|
||||
|
||||
async def _send_alert(self, task: ScheduledTask, error: str):
|
||||
"""发送失败告警"""
|
||||
if not task.alert_webhook:
|
||||
return
|
||||
|
||||
content = f"""### 定时任务执行失败告警
|
||||
|
||||
**任务名称**: {task.task_name}
|
||||
**任务ID**: {task.id}
|
||||
**租户**: {task.tenant_id or '全局'}
|
||||
**失败时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**错误信息**:
|
||||
```
|
||||
{error[:500] if error else '未知错误'}
|
||||
```"""
|
||||
|
||||
try:
|
||||
# 判断是钉钉还是企微
|
||||
if 'dingtalk' in task.alert_webhook or 'oapi.dingtalk.com' in task.alert_webhook:
|
||||
payload = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"title": "任务失败告警", "text": content}
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content}
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
await client.post(task.alert_webhook, json=payload)
|
||||
except Exception as e:
|
||||
print(f"发送告警失败: {e}")
|
||||
|
||||
async def run_task_now(self, task_id: int) -> dict:
|
||||
"""立即执行任务(手动触发)"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
return {"success": False, "error": "任务不存在"}
|
||||
|
||||
# 解析参数
|
||||
params = {}
|
||||
if task.input_params:
|
||||
try:
|
||||
params = json.loads(task.input_params)
|
||||
except:
|
||||
pass
|
||||
|
||||
success, output, error = await self._execute_task_once(db, task)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"output": output,
|
||||
"error": error
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def execute_task(task_id: int):
|
||||
"""执行定时任务入口(处理重试配置)"""
|
||||
db = get_db_session()
|
||||
try:
|
||||
result = db.execute(
|
||||
text("SELECT retry_count, retry_interval FROM platform_scheduled_tasks WHERE id = :id"),
|
||||
{"id": task_id}
|
||||
)
|
||||
task = result.mappings().first()
|
||||
if task:
|
||||
max_retries = task.get("retry_count", 0) or 0
|
||||
retry_interval = task.get("retry_interval", 60) or 60
|
||||
await execute_task_with_retry(task_id, 0, max_retries, retry_interval)
|
||||
else:
|
||||
await execute_task_once(task_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def execute_task_once(task_id: int) -> bool:
|
||||
"""执行一次定时任务,返回是否成功"""
|
||||
db = get_db_session()
|
||||
log_id = None
|
||||
success = False
|
||||
|
||||
try:
|
||||
# 1. 查询任务配置
|
||||
result = db.execute(
|
||||
text("SELECT * FROM platform_scheduled_tasks WHERE id = :id AND is_enabled = 1"),
|
||||
{"id": task_id}
|
||||
)
|
||||
task = result.mappings().first()
|
||||
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found or disabled")
|
||||
return True # 不需要重试
|
||||
|
||||
# 2. 更新任务状态为运行中
|
||||
db.execute(
|
||||
text("UPDATE platform_scheduled_tasks SET last_run_status = 'running', last_run_at = NOW() WHERE id = :id"),
|
||||
{"id": task_id}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
# 3. 创建执行日志
|
||||
db.execute(
|
||||
text("""
|
||||
INSERT INTO platform_task_logs (task_id, tenant_id, started_at, status)
|
||||
VALUES (:task_id, :tenant_id, NOW(), 'running')
|
||||
"""),
|
||||
{"task_id": task_id, "tenant_id": task["tenant_id"]}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
# 获取刚插入的日志ID
|
||||
result = db.execute(text("SELECT LAST_INSERT_ID() as id"))
|
||||
log_id = result.scalar()
|
||||
|
||||
# 生成 trace_id
|
||||
trace_id = f"task_{task_id}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
||||
|
||||
# 4. 根据执行类型分发
|
||||
execution_type = task.get("execution_type", "webhook")
|
||||
|
||||
if execution_type == "script":
|
||||
# 脚本执行模式
|
||||
from .script_executor import execute_script as run_script
|
||||
|
||||
script_content = task.get("script_content", "")
|
||||
if not script_content:
|
||||
status = "failed"
|
||||
error_message = "脚本内容为空"
|
||||
response_code = None
|
||||
response_body = ""
|
||||
else:
|
||||
script_result = await run_script(
|
||||
task_id=task_id,
|
||||
tenant_id=task["tenant_id"],
|
||||
script_content=script_content,
|
||||
trace_id=trace_id
|
||||
)
|
||||
|
||||
if script_result.success:
|
||||
status = "success"
|
||||
error_message = None
|
||||
else:
|
||||
status = "failed"
|
||||
error_message = script_result.error
|
||||
|
||||
response_code = None
|
||||
response_body = script_result.output[:5000] if script_result.output else ""
|
||||
|
||||
# 添加日志到响应体
|
||||
if script_result.logs:
|
||||
response_body += "\n\n--- 执行日志 ---\n" + "\n".join(script_result.logs[-20:])
|
||||
else:
|
||||
# Webhook 执行模式
|
||||
webhook_url = task["webhook_url"]
|
||||
input_params = task["input_params"] or {}
|
||||
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
response = await client.post(
|
||||
webhook_url,
|
||||
json=input_params,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
response_code = response.status_code
|
||||
response_body = response.text[:5000] if response.text else "" # 限制存储长度
|
||||
|
||||
if response.is_success:
|
||||
status = "success"
|
||||
error_message = None
|
||||
else:
|
||||
status = "failed"
|
||||
error_message = f"HTTP {response_code}"
|
||||
|
||||
# 5. 更新执行日志
|
||||
db.execute(
|
||||
text("""
|
||||
UPDATE platform_task_logs
|
||||
SET finished_at = NOW(), status = :status, response_code = :code,
|
||||
response_body = :body, error_message = :error
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{
|
||||
"id": log_id,
|
||||
"status": status,
|
||||
"code": response_code,
|
||||
"body": response_body,
|
||||
"error": error_message
|
||||
}
|
||||
)
|
||||
|
||||
# 6. 更新任务状态
|
||||
db.execute(
|
||||
text("""
|
||||
UPDATE platform_scheduled_tasks
|
||||
SET last_run_status = :status, last_run_message = :message
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{
|
||||
"id": task_id,
|
||||
"status": status,
|
||||
"message": error_message or "执行成功"
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Task {task_id} executed with status: {status}")
|
||||
success = (status == "success")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task {task_id} execution error: {str(e)}")
|
||||
success = False
|
||||
|
||||
# 更新失败状态
|
||||
try:
|
||||
if log_id:
|
||||
db.execute(
|
||||
text("""
|
||||
UPDATE platform_task_logs
|
||||
SET finished_at = NOW(), status = 'failed', error_message = :error
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{"id": log_id, "error": str(e)[:1000]}
|
||||
)
|
||||
|
||||
db.execute(
|
||||
text("""
|
||||
UPDATE platform_scheduled_tasks
|
||||
SET last_run_status = 'failed', last_run_message = :message
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{"id": task_id, "message": str(e)[:500]}
|
||||
)
|
||||
db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def add_task_to_scheduler(task: Dict[str, Any]):
|
||||
"""将任务添加到调度器"""
|
||||
global scheduler
|
||||
if scheduler is None:
|
||||
return
|
||||
|
||||
task_id = task["id"]
|
||||
schedule_type = task["schedule_type"]
|
||||
|
||||
# 先移除已有的任务(如果存在)
|
||||
remove_task_from_scheduler(task_id)
|
||||
|
||||
if schedule_type == "cron":
|
||||
# CRON 模式
|
||||
cron_expr = task["cron_expression"]
|
||||
if cron_expr:
|
||||
try:
|
||||
trigger = CronTrigger.from_crontab(cron_expr, timezone="Asia/Shanghai")
|
||||
scheduler.add_job(
|
||||
execute_task,
|
||||
trigger,
|
||||
args=[task_id],
|
||||
id=f"task_{task_id}_cron",
|
||||
replace_existing=True
|
||||
)
|
||||
logger.info(f"Added cron task {task_id}: {cron_expr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add cron task {task_id}: {e}")
|
||||
else:
|
||||
# 简单模式 - 多个时间点
|
||||
time_points = task.get("time_points") or []
|
||||
if isinstance(time_points, str):
|
||||
import json
|
||||
time_points = json.loads(time_points)
|
||||
|
||||
for i, time_point in enumerate(time_points):
|
||||
try:
|
||||
hour, minute = time_point.split(":")
|
||||
trigger = CronTrigger(
|
||||
hour=int(hour),
|
||||
minute=int(minute),
|
||||
timezone="Asia/Shanghai"
|
||||
)
|
||||
scheduler.add_job(
|
||||
execute_task,
|
||||
trigger,
|
||||
args=[task_id],
|
||||
id=f"task_{task_id}_time_{i}",
|
||||
replace_existing=True
|
||||
)
|
||||
logger.info(f"Added simple task {task_id} at {time_point}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add time point {time_point} for task {task_id}: {e}")
|
||||
|
||||
|
||||
def remove_task_from_scheduler(task_id: int):
|
||||
"""从调度器移除任务"""
|
||||
global scheduler
|
||||
if scheduler is None:
|
||||
return
|
||||
|
||||
# 移除所有相关的 job
|
||||
jobs_to_remove = []
|
||||
for job in scheduler.get_jobs():
|
||||
if job.id.startswith(f"task_{task_id}_"):
|
||||
jobs_to_remove.append(job.id)
|
||||
|
||||
for job_id in jobs_to_remove:
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
logger.info(f"Removed job {job_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to remove job {job_id}: {e}")
|
||||
|
||||
|
||||
def load_all_tasks():
|
||||
"""从数据库加载所有启用的任务"""
|
||||
db = get_db_session()
|
||||
try:
|
||||
result = db.execute(
|
||||
text("SELECT * FROM platform_scheduled_tasks WHERE is_enabled = 1")
|
||||
)
|
||||
tasks = result.mappings().all()
|
||||
|
||||
for task in tasks:
|
||||
add_task_to_scheduler(dict(task))
|
||||
|
||||
logger.info(f"Loaded {len(tasks)} scheduled tasks")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def start_scheduler():
|
||||
"""启动调度器"""
|
||||
global scheduler
|
||||
scheduler = get_scheduler()
|
||||
|
||||
# 加载所有任务
|
||||
load_all_tasks()
|
||||
|
||||
# 启动调度器
|
||||
if not scheduler.running:
|
||||
scheduler.start()
|
||||
logger.info("Scheduler started")
|
||||
|
||||
|
||||
def shutdown_scheduler():
|
||||
"""关闭调度器"""
|
||||
global scheduler
|
||||
if scheduler and scheduler.running:
|
||||
scheduler.shutdown(wait=False)
|
||||
logger.info("Scheduler shutdown")
|
||||
|
||||
|
||||
def reload_task(task_id: int):
|
||||
"""重新加载单个任务"""
|
||||
db = get_db_session()
|
||||
try:
|
||||
result = db.execute(
|
||||
text("SELECT * FROM platform_scheduled_tasks WHERE id = :id"),
|
||||
{"id": task_id}
|
||||
)
|
||||
task = result.mappings().first()
|
||||
|
||||
if task and task["is_enabled"]:
|
||||
add_task_to_scheduler(dict(task))
|
||||
else:
|
||||
remove_task_from_scheduler(task_id)
|
||||
finally:
|
||||
db.close()
|
||||
# 全局调度器实例
|
||||
scheduler_service = SchedulerService()
|
||||
|
||||
@@ -1,262 +1,246 @@
|
||||
"""脚本执行器 - 安全执行 Python 脚本"""
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
"""脚本执行器 - 安全执行Python脚本"""
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .script_sdk import ScriptSDK
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 执行超时时间(秒)
|
||||
SCRIPT_TIMEOUT = 300 # 5 分钟
|
||||
|
||||
# 禁止导入的模块
|
||||
FORBIDDEN_MODULES = {
|
||||
'os', 'subprocess', 'sys', 'builtins', '__builtins__',
|
||||
'importlib', 'eval', 'exec', 'compile',
|
||||
'open', 'file', 'input',
|
||||
'socket', 'multiprocessing', 'threading',
|
||||
'pickle', 'marshal', 'ctypes',
|
||||
'code', 'codeop', 'pty', 'tty',
|
||||
'os', 'subprocess', 'shutil', 'pathlib',
|
||||
'socket', 'ftplib', 'telnetlib', 'smtplib',
|
||||
'pickle', 'shelve', 'marshal',
|
||||
'ctypes', 'multiprocessing',
|
||||
'__builtins__', 'builtins',
|
||||
'importlib', 'imp',
|
||||
'code', 'codeop', 'compile',
|
||||
}
|
||||
|
||||
# 允许的内置函数
|
||||
ALLOWED_BUILTINS = {
|
||||
'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'bytearray', 'bytes',
|
||||
'callable', 'chr', 'complex', 'dict', 'dir', 'divmod', 'enumerate',
|
||||
'filter', 'float', 'format', 'frozenset', 'getattr', 'hasattr', 'hash',
|
||||
'hex', 'id', 'int', 'isinstance', 'issubclass', 'iter', 'len', 'list',
|
||||
'map', 'max', 'min', 'next', 'object', 'oct', 'ord', 'pow', 'print',
|
||||
'range', 'repr', 'reversed', 'round', 'set', 'setattr', 'slice',
|
||||
'sorted', 'str', 'sum', 'tuple', 'type', 'vars', 'zip',
|
||||
'True', 'False', 'None',
|
||||
'Exception', 'BaseException', 'ValueError', 'TypeError', 'KeyError',
|
||||
'IndexError', 'AttributeError', 'RuntimeError', 'StopIteration',
|
||||
}
|
||||
|
||||
|
||||
class ScriptExecutionResult:
|
||||
"""脚本执行结果"""
|
||||
class ScriptExecutor:
|
||||
"""脚本执行器"""
|
||||
|
||||
def __init__(
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def execute(
|
||||
self,
|
||||
success: bool,
|
||||
output: str = "",
|
||||
error: str = None,
|
||||
logs: list = None,
|
||||
execution_time_ms: int = 0
|
||||
):
|
||||
self.success = success
|
||||
self.output = output
|
||||
self.error = error
|
||||
self.logs = logs or []
|
||||
self.execution_time_ms = execution_time_ms
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
"error": self.error,
|
||||
"logs": self.logs,
|
||||
"execution_time_ms": self.execution_time_ms
|
||||
}
|
||||
|
||||
|
||||
def create_safe_builtins() -> Dict[str, Any]:
|
||||
"""创建安全的内置函数集"""
|
||||
import builtins
|
||||
|
||||
# 允许的内置函数
|
||||
allowed = [
|
||||
'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'bytearray', 'bytes',
|
||||
'callable', 'chr', 'complex', 'dict', 'divmod', 'enumerate',
|
||||
'filter', 'float', 'format', 'frozenset', 'getattr', 'hasattr',
|
||||
'hash', 'hex', 'id', 'int', 'isinstance', 'issubclass', 'iter',
|
||||
'len', 'list', 'map', 'max', 'min', 'next', 'object', 'oct',
|
||||
'ord', 'pow', 'print', 'range', 'repr', 'reversed', 'round',
|
||||
'set', 'slice', 'sorted', 'str', 'sum', 'tuple', 'type', 'zip',
|
||||
'True', 'False', 'None',
|
||||
]
|
||||
|
||||
safe_builtins = {}
|
||||
for name in allowed:
|
||||
if hasattr(builtins, name):
|
||||
safe_builtins[name] = getattr(builtins, name)
|
||||
|
||||
# 添加安全的 import 函数
|
||||
def safe_import(name, *args, **kwargs):
|
||||
"""安全的 import 函数,只允许特定模块"""
|
||||
allowed_modules = {
|
||||
'json', 'datetime', 'time', 're', 'math', 'random',
|
||||
'collections', 'itertools', 'functools', 'operator',
|
||||
'string', 'textwrap', 'unicodedata',
|
||||
'hashlib', 'base64', 'urllib.parse',
|
||||
}
|
||||
script_content: str,
|
||||
task_id: int,
|
||||
tenant_id: Optional[str] = None,
|
||||
trace_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 300
|
||||
) -> Tuple[bool, str, str]:
|
||||
"""执行脚本
|
||||
|
||||
if name in FORBIDDEN_MODULES:
|
||||
raise ImportError(f"禁止导入模块: {name}")
|
||||
Args:
|
||||
script_content: Python脚本内容
|
||||
task_id: 任务ID
|
||||
tenant_id: 租户ID
|
||||
trace_id: 追踪ID
|
||||
params: 输入参数
|
||||
timeout: 超时秒数
|
||||
|
||||
Returns:
|
||||
(success, output, error)
|
||||
"""
|
||||
# 创建SDK实例
|
||||
sdk = ScriptSDK(
|
||||
db=self.db,
|
||||
task_id=task_id,
|
||||
tenant_id=tenant_id,
|
||||
trace_id=trace_id,
|
||||
params=params or {}
|
||||
)
|
||||
|
||||
if name not in allowed_modules and not name.startswith('urllib.parse'):
|
||||
raise ImportError(f"不允许导入模块: {name},允许的模块: {', '.join(sorted(allowed_modules))}")
|
||||
|
||||
return __builtins__['__import__'](name, *args, **kwargs)
|
||||
|
||||
safe_builtins['__import__'] = safe_import
|
||||
|
||||
return safe_builtins
|
||||
|
||||
|
||||
async def execute_script(
|
||||
task_id: int,
|
||||
tenant_id: str,
|
||||
script_content: str,
|
||||
trace_id: str = None
|
||||
) -> ScriptExecutionResult:
|
||||
"""
|
||||
执行 Python 脚本
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
tenant_id: 租户 ID
|
||||
script_content: 脚本内容
|
||||
trace_id: 追踪 ID
|
||||
|
||||
Returns:
|
||||
ScriptExecutionResult: 执行结果
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
sdk = None
|
||||
|
||||
try:
|
||||
# 创建 SDK 实例
|
||||
sdk = ScriptSDK(tenant_id, task_id, trace_id)
|
||||
# 检查脚本安全性
|
||||
check_result = self._check_script_safety(script_content)
|
||||
if check_result:
|
||||
return False, '', f"脚本安全检查失败: {check_result}"
|
||||
|
||||
# 准备执行环境
|
||||
script_globals = {
|
||||
'__builtins__': create_safe_builtins(),
|
||||
'__name__': '__script__',
|
||||
|
||||
# SDK 实例
|
||||
'sdk': sdk,
|
||||
|
||||
# 快捷方法(同步包装)
|
||||
'ai': lambda *args, **kwargs: asyncio.get_event_loop().run_until_complete(sdk.ai_chat(*args, **kwargs)),
|
||||
'dingtalk': lambda *args, **kwargs: asyncio.get_event_loop().run_until_complete(sdk.send_dingtalk(*args, **kwargs)),
|
||||
'wecom': lambda *args, **kwargs: asyncio.get_event_loop().run_until_complete(sdk.send_wecom(*args, **kwargs)),
|
||||
'http_get': lambda *args, **kwargs: asyncio.get_event_loop().run_until_complete(sdk.http_get(*args, **kwargs)),
|
||||
'http_post': lambda *args, **kwargs: asyncio.get_event_loop().run_until_complete(sdk.http_post(*args, **kwargs)),
|
||||
|
||||
# 同步方法
|
||||
'db': sdk.db_query,
|
||||
'get_var': sdk.get_var,
|
||||
'set_var': sdk.set_var,
|
||||
'delete_var': sdk.delete_var,
|
||||
'log': sdk.log,
|
||||
|
||||
# 常用模块
|
||||
'json': __import__('json'),
|
||||
'datetime': __import__('datetime'),
|
||||
're': __import__('re'),
|
||||
'math': __import__('math'),
|
||||
'random': __import__('random'),
|
||||
}
|
||||
safe_globals = self._create_safe_globals(sdk)
|
||||
|
||||
# 捕获输出
|
||||
stdout = io.StringIO()
|
||||
stderr = io.StringIO()
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
stdout_capture = StringIO()
|
||||
stderr_capture = StringIO()
|
||||
|
||||
sdk.log("脚本开始执行")
|
||||
|
||||
# 编译并执行脚本
|
||||
try:
|
||||
# 编译脚本
|
||||
code = compile(script_content, '<script>', 'exec')
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
|
||||
# 执行(带超时)
|
||||
async def run_script():
|
||||
with redirect_stdout(stdout), redirect_stderr(stderr):
|
||||
exec(code, script_globals)
|
||||
# 编译并执行脚本
|
||||
compiled = compile(script_content, '<script>', 'exec')
|
||||
exec(compiled, safe_globals)
|
||||
|
||||
await asyncio.wait_for(run_script(), timeout=SCRIPT_TIMEOUT)
|
||||
# 获取输出
|
||||
stdout_output = stdout_capture.getvalue()
|
||||
sdk_output = sdk.get_output()
|
||||
|
||||
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
sdk.log(f"脚本执行完成,耗时 {execution_time}ms")
|
||||
# 合并输出
|
||||
output = '\n'.join(filter(None, [sdk_output, stdout_output]))
|
||||
|
||||
return ScriptExecutionResult(
|
||||
success=True,
|
||||
output=stdout.getvalue(),
|
||||
logs=sdk.get_logs(),
|
||||
execution_time_ms=execution_time
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
error_msg = f"脚本执行超时(超过 {SCRIPT_TIMEOUT} 秒)"
|
||||
sdk.log(error_msg, level="ERROR")
|
||||
|
||||
return ScriptExecutionResult(
|
||||
success=False,
|
||||
output=stdout.getvalue(),
|
||||
error=error_msg,
|
||||
logs=sdk.get_logs(),
|
||||
execution_time_ms=execution_time
|
||||
)
|
||||
|
||||
except SyntaxError as e:
|
||||
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
error_msg = f"语法错误: 第 {e.lineno} 行 - {e.msg}"
|
||||
sdk.log(error_msg, level="ERROR")
|
||||
|
||||
return ScriptExecutionResult(
|
||||
success=False,
|
||||
output=stdout.getvalue(),
|
||||
error=error_msg,
|
||||
logs=sdk.get_logs(),
|
||||
execution_time_ms=execution_time
|
||||
)
|
||||
return True, output, ''
|
||||
|
||||
except Exception as e:
|
||||
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
error_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
|
||||
return False, sdk.get_output(), error_msg
|
||||
|
||||
# 获取详细的错误堆栈
|
||||
tb = traceback.format_exc()
|
||||
sdk.log(f"执行错误: {error_msg}\n{tb}", level="ERROR")
|
||||
|
||||
return ScriptExecutionResult(
|
||||
success=False,
|
||||
output=stdout.getvalue(),
|
||||
error=error_msg,
|
||||
logs=sdk.get_logs(),
|
||||
execution_time_ms=execution_time
|
||||
)
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
except Exception as e:
|
||||
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
error_msg = f"执行器错误: {str(e)}"
|
||||
logger.error(f"Script executor error: {e}", exc_info=True)
|
||||
def _check_script_safety(self, script_content: str) -> Optional[str]:
|
||||
"""检查脚本安全性
|
||||
|
||||
return ScriptExecutionResult(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
logs=sdk.get_logs() if sdk else [],
|
||||
execution_time_ms=execution_time
|
||||
Returns:
|
||||
错误消息,如果安全则返回None
|
||||
"""
|
||||
# 检查危险导入
|
||||
import_patterns = [
|
||||
'import os', 'from os',
|
||||
'import subprocess', 'from subprocess',
|
||||
'import shutil', 'from shutil',
|
||||
'import socket', 'from socket',
|
||||
'__import__',
|
||||
'eval(', 'exec(',
|
||||
'compile(',
|
||||
'open(', # 禁止文件操作
|
||||
]
|
||||
|
||||
script_lower = script_content.lower()
|
||||
for pattern in import_patterns:
|
||||
if pattern.lower() in script_lower:
|
||||
return f"禁止使用: {pattern}"
|
||||
|
||||
return None
|
||||
|
||||
def _create_safe_globals(self, sdk: ScriptSDK) -> Dict[str, Any]:
|
||||
"""创建安全的执行环境"""
|
||||
import json
|
||||
import re
|
||||
import math
|
||||
import random
|
||||
import hashlib
|
||||
import base64
|
||||
from datetime import datetime, date, timedelta
|
||||
from urllib.parse import urlencode, quote, unquote
|
||||
|
||||
# 安全的内置函数
|
||||
safe_builtins = {name: getattr(__builtins__, name, None)
|
||||
for name in ALLOWED_BUILTINS
|
||||
if hasattr(__builtins__, name) or name in dir(__builtins__)}
|
||||
|
||||
# 如果 __builtins__ 是字典
|
||||
if isinstance(__builtins__, dict):
|
||||
safe_builtins = {name: __builtins__.get(name)
|
||||
for name in ALLOWED_BUILTINS
|
||||
if name in __builtins__}
|
||||
|
||||
# 添加常用异常
|
||||
safe_builtins['Exception'] = Exception
|
||||
safe_builtins['ValueError'] = ValueError
|
||||
safe_builtins['TypeError'] = TypeError
|
||||
safe_builtins['KeyError'] = KeyError
|
||||
safe_builtins['IndexError'] = IndexError
|
||||
|
||||
return {
|
||||
'__builtins__': safe_builtins,
|
||||
'__name__': '__main__',
|
||||
|
||||
# SDK函数(全局可用)
|
||||
'log': sdk.log,
|
||||
'print': sdk.print,
|
||||
'ai': sdk.ai,
|
||||
'dingtalk': sdk.dingtalk,
|
||||
'wecom': sdk.wecom,
|
||||
'http_get': sdk.http_get,
|
||||
'http_post': sdk.http_post,
|
||||
'db_query': sdk.db_query,
|
||||
'get_var': sdk.get_var,
|
||||
'set_var': sdk.set_var,
|
||||
'del_var': sdk.del_var,
|
||||
'get_param': sdk.get_param,
|
||||
'get_params': sdk.get_params,
|
||||
'get_tenants': sdk.get_tenants,
|
||||
'get_tenant_config': sdk.get_tenant_config,
|
||||
'get_all_tenant_configs': sdk.get_all_tenant_configs,
|
||||
'get_secret': sdk.get_secret,
|
||||
|
||||
# 当前上下文
|
||||
'task_id': sdk.task_id,
|
||||
'tenant_id': sdk.tenant_id,
|
||||
'trace_id': sdk.trace_id,
|
||||
|
||||
# 安全的标准库
|
||||
'json': json,
|
||||
're': re,
|
||||
'math': math,
|
||||
'random': random,
|
||||
'hashlib': hashlib,
|
||||
'base64': base64,
|
||||
'datetime': datetime,
|
||||
'date': date,
|
||||
'timedelta': timedelta,
|
||||
'urlencode': urlencode,
|
||||
'quote': quote,
|
||||
'unquote': unquote,
|
||||
}
|
||||
|
||||
def test_script(
|
||||
self,
|
||||
script_content: str,
|
||||
task_id: int = 0,
|
||||
tenant_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试脚本(用于调试)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"output": str,
|
||||
"error": str,
|
||||
"duration_ms": int,
|
||||
"logs": [...]
|
||||
}
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
success, output, error = self.execute(
|
||||
script_content=script_content,
|
||||
task_id=task_id,
|
||||
tenant_id=tenant_id,
|
||||
trace_id=f"test-{start_time.timestamp()}",
|
||||
params=params
|
||||
)
|
||||
|
||||
finally:
|
||||
# 清理资源
|
||||
if sdk:
|
||||
sdk.cleanup()
|
||||
|
||||
|
||||
async def test_script(
|
||||
tenant_id: str,
|
||||
script_content: str
|
||||
) -> ScriptExecutionResult:
|
||||
"""
|
||||
测试执行脚本(不记录日志到数据库)
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
script_content: 脚本内容
|
||||
|
||||
Returns:
|
||||
ScriptExecutionResult: 执行结果
|
||||
"""
|
||||
return await execute_script(
|
||||
task_id=0, # 测试用 ID
|
||||
tenant_id=tenant_id,
|
||||
script_content=script_content,
|
||||
trace_id=f"test_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
)
|
||||
duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"output": output,
|
||||
"error": error,
|
||||
"duration_ms": duration_ms
|
||||
}
|
||||
|
||||
@@ -1,79 +1,113 @@
|
||||
"""脚本执行 SDK - 提供给 Python 脚本使用的内置能力"""
|
||||
"""脚本执行SDK - 为Python脚本提供内置功能"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..database import SessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScriptSDK:
|
||||
"""
|
||||
脚本执行 SDK
|
||||
"""脚本SDK - 提供AI、通知、数据库、HTTP、变量存储等功能"""
|
||||
|
||||
提供以下能力:
|
||||
- AI 大模型调用
|
||||
- 钉钉/企微通知
|
||||
- 数据库查询(只读)
|
||||
- HTTP 请求
|
||||
- 变量存储(跨执行持久化)
|
||||
- 日志记录
|
||||
- 多租户遍历
|
||||
- 密钥管理
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, task_id: int, trace_id: str = None):
|
||||
self.tenant_id = tenant_id
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
task_id: int,
|
||||
tenant_id: Optional[str] = None,
|
||||
trace_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.db = db
|
||||
self.task_id = task_id
|
||||
self.trace_id = trace_id or f"script_{task_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
self._logs: List[str] = []
|
||||
self._db: Optional[Session] = None
|
||||
self._tenants_cache: Optional[List[Dict]] = None # 租户列表缓存
|
||||
self.tenant_id = tenant_id
|
||||
self.trace_id = trace_id
|
||||
self.params = params or {}
|
||||
|
||||
self._logs: List[Dict] = []
|
||||
self._output: List[str] = []
|
||||
self._tenants_cache: Dict = {}
|
||||
|
||||
# AI 配置
|
||||
self._ai_base_url = os.getenv('OPENAI_BASE_URL', 'https://api.4sapi.net/v1')
|
||||
self._ai_api_key = os.getenv('OPENAI_API_KEY', 'sk-9yMCXjRGANbacz20kJY8doSNy6Rf446aYwmgGIuIXQ7DAyBw')
|
||||
self._ai_model = os.getenv('OPENAI_MODEL', 'gemini-2.5-flash')
|
||||
|
||||
def _get_db(self) -> Session:
|
||||
"""获取数据库会话"""
|
||||
if self._db is None:
|
||||
self._db = SessionLocal()
|
||||
return self._db
|
||||
# ==================== 参数获取 ====================
|
||||
|
||||
def _close_db(self):
|
||||
"""关闭数据库会话"""
|
||||
if self._db:
|
||||
self._db.close()
|
||||
self._db = None
|
||||
def get_param(self, key: str, default: Any = None) -> Any:
|
||||
"""获取任务参数
|
||||
|
||||
Args:
|
||||
key: 参数名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
参数值
|
||||
"""
|
||||
return self.params.get(key, default)
|
||||
|
||||
# ============ AI 服务 ============
|
||||
def get_params(self) -> Dict[str, Any]:
|
||||
"""获取所有任务参数
|
||||
|
||||
Returns:
|
||||
所有参数字典
|
||||
"""
|
||||
return self.params.copy()
|
||||
|
||||
async def ai_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = None,
|
||||
model: str = "gemini-2.5-flash",
|
||||
# ==================== 日志 ====================
|
||||
|
||||
def log(self, message: str, level: str = 'INFO') -> None:
|
||||
"""记录日志
|
||||
|
||||
Args:
|
||||
message: 日志内容
|
||||
level: 日志级别 (INFO, WARN, ERROR)
|
||||
"""
|
||||
log_entry = {
|
||||
'time': datetime.now().isoformat(),
|
||||
'level': level.upper(),
|
||||
'message': message
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._output.append(f"[{level.upper()}] {message}")
|
||||
|
||||
def print(self, *args, **kwargs) -> None:
|
||||
"""打印输出(兼容print)"""
|
||||
message = ' '.join(str(arg) for arg in args)
|
||||
self._output.append(message)
|
||||
|
||||
def get_logs(self) -> List[Dict]:
|
||||
"""获取所有日志"""
|
||||
return self._logs
|
||||
|
||||
def get_output(self) -> str:
|
||||
"""获取所有输出"""
|
||||
return '\n'.join(self._output)
|
||||
|
||||
# ==================== AI 调用 ====================
|
||||
|
||||
def ai(
|
||||
self,
|
||||
prompt: str,
|
||||
system: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000
|
||||
) -> str:
|
||||
"""
|
||||
调用大模型
|
||||
"""调用AI模型
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system: 系统提示词(可选)
|
||||
model: 模型名称,默认 gemini-2.5-flash
|
||||
temperature: 温度,默认 0.7
|
||||
max_tokens: 最大 token 数,默认 2000
|
||||
system: 系统提示词
|
||||
model: 模型名称(默认gemini-2.5-flash)
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
AI 生成的文本
|
||||
AI响应内容
|
||||
"""
|
||||
# 使用 4sapi 作为 AI 服务
|
||||
api_key = "sk-9yMCXjRGANbacz20kJY8doSNy6Rf446aYwmgGIuIXQ7DAyBw"
|
||||
base_url = "https://4sapi.com/v1"
|
||||
model = model or self._ai_model
|
||||
|
||||
messages = []
|
||||
if system:
|
||||
@@ -81,11 +115,11 @@ class ScriptSDK:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
f"{base_url}/chat/completions",
|
||||
with httpx.Client(timeout=60) as client:
|
||||
response = client.post(
|
||||
f"{self._ai_base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Authorization": f"Bearer {self._ai_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
@@ -97,515 +131,349 @@ class ScriptSDK:
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
self.log(f"AI 调用成功,模型: {model},响应长度: {len(content)}")
|
||||
content = data['choices'][0]['message']['content']
|
||||
self.log(f"AI调用成功: {len(content)} 字符")
|
||||
return content
|
||||
except Exception as e:
|
||||
self.log(f"AI 调用失败: {str(e)}", level="ERROR")
|
||||
self.log(f"AI调用失败: {str(e)}", 'ERROR')
|
||||
raise
|
||||
|
||||
# ============ 通知服务 ============
|
||||
# ==================== 通知 ====================
|
||||
|
||||
async def send_dingtalk(
|
||||
self,
|
||||
webhook: str,
|
||||
content: str,
|
||||
msg_type: str = "text",
|
||||
at_mobiles: List[str] = None,
|
||||
at_all: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
发送钉钉群消息
|
||||
def dingtalk(self, webhook: str, content: str, title: Optional[str] = None, at_all: bool = False) -> bool:
|
||||
"""发送钉钉消息
|
||||
|
||||
Args:
|
||||
webhook: 钉钉机器人 Webhook URL
|
||||
content: 消息内容
|
||||
msg_type: 消息类型,text 或 markdown
|
||||
at_mobiles: @的手机号列表
|
||||
webhook: 钉钉机器人webhook地址
|
||||
content: 消息内容(支持Markdown)
|
||||
title: 消息标题
|
||||
at_all: 是否@所有人
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
if msg_type == "text":
|
||||
data = {
|
||||
"msgtype": "text",
|
||||
"text": {"content": content},
|
||||
"at": {
|
||||
"atMobiles": at_mobiles or [],
|
||||
"isAtAll": at_all
|
||||
}
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"title": content[:20] if len(content) > 20 else content,
|
||||
"text": content
|
||||
},
|
||||
"at": {
|
||||
"atMobiles": at_mobiles or [],
|
||||
"isAtAll": at_all
|
||||
}
|
||||
}
|
||||
payload = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"title": title or "通知",
|
||||
"text": content + ("\n@所有人" if at_all else "")
|
||||
},
|
||||
"at": {"isAtAll": at_all}
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(webhook, json=data)
|
||||
with httpx.Client(timeout=10) as client:
|
||||
response = client.post(webhook, json=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode") == 0:
|
||||
self.log(f"钉钉消息发送成功")
|
||||
return True
|
||||
else:
|
||||
self.log(f"钉钉消息发送失败: {result.get('errmsg')}", level="ERROR")
|
||||
return False
|
||||
success = result.get('errcode') == 0
|
||||
self.log(f"钉钉消息发送{'成功' if success else '失败'}")
|
||||
return success
|
||||
except Exception as e:
|
||||
self.log(f"钉钉消息发送异常: {str(e)}", level="ERROR")
|
||||
self.log(f"钉钉消息发送失败: {str(e)}", 'ERROR')
|
||||
return False
|
||||
|
||||
async def send_wecom(
|
||||
self,
|
||||
webhook: str,
|
||||
content: str,
|
||||
msg_type: str = "text"
|
||||
) -> bool:
|
||||
"""
|
||||
发送企业微信群消息
|
||||
def wecom(self, webhook: str, content: str, msg_type: str = 'markdown') -> bool:
|
||||
"""发送企业微信消息
|
||||
|
||||
Args:
|
||||
webhook: 企微机器人 Webhook URL
|
||||
webhook: 企微机器人webhook地址
|
||||
content: 消息内容
|
||||
msg_type: 消息类型,text 或 markdown
|
||||
msg_type: 消息类型 (text, markdown)
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
if msg_type == "text":
|
||||
data = {
|
||||
"msgtype": "text",
|
||||
"text": {"content": content}
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
if msg_type == 'markdown':
|
||||
payload = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content}
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"msgtype": "text",
|
||||
"text": {"content": content}
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(webhook, json=data)
|
||||
with httpx.Client(timeout=10) as client:
|
||||
response = client.post(webhook, json=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode") == 0:
|
||||
self.log(f"企微消息发送成功")
|
||||
return True
|
||||
else:
|
||||
self.log(f"企微消息发送失败: {result.get('errmsg')}", level="ERROR")
|
||||
return False
|
||||
success = result.get('errcode') == 0
|
||||
self.log(f"企微消息发送{'成功' if success else '失败'}")
|
||||
return success
|
||||
except Exception as e:
|
||||
self.log(f"企微消息发送异常: {str(e)}", level="ERROR")
|
||||
self.log(f"企微消息发送失败: {str(e)}", 'ERROR')
|
||||
return False
|
||||
|
||||
# ============ 数据库查询 ============
|
||||
# ==================== HTTP 请求 ====================
|
||||
|
||||
def db_query(self, sql: str, params: Dict[str, Any] = None) -> List[Dict]:
|
||||
def http_get(self, url: str, headers: Optional[Dict] = None, params: Optional[Dict] = None, timeout: int = 30) -> Dict:
|
||||
"""发起HTTP GET请求
|
||||
|
||||
Returns:
|
||||
{"status": 200, "data": ..., "text": "..."}
|
||||
"""
|
||||
执行 SQL 查询(只读)
|
||||
try:
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
response = client.get(url, headers=headers, params=params)
|
||||
return {
|
||||
"status": response.status_code,
|
||||
"data": response.json() if response.headers.get('content-type', '').startswith('application/json') else None,
|
||||
"text": response.text
|
||||
}
|
||||
except Exception as e:
|
||||
self.log(f"HTTP GET 失败: {str(e)}", 'ERROR')
|
||||
raise
|
||||
|
||||
def http_post(self, url: str, data: Any = None, headers: Optional[Dict] = None, timeout: int = 30) -> Dict:
|
||||
"""发起HTTP POST请求
|
||||
|
||||
Returns:
|
||||
{"status": 200, "data": ..., "text": "..."}
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
response = client.post(url, json=data, headers=headers)
|
||||
return {
|
||||
"status": response.status_code,
|
||||
"data": response.json() if response.headers.get('content-type', '').startswith('application/json') else None,
|
||||
"text": response.text
|
||||
}
|
||||
except Exception as e:
|
||||
self.log(f"HTTP POST 失败: {str(e)}", 'ERROR')
|
||||
raise
|
||||
|
||||
# ==================== 数据库查询(只读)====================
|
||||
|
||||
def db_query(self, sql: str, params: Optional[Dict] = None) -> List[Dict]:
|
||||
"""执行只读SQL查询
|
||||
|
||||
Args:
|
||||
sql: SQL 语句(仅支持 SELECT)
|
||||
params: 查询参数
|
||||
sql: SQL语句(必须是SELECT)
|
||||
params: 参数字典
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
"""
|
||||
# 安全检查:只允许 SELECT
|
||||
sql_upper = sql.strip().upper()
|
||||
if not sql_upper.startswith("SELECT"):
|
||||
raise ValueError("只允许 SELECT 查询")
|
||||
if not sql_upper.startswith('SELECT'):
|
||||
raise ValueError("只允许执行SELECT查询")
|
||||
|
||||
forbidden = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE"]
|
||||
# 禁止危险操作
|
||||
forbidden = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'TRUNCATE', 'ALTER', 'CREATE']
|
||||
for word in forbidden:
|
||||
if word in sql_upper:
|
||||
raise ValueError(f"禁止使用 {word} 语句")
|
||||
raise ValueError(f"禁止执行 {word} 操作")
|
||||
|
||||
db = self._get_db()
|
||||
try:
|
||||
result = db.execute(text(sql), params or {})
|
||||
rows = [dict(row) for row in result.mappings().all()]
|
||||
self.log(f"SQL 查询成功,返回 {len(rows)} 条记录")
|
||||
from sqlalchemy import text
|
||||
result = self.db.execute(text(sql), params or {})
|
||||
columns = result.keys()
|
||||
rows = [dict(zip(columns, row)) for row in result.fetchall()]
|
||||
self.log(f"SQL查询返回 {len(rows)} 条记录")
|
||||
return rows
|
||||
except Exception as e:
|
||||
self.log(f"SQL 查询失败: {str(e)}", level="ERROR")
|
||||
self.log(f"SQL查询失败: {str(e)}", 'ERROR')
|
||||
raise
|
||||
|
||||
# ============ HTTP 请求 ============
|
||||
|
||||
async def http_get(
|
||||
self,
|
||||
url: str,
|
||||
headers: Dict[str, str] = None,
|
||||
params: Dict[str, Any] = None,
|
||||
timeout: int = 30
|
||||
) -> Dict:
|
||||
"""
|
||||
发送 HTTP GET 请求
|
||||
|
||||
Args:
|
||||
url: 请求 URL
|
||||
headers: 请求头
|
||||
params: 查询参数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
响应数据(JSON 解析后)
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=float(timeout)) as client:
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
self.log(f"HTTP GET {url} -> {response.status_code}")
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except:
|
||||
return {"text": response.text, "status_code": response.status_code}
|
||||
except Exception as e:
|
||||
self.log(f"HTTP GET 失败: {str(e)}", level="ERROR")
|
||||
raise
|
||||
|
||||
async def http_post(
|
||||
self,
|
||||
url: str,
|
||||
data: Dict[str, Any] = None,
|
||||
json_data: Dict[str, Any] = None,
|
||||
headers: Dict[str, str] = None,
|
||||
timeout: int = 30
|
||||
) -> Dict:
|
||||
"""
|
||||
发送 HTTP POST 请求
|
||||
|
||||
Args:
|
||||
url: 请求 URL
|
||||
data: 表单数据
|
||||
json_data: JSON 数据
|
||||
headers: 请求头
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
响应数据(JSON 解析后)
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=float(timeout)) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
data=data,
|
||||
json=json_data,
|
||||
headers=headers
|
||||
)
|
||||
self.log(f"HTTP POST {url} -> {response.status_code}")
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except:
|
||||
return {"text": response.text, "status_code": response.status_code}
|
||||
except Exception as e:
|
||||
self.log(f"HTTP POST 失败: {str(e)}", level="ERROR")
|
||||
raise
|
||||
|
||||
# ============ 变量存储 ============
|
||||
# ==================== 变量存储 ====================
|
||||
|
||||
def get_var(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
获取存储的变量
|
||||
"""获取持久化变量
|
||||
|
||||
Args:
|
||||
key: 变量名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
变量值(JSON 解析后)
|
||||
变量值
|
||||
"""
|
||||
db = self._get_db()
|
||||
try:
|
||||
result = db.execute(
|
||||
text("SELECT var_value FROM platform_script_vars WHERE tenant_id = :tid AND var_key = :key"),
|
||||
{"tid": self.tenant_id, "key": key}
|
||||
)
|
||||
row = result.first()
|
||||
if row:
|
||||
try:
|
||||
return json.loads(row[0])
|
||||
except:
|
||||
return row[0]
|
||||
return default
|
||||
except Exception as e:
|
||||
self.log(f"获取变量失败: {str(e)}", level="ERROR")
|
||||
return default
|
||||
from ..models.scheduled_task import ScriptVar
|
||||
|
||||
var = self.db.query(ScriptVar).filter(
|
||||
ScriptVar.task_id == self.task_id,
|
||||
ScriptVar.tenant_id == self.tenant_id,
|
||||
ScriptVar.var_key == key
|
||||
).first()
|
||||
|
||||
if var and var.var_value:
|
||||
try:
|
||||
return json.loads(var.var_value)
|
||||
except:
|
||||
return var.var_value
|
||||
return default
|
||||
|
||||
def set_var(self, key: str, value: Any) -> bool:
|
||||
"""
|
||||
存储变量(跨执行持久化)
|
||||
def set_var(self, key: str, value: Any) -> None:
|
||||
"""设置持久化变量
|
||||
|
||||
Args:
|
||||
key: 变量名
|
||||
value: 变量值(会 JSON 序列化)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
value: 变量值(会JSON序列化)
|
||||
"""
|
||||
db = self._get_db()
|
||||
try:
|
||||
value_str = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
|
||||
|
||||
# 使用 REPLACE INTO 实现 upsert
|
||||
db.execute(
|
||||
text("""
|
||||
REPLACE INTO platform_script_vars (tenant_id, var_key, var_value, updated_at)
|
||||
VALUES (:tid, :key, :value, NOW())
|
||||
"""),
|
||||
{"tid": self.tenant_id, "key": key, "value": value_str}
|
||||
from ..models.scheduled_task import ScriptVar
|
||||
|
||||
var = self.db.query(ScriptVar).filter(
|
||||
ScriptVar.task_id == self.task_id,
|
||||
ScriptVar.tenant_id == self.tenant_id,
|
||||
ScriptVar.var_key == key
|
||||
).first()
|
||||
|
||||
value_json = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if var:
|
||||
var.var_value = value_json
|
||||
else:
|
||||
var = ScriptVar(
|
||||
task_id=self.task_id,
|
||||
tenant_id=self.tenant_id,
|
||||
var_key=key,
|
||||
var_value=value_json
|
||||
)
|
||||
db.commit()
|
||||
self.log(f"变量已存储: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"存储变量失败: {str(e)}", level="ERROR")
|
||||
return False
|
||||
self.db.add(var)
|
||||
|
||||
self.db.commit()
|
||||
self.log(f"变量 {key} 已保存")
|
||||
|
||||
def delete_var(self, key: str) -> bool:
|
||||
"""
|
||||
删除变量
|
||||
def del_var(self, key: str) -> bool:
|
||||
"""删除持久化变量"""
|
||||
from ..models.scheduled_task import ScriptVar
|
||||
|
||||
result = self.db.query(ScriptVar).filter(
|
||||
ScriptVar.task_id == self.task_id,
|
||||
ScriptVar.tenant_id == self.tenant_id,
|
||||
ScriptVar.var_key == key
|
||||
).delete()
|
||||
|
||||
self.db.commit()
|
||||
return result > 0
|
||||
|
||||
# ==================== 租户配置 ====================
|
||||
|
||||
def get_tenants(self, app_code: Optional[str] = None) -> List[Dict]:
|
||||
"""获取租户列表
|
||||
|
||||
Args:
|
||||
key: 变量名
|
||||
app_code: 可选,按应用代码筛选
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
租户列表 [{"tenant_id": ..., "tenant_name": ...}, ...]
|
||||
"""
|
||||
db = self._get_db()
|
||||
try:
|
||||
db.execute(
|
||||
text("DELETE FROM platform_script_vars WHERE tenant_id = :tid AND var_key = :key"),
|
||||
{"tid": self.tenant_id, "key": key}
|
||||
)
|
||||
db.commit()
|
||||
self.log(f"变量已删除: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log(f"删除变量失败: {str(e)}", level="ERROR")
|
||||
return False
|
||||
|
||||
# ============ 多租户遍历 ============
|
||||
|
||||
def get_tenants(self, app_code: str = None) -> List[Dict]:
|
||||
"""
|
||||
获取租户列表(用于多租户任务遍历)
|
||||
from ..models.tenant import Tenant
|
||||
from ..models.tenant_app import TenantApp
|
||||
|
||||
Args:
|
||||
app_code: 应用代码(可选),筛选订阅了该应用的租户
|
||||
if app_code:
|
||||
# 获取订阅了该应用的租户
|
||||
tenant_ids = self.db.query(TenantApp.tenant_id).filter(
|
||||
TenantApp.app_code == app_code,
|
||||
TenantApp.status == 1
|
||||
).all()
|
||||
tenant_ids = [t[0] for t in tenant_ids]
|
||||
|
||||
Returns:
|
||||
租户列表 [{"code": "xxx", "name": "租户名", "custom_configs": {...}}]
|
||||
"""
|
||||
db = self._get_db()
|
||||
try:
|
||||
if app_code:
|
||||
# 筛选订阅了指定应用的租户
|
||||
result = db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT t.code, t.name, ta.custom_configs
|
||||
FROM platform_tenants t
|
||||
INNER JOIN platform_tenant_apps ta ON t.code = ta.tenant_id
|
||||
WHERE ta.app_code = :app_code AND t.status = 1
|
||||
"""),
|
||||
{"app_code": app_code}
|
||||
)
|
||||
else:
|
||||
# 获取所有启用的租户
|
||||
result = db.execute(
|
||||
text("SELECT code, name FROM platform_tenants WHERE status = 1")
|
||||
)
|
||||
|
||||
tenants = []
|
||||
for row in result.mappings().all():
|
||||
tenant = dict(row)
|
||||
# 解析 custom_configs
|
||||
if "custom_configs" in tenant and tenant["custom_configs"]:
|
||||
try:
|
||||
tenant["custom_configs"] = json.loads(tenant["custom_configs"])
|
||||
except:
|
||||
pass
|
||||
tenants.append(tenant)
|
||||
|
||||
self.log(f"获取租户列表成功,共 {len(tenants)} 个租户")
|
||||
return tenants
|
||||
except Exception as e:
|
||||
self.log(f"获取租户列表失败: {str(e)}", level="ERROR")
|
||||
return []
|
||||
tenants = self.db.query(Tenant).filter(
|
||||
Tenant.code.in_(tenant_ids),
|
||||
Tenant.status == 'active'
|
||||
).all()
|
||||
else:
|
||||
tenants = self.db.query(Tenant).filter(Tenant.status == 'active').all()
|
||||
|
||||
return [{"tenant_id": t.code, "tenant_name": t.name} for t in tenants]
|
||||
|
||||
def get_tenant_config(self, tenant_id: str, app_code: str, key: str = None) -> Any:
|
||||
"""
|
||||
获取指定租户的应用配置
|
||||
def get_tenant_config(self, tenant_id: str, app_code: str, key: Optional[str] = None) -> Any:
|
||||
"""获取租户的应用配置
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
app_code: 应用代码
|
||||
key: 配置项键名(可选,不传返回全部配置)
|
||||
key: 配置键(可选,不提供则返回所有配置)
|
||||
|
||||
Returns:
|
||||
配置值或配置字典
|
||||
"""
|
||||
db = self._get_db()
|
||||
try:
|
||||
result = db.execute(
|
||||
text("""
|
||||
SELECT custom_configs FROM platform_tenant_apps
|
||||
WHERE tenant_id = :tenant_id AND app_code = :app_code
|
||||
"""),
|
||||
{"tenant_id": tenant_id, "app_code": app_code}
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if not row or not row[0]:
|
||||
return None if key else {}
|
||||
|
||||
try:
|
||||
configs = json.loads(row[0])
|
||||
except:
|
||||
configs = {}
|
||||
|
||||
if key:
|
||||
return configs.get(key)
|
||||
return configs
|
||||
except Exception as e:
|
||||
self.log(f"获取租户配置失败: {str(e)}", level="ERROR")
|
||||
from ..models.tenant_app import TenantApp
|
||||
|
||||
tenant_app = self.db.query(TenantApp).filter(
|
||||
TenantApp.tenant_id == tenant_id,
|
||||
TenantApp.app_code == app_code
|
||||
).first()
|
||||
|
||||
if not tenant_app:
|
||||
return None if key else {}
|
||||
|
||||
# 解析 custom_configs
|
||||
configs = {}
|
||||
if hasattr(tenant_app, 'custom_configs') and tenant_app.custom_configs:
|
||||
try:
|
||||
configs = json.loads(tenant_app.custom_configs) if isinstance(tenant_app.custom_configs, str) else tenant_app.custom_configs
|
||||
except:
|
||||
pass
|
||||
|
||||
if key:
|
||||
return configs.get(key)
|
||||
return configs
|
||||
|
||||
def get_all_tenant_configs(self, app_code: str) -> List[Dict]:
|
||||
"""
|
||||
获取所有租户的应用配置(便捷方法,用于批量操作)
|
||||
"""获取所有租户的应用配置
|
||||
|
||||
Args:
|
||||
app_code: 应用代码
|
||||
|
||||
Returns:
|
||||
[{"tenant_id": "xxx", "tenant_name": "租户名", "configs": {...}}]
|
||||
[{"tenant_id": ..., "tenant_name": ..., "configs": {...}}, ...]
|
||||
"""
|
||||
db = self._get_db()
|
||||
try:
|
||||
result = db.execute(
|
||||
text("""
|
||||
SELECT t.code as tenant_id, t.name as tenant_name, ta.custom_configs
|
||||
FROM platform_tenants t
|
||||
INNER JOIN platform_tenant_apps ta ON t.code = ta.tenant_id
|
||||
WHERE ta.app_code = :app_code AND t.status = 1
|
||||
"""),
|
||||
{"app_code": app_code}
|
||||
)
|
||||
from ..models.tenant import Tenant
|
||||
from ..models.tenant_app import TenantApp
|
||||
|
||||
tenant_apps = self.db.query(TenantApp).filter(
|
||||
TenantApp.app_code == app_code,
|
||||
TenantApp.status == 1
|
||||
).all()
|
||||
|
||||
result = []
|
||||
for ta in tenant_apps:
|
||||
tenant = self.db.query(Tenant).filter(Tenant.code == ta.tenant_id).first()
|
||||
configs = {}
|
||||
if hasattr(ta, 'custom_configs') and ta.custom_configs:
|
||||
try:
|
||||
configs = json.loads(ta.custom_configs) if isinstance(ta.custom_configs, str) else ta.custom_configs
|
||||
except:
|
||||
pass
|
||||
|
||||
tenants = []
|
||||
for row in result.mappings().all():
|
||||
configs = {}
|
||||
if row["custom_configs"]:
|
||||
try:
|
||||
configs = json.loads(row["custom_configs"])
|
||||
except:
|
||||
pass
|
||||
tenants.append({
|
||||
"tenant_id": row["tenant_id"],
|
||||
"tenant_name": row["tenant_name"],
|
||||
"configs": configs
|
||||
})
|
||||
|
||||
self.log(f"获取 {app_code} 应用的租户配置,共 {len(tenants)} 个")
|
||||
return tenants
|
||||
except Exception as e:
|
||||
self.log(f"获取租户配置失败: {str(e)}", level="ERROR")
|
||||
return []
|
||||
result.append({
|
||||
"tenant_id": ta.tenant_id,
|
||||
"tenant_name": tenant.name if tenant else ta.tenant_id,
|
||||
"configs": configs
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
# ============ 密钥管理 ============
|
||||
# ==================== 密钥管理 ====================
|
||||
|
||||
def get_secret(self, key: str) -> Optional[str]:
|
||||
"""
|
||||
获取密钥(优先读取租户级密钥,其次读取全局密钥)
|
||||
"""获取密钥(优先租户级,其次全局)
|
||||
|
||||
Args:
|
||||
key: 密钥名称
|
||||
key: 密钥名
|
||||
|
||||
Returns:
|
||||
密钥值(如不存在返回 None)
|
||||
密钥值
|
||||
"""
|
||||
db = self._get_db()
|
||||
try:
|
||||
# 优先查询租户级密钥
|
||||
result = db.execute(
|
||||
text("""
|
||||
SELECT secret_value FROM platform_secrets
|
||||
WHERE (tenant_id = :tenant_id OR tenant_id IS NULL)
|
||||
AND secret_key = :key
|
||||
ORDER BY tenant_id DESC
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"tenant_id": self.tenant_id, "key": key}
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if row:
|
||||
self.log(f"获取密钥成功: {key}")
|
||||
return row[0]
|
||||
|
||||
self.log(f"密钥不存在: {key}", level="WARN")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.log(f"获取密钥失败: {str(e)}", level="ERROR")
|
||||
return None
|
||||
|
||||
# ============ 日志 ============
|
||||
|
||||
def log(self, message: str, level: str = "INFO"):
|
||||
"""
|
||||
记录日志
|
||||
from ..models.scheduled_task import Secret
|
||||
|
||||
Args:
|
||||
message: 日志内容
|
||||
level: 日志级别(INFO, WARN, ERROR)
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
log_entry = f"[{timestamp}] [{level}] {message}"
|
||||
self._logs.append(log_entry)
|
||||
# 先查租户级
|
||||
if self.tenant_id:
|
||||
secret = self.db.query(Secret).filter(
|
||||
Secret.tenant_id == self.tenant_id,
|
||||
Secret.secret_key == key
|
||||
).first()
|
||||
if secret:
|
||||
return secret.secret_value
|
||||
|
||||
# 同时输出到标准日志
|
||||
if level == "ERROR":
|
||||
logger.error(f"[Script {self.task_id}] {message}")
|
||||
else:
|
||||
logger.info(f"[Script {self.task_id}] {message}")
|
||||
# 再查全局
|
||||
secret = self.db.query(Secret).filter(
|
||||
Secret.tenant_id.is_(None),
|
||||
Secret.secret_key == key
|
||||
).first()
|
||||
|
||||
# 写入 platform_logs
|
||||
try:
|
||||
db = self._get_db()
|
||||
db.execute(
|
||||
text("""
|
||||
INSERT INTO platform_logs
|
||||
(trace_id, app_code, module, level, message, created_at)
|
||||
VALUES (:trace_id, :app_code, :module, :level, :message, NOW())
|
||||
"""),
|
||||
{
|
||||
"trace_id": self.trace_id,
|
||||
"app_code": "000-platform",
|
||||
"module": "script",
|
||||
"level": level,
|
||||
"message": message[:2000] # 限制长度
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write log to database: {e}")
|
||||
|
||||
def get_logs(self) -> List[str]:
|
||||
"""获取所有日志"""
|
||||
return self._logs.copy()
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
self._close_db()
|
||||
return secret.secret_value if secret else None
|
||||
|
||||
Reference in New Issue
Block a user