- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
354 lines
13 KiB
Python
354 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
MySQL Binlog 回滚工具
|
||
用于考培练系统的数据库回滚操作
|
||
|
||
功能:
|
||
1. 解析Binlog文件
|
||
2. 生成反向SQL语句
|
||
3. 执行数据回滚
|
||
4. 支持时间范围和表过滤
|
||
|
||
使用方法:
|
||
python scripts/binlog_rollback_tool.py --help
|
||
"""
|
||
|
||
import asyncio
|
||
import argparse
|
||
import subprocess
|
||
import tempfile
|
||
import os
|
||
import re
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from typing import List, Dict, Optional
|
||
import aiomysql
|
||
import logging
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class BinlogRollbackTool:
|
||
"""Binlog回滚工具类"""
|
||
|
||
def __init__(self,
|
||
host: str = "localhost",
|
||
port: int = 3306,
|
||
user: str = "root",
|
||
password: str = "root",
|
||
database: str = "kaopeilian"):
|
||
self.host = host
|
||
self.port = port
|
||
self.user = user
|
||
self.password = password
|
||
self.database = database
|
||
self.connection = None
|
||
|
||
async def connect(self):
|
||
"""连接到MySQL数据库"""
|
||
try:
|
||
self.connection = await aiomysql.connect(
|
||
host=self.host,
|
||
port=self.port,
|
||
user=self.user,
|
||
password=self.password,
|
||
db=self.database,
|
||
charset='utf8mb4'
|
||
)
|
||
logger.info(f"✅ 成功连接到数据库 {self.database}")
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库连接失败: {e}")
|
||
raise
|
||
|
||
async def close(self):
|
||
"""关闭数据库连接"""
|
||
if self.connection:
|
||
self.connection.close()
|
||
logger.info("🔒 数据库连接已关闭")
|
||
|
||
async def get_binlog_files(self) -> List[Dict]:
|
||
"""获取Binlog文件列表"""
|
||
cursor = await self.connection.cursor()
|
||
await cursor.execute("SHOW BINARY LOGS")
|
||
result = await cursor.fetchall()
|
||
await cursor.close()
|
||
|
||
binlog_files = []
|
||
for row in result:
|
||
binlog_files.append({
|
||
'name': row[0],
|
||
'size': row[1],
|
||
'encrypted': row[2] if len(row) > 2 else False
|
||
})
|
||
|
||
logger.info(f"📋 找到 {len(binlog_files)} 个Binlog文件")
|
||
return binlog_files
|
||
|
||
async def get_binlog_position_by_time(self, target_time: datetime) -> Optional[str]:
|
||
"""根据时间获取Binlog位置"""
|
||
cursor = await self.connection.cursor()
|
||
|
||
# 获取所有Binlog文件
|
||
binlog_files = await self.get_binlog_files()
|
||
|
||
for binlog_file in binlog_files:
|
||
try:
|
||
# 使用mysqlbinlog解析文件,查找时间点
|
||
cmd = [
|
||
'mysqlbinlog',
|
||
'--start-datetime', target_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'--stop-datetime', (target_time + timedelta(seconds=1)).strftime('%Y-%m-%d %H:%M:%S'),
|
||
f'/var/lib/mysql/{binlog_file["name"]}'
|
||
]
|
||
|
||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||
if result.returncode == 0 and result.stdout.strip():
|
||
logger.info(f"📍 在 {binlog_file['name']} 中找到时间点 {target_time}")
|
||
return binlog_file['name']
|
||
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ 解析 {binlog_file['name']} 时出错: {e}")
|
||
continue
|
||
|
||
logger.warning(f"⚠️ 未找到时间点 {target_time} 对应的Binlog位置")
|
||
return None
|
||
|
||
def parse_binlog_to_sql(self,
|
||
binlog_file: str,
|
||
start_time: Optional[datetime] = None,
|
||
stop_time: Optional[datetime] = None,
|
||
tables: Optional[List[str]] = None) -> str:
|
||
"""解析Binlog文件生成SQL语句"""
|
||
|
||
# 构建mysqlbinlog命令
|
||
cmd = ['mysqlbinlog', '--base64-output=decode-rows', '-v']
|
||
|
||
if start_time:
|
||
cmd.extend(['--start-datetime', start_time.strftime('%Y-%m-%d %H:%M:%S')])
|
||
|
||
if stop_time:
|
||
cmd.extend(['--stop-datetime', stop_time.strftime('%Y-%m-%d %H:%M:%S')])
|
||
|
||
# 添加数据库过滤
|
||
cmd.extend(['--database', self.database])
|
||
|
||
# 添加表过滤
|
||
if tables:
|
||
for table in tables:
|
||
cmd.extend(['--table', table])
|
||
|
||
# 添加Binlog文件路径
|
||
cmd.append(f'/var/lib/mysql/{binlog_file}')
|
||
|
||
logger.info(f"🔍 执行命令: {' '.join(cmd)}")
|
||
|
||
try:
|
||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||
if result.returncode != 0:
|
||
logger.error(f"❌ mysqlbinlog执行失败: {result.stderr}")
|
||
return ""
|
||
|
||
return result.stdout
|
||
|
||
except subprocess.TimeoutExpired:
|
||
logger.error("❌ mysqlbinlog执行超时")
|
||
return ""
|
||
except Exception as e:
|
||
logger.error(f"❌ mysqlbinlog执行异常: {e}")
|
||
return ""
|
||
|
||
def generate_reverse_sql(self, binlog_sql: str) -> List[str]:
|
||
"""生成反向SQL语句"""
|
||
reverse_sqls = []
|
||
|
||
# 解析INSERT语句,转换为DELETE
|
||
insert_pattern = r'INSERT INTO `([^`]+)` \(([^)]+)\) VALUES \((.+)\);'
|
||
for match in re.finditer(insert_pattern, binlog_sql, re.MULTILINE):
|
||
table = match.group(1)
|
||
columns = match.group(2)
|
||
values = match.group(3)
|
||
|
||
# 构建WHERE条件
|
||
where_conditions = []
|
||
column_list = [col.strip().strip('`') for col in columns.split(',')]
|
||
value_list = [val.strip().strip("'") for val in values.split(',')]
|
||
|
||
for col, val in zip(column_list, value_list):
|
||
if val != 'NULL':
|
||
where_conditions.append(f"`{col}` = '{val}'")
|
||
|
||
if where_conditions:
|
||
delete_sql = f"DELETE FROM `{table}` WHERE {' AND '.join(where_conditions)};"
|
||
reverse_sqls.append(delete_sql)
|
||
|
||
# 解析UPDATE语句,生成反向UPDATE
|
||
update_pattern = r'UPDATE `([^`]+)` SET (.+) WHERE (.+);'
|
||
for match in re.finditer(update_pattern, binlog_sql, re.MULTILINE):
|
||
table = match.group(1)
|
||
set_clause = match.group(2)
|
||
where_clause = match.group(3)
|
||
|
||
# 这里需要从Binlog中提取原始值,ROW格式的Binlog包含@1, @2等变量
|
||
# 简化处理:生成警告信息
|
||
reverse_sqls.append(f"-- 需要手动处理UPDATE语句: UPDATE `{table}` SET {set_clause} WHERE {where_clause};")
|
||
|
||
# 解析DELETE语句,转换为INSERT
|
||
delete_pattern = r'DELETE FROM `([^`]+)` WHERE (.+);'
|
||
for match in re.finditer(delete_pattern, binlog_sql, re.MULTILINE):
|
||
table = match.group(1)
|
||
where_clause = match.group(2)
|
||
|
||
# 简化处理:生成警告信息
|
||
reverse_sqls.append(f"-- 需要手动处理DELETE语句: INSERT INTO `{table}` ... WHERE {where_clause};")
|
||
|
||
return reverse_sqls
|
||
|
||
async def execute_rollback_sql(self, sql_statements: List[str], dry_run: bool = True) -> bool:
|
||
"""执行回滚SQL语句"""
|
||
if not sql_statements:
|
||
logger.warning("⚠️ 没有需要执行的SQL语句")
|
||
return True
|
||
|
||
if dry_run:
|
||
logger.info("🔍 模拟执行模式 - 以下SQL语句将被执行:")
|
||
for i, sql in enumerate(sql_statements, 1):
|
||
logger.info(f"{i:3d}. {sql}")
|
||
return True
|
||
|
||
cursor = await self.connection.cursor()
|
||
|
||
try:
|
||
# 开始事务
|
||
await cursor.execute("START TRANSACTION")
|
||
logger.info("🔄 开始回滚事务")
|
||
|
||
for i, sql in enumerate(sql_statements, 1):
|
||
if sql.strip().startswith('--'):
|
||
logger.info(f"⏭️ 跳过注释: {sql}")
|
||
continue
|
||
|
||
try:
|
||
await cursor.execute(sql)
|
||
logger.info(f"✅ 执行成功 ({i}/{len(sql_statements)}): {sql[:100]}...")
|
||
except Exception as e:
|
||
logger.error(f"❌ 执行失败 ({i}/{len(sql_statements)}): {sql}")
|
||
logger.error(f" 错误信息: {e}")
|
||
raise
|
||
|
||
# 提交事务
|
||
await cursor.execute("COMMIT")
|
||
logger.info("✅ 回滚事务提交成功")
|
||
return True
|
||
|
||
except Exception as e:
|
||
# 回滚事务
|
||
await cursor.execute("ROLLBACK")
|
||
logger.error(f"❌ 回滚事务失败,已回滚: {e}")
|
||
return False
|
||
finally:
|
||
await cursor.close()
|
||
|
||
async def rollback_by_time(self,
|
||
target_time: datetime,
|
||
tables: Optional[List[str]] = None,
|
||
dry_run: bool = True) -> bool:
|
||
"""根据时间点进行回滚"""
|
||
logger.info(f"🎯 开始回滚到时间点: {target_time}")
|
||
|
||
# 查找对应的Binlog文件
|
||
binlog_file = await self.get_binlog_position_by_time(target_time)
|
||
if not binlog_file:
|
||
logger.error("❌ 未找到对应的Binlog文件")
|
||
return False
|
||
|
||
# 解析Binlog生成SQL
|
||
binlog_sql = self.parse_binlog_to_sql(
|
||
binlog_file=binlog_file,
|
||
start_time=target_time,
|
||
tables=tables
|
||
)
|
||
|
||
if not binlog_sql:
|
||
logger.error("❌ 解析Binlog失败")
|
||
return False
|
||
|
||
# 生成反向SQL
|
||
reverse_sqls = self.generate_reverse_sql(binlog_sql)
|
||
|
||
if not reverse_sqls:
|
||
logger.warning("⚠️ 未生成反向SQL语句")
|
||
return True
|
||
|
||
# 执行回滚
|
||
return await self.execute_rollback_sql(reverse_sqls, dry_run)
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description='MySQL Binlog 回滚工具')
|
||
parser.add_argument('--host', default='localhost', help='MySQL主机地址')
|
||
parser.add_argument('--port', type=int, default=3306, help='MySQL端口')
|
||
parser.add_argument('--user', default='root', help='MySQL用户名')
|
||
parser.add_argument('--password', default='root', help='MySQL密码')
|
||
parser.add_argument('--database', default='kaopeilian', help='数据库名')
|
||
parser.add_argument('--time', required=True, help='回滚到的时间点 (格式: YYYY-MM-DD HH:MM:SS)')
|
||
parser.add_argument('--tables', nargs='*', help='指定要回滚的表名')
|
||
parser.add_argument('--execute', action='store_true', help='实际执行回滚(默认只模拟)')
|
||
parser.add_argument('--list-binlogs', action='store_true', help='列出所有Binlog文件')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 创建回滚工具实例
|
||
tool = BinlogRollbackTool(
|
||
host=args.host,
|
||
port=args.port,
|
||
user=args.user,
|
||
password=args.password,
|
||
database=args.database
|
||
)
|
||
|
||
try:
|
||
await tool.connect()
|
||
|
||
if args.list_binlogs:
|
||
# 列出Binlog文件
|
||
binlog_files = await tool.get_binlog_files()
|
||
print("\n📋 Binlog文件列表:")
|
||
for i, file_info in enumerate(binlog_files, 1):
|
||
print(f"{i:2d}. {file_info['name']} ({file_info['size']} bytes)")
|
||
return
|
||
|
||
# 解析时间参数
|
||
try:
|
||
target_time = datetime.strptime(args.time, '%Y-%m-%d %H:%M:%S')
|
||
except ValueError:
|
||
logger.error("❌ 时间格式错误,请使用: YYYY-MM-DD HH:MM:SS")
|
||
return
|
||
|
||
# 执行回滚
|
||
dry_run = not args.execute
|
||
success = await tool.rollback_by_time(
|
||
target_time=target_time,
|
||
tables=args.tables,
|
||
dry_run=dry_run
|
||
)
|
||
|
||
if success:
|
||
if dry_run:
|
||
logger.info("🔍 模拟执行完成,使用 --execute 参数实际执行回滚")
|
||
else:
|
||
logger.info("✅ 回滚操作完成")
|
||
else:
|
||
logger.error("❌ 回滚操作失败")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 程序执行异常: {e}")
|
||
finally:
|
||
await tool.close()
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|