Files
000-platform/backend/app/services/script_executor.py
Admin 3ebd8b20a4
All checks were successful
continuous-integration/drone/push Build is passing
fix: 添加受限的 __import__ 函数支持白名单模块导入
解决脚本执行时 KeyError: '__import__' 错误
2026-01-28 17:34:38 +08:00

286 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""脚本执行器 - 安全执行Python脚本"""
import sys
import traceback
from io import StringIO
from typing import Any, Dict, Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from .script_sdk import ScriptSDK
# 禁止导入的模块
FORBIDDEN_MODULES = {
'os', 'subprocess', 'shutil', 'pathlib',
'socket', 'ftplib', 'telnetlib', 'smtplib',
'pickle', 'shelve', 'marshal',
'ctypes', 'multiprocessing',
'__builtins__', 'builtins',
'importlib', 'imp',
'code', 'codeop', 'compile',
}
# 允许的内置函数
ALLOWED_BUILTINS = {
'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'bytearray', 'bytes',
'callable', 'chr', 'complex', 'dict', 'dir', 'divmod', 'enumerate',
'filter', 'float', 'format', 'frozenset', 'getattr', 'hasattr', 'hash',
'hex', 'id', 'int', 'isinstance', 'issubclass', 'iter', 'len', 'list',
'map', 'max', 'min', 'next', 'object', 'oct', 'ord', 'pow', 'print',
'range', 'repr', 'reversed', 'round', 'set', 'setattr', 'slice',
'sorted', 'str', 'sum', 'tuple', 'type', 'vars', 'zip',
'True', 'False', 'None',
'Exception', 'BaseException', 'ValueError', 'TypeError', 'KeyError',
'IndexError', 'AttributeError', 'RuntimeError', 'StopIteration',
}
class ScriptExecutor:
"""脚本执行器"""
def __init__(self, db: Session):
self.db = db
def execute(
self,
script_content: str,
task_id: int,
tenant_id: Optional[str] = None,
trace_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
timeout: int = 300
) -> Tuple[bool, str, str, Optional[Dict]]:
"""执行脚本
Args:
script_content: Python脚本内容
task_id: 任务ID
tenant_id: 租户ID
trace_id: 追踪ID
params: 输入参数
timeout: 超时秒数
Returns:
(success, output, error, result)
result: 脚本返回值 {'content': '...', 'title': '...'}
"""
# 创建SDK实例
sdk = ScriptSDK(
db=self.db,
task_id=task_id,
tenant_id=tenant_id,
trace_id=trace_id,
params=params or {}
)
# 检查脚本安全性
check_result = self._check_script_safety(script_content)
if check_result:
return False, '', f"脚本安全检查失败: {check_result}", None
# 准备执行环境
safe_globals = self._create_safe_globals(sdk)
# 捕获输出
old_stdout = sys.stdout
old_stderr = sys.stderr
stdout_capture = StringIO()
stderr_capture = StringIO()
try:
sys.stdout = stdout_capture
sys.stderr = stderr_capture
# 编译并执行脚本
compiled = compile(script_content, '<script>', 'exec')
exec(compiled, safe_globals)
# 获取输出
stdout_output = stdout_capture.getvalue()
sdk_output = sdk.get_output()
# 合并输出
output = '\n'.join(filter(None, [sdk_output, stdout_output]))
# 获取脚本返回值(通过 __result__ 变量)
result = safe_globals.get('__result__')
if result is None and 'result' in safe_globals:
result = safe_globals.get('result')
# 如果返回的是字符串,包装成字典
if isinstance(result, str):
result = {'content': result}
elif result is not None and not isinstance(result, dict):
result = {'content': str(result)}
return True, output, '', result
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
return False, sdk.get_output(), error_msg, None
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
def _check_script_safety(self, script_content: str) -> Optional[str]:
"""检查脚本安全性
Returns:
错误消息如果安全则返回None
"""
# 检查危险导入
import_patterns = [
'import os', 'from os',
'import subprocess', 'from subprocess',
'import shutil', 'from shutil',
'import socket', 'from socket',
'__import__',
'eval(', 'exec(',
'compile(',
'open(', # 禁止文件操作
]
script_lower = script_content.lower()
for pattern in import_patterns:
if pattern.lower() in script_lower:
return f"禁止使用: {pattern}"
return None
def _create_safe_globals(self, sdk: ScriptSDK) -> Dict[str, Any]:
"""创建安全的执行环境"""
import json
import re
import math
import random
import hashlib
import base64
import time
import collections
from datetime import datetime, date, timedelta
from urllib.parse import urlencode, quote, unquote
# 允许导入的模块白名单
ALLOWED_MODULES = {
'json': json,
're': re,
'math': math,
'random': random,
'hashlib': hashlib,
'base64': base64,
'time': time,
'datetime': __import__('datetime'),
'collections': collections,
}
def safe_import(name, globals=None, locals=None, fromlist=(), level=0):
"""受限的 import 函数"""
if name in ALLOWED_MODULES:
return ALLOWED_MODULES[name]
raise ImportError(f"不允许导入模块: {name}。已内置可用: {', '.join(ALLOWED_MODULES.keys())}")
# 安全的内置函数
safe_builtins = {name: getattr(__builtins__, name, None)
for name in ALLOWED_BUILTINS
if hasattr(__builtins__, name) or name in dir(__builtins__)}
# 如果 __builtins__ 是字典
if isinstance(__builtins__, dict):
safe_builtins = {name: __builtins__.get(name)
for name in ALLOWED_BUILTINS
if name in __builtins__}
# 添加受限的 __import__
safe_builtins['__import__'] = safe_import
# 添加常用异常
safe_builtins['Exception'] = Exception
safe_builtins['ValueError'] = ValueError
safe_builtins['TypeError'] = TypeError
safe_builtins['KeyError'] = KeyError
safe_builtins['IndexError'] = IndexError
safe_builtins['ImportError'] = ImportError
return {
'__builtins__': safe_builtins,
'__name__': '__main__',
# SDK函数全局可用
'log': sdk.log,
'print': sdk.print,
'ai': sdk.ai,
'dingtalk': sdk.dingtalk,
'wecom': sdk.wecom,
'http_get': sdk.http_get,
'http_post': sdk.http_post,
'db_query': sdk.db_query,
'get_var': sdk.get_var,
'set_var': sdk.set_var,
'del_var': sdk.del_var,
'get_param': sdk.get_param,
'get_params': sdk.get_params,
'get_tenants': sdk.get_tenants,
'get_tenant_config': sdk.get_tenant_config,
'get_all_tenant_configs': sdk.get_all_tenant_configs,
'get_secret': sdk.get_secret,
# 当前上下文
'task_id': sdk.task_id,
'tenant_id': sdk.tenant_id,
'trace_id': sdk.trace_id,
# 安全的标准库
'json': json,
're': re,
'math': math,
'random': random,
'hashlib': hashlib,
'base64': base64,
'datetime': datetime,
'date': date,
'timedelta': timedelta,
'time': time,
'urlencode': urlencode,
'quote': quote,
'unquote': unquote,
}
def test_script(
self,
script_content: str,
task_id: int = 0,
tenant_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""测试脚本(用于调试)
Returns:
{
"success": bool,
"output": str,
"error": str,
"duration_ms": int,
"result": dict
}
"""
start_time = datetime.now()
success, output, error, result = self.execute(
script_content=script_content,
task_id=task_id,
tenant_id=tenant_id,
trace_id=f"test-{start_time.timestamp()}",
params=params
)
duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
return {
"success": success,
"output": output,
"error": error,
"duration_ms": duration_ms,
"result": result
}