preflight.py 5.99 KB
"""
启动门禁:保证审计日志上传的所有前置条件都成立。
任一失败即阻止应用进入主流程,对用户只显示一句"应用启动失败,请联系 @柴进"。
详细错误脱敏后写入 logs/preflight_error.log。
"""
from __future__ import annotations

import logging
import re
import sys
import traceback
from datetime import datetime
from pathlib import Path
from typing import Tuple

import pymysql

from config_util import load_config_safe


logger = logging.getLogger(__name__)


REQUIRED_DB_FIELDS = ("host", "port", "user", "password", "database")
REQUIRED_TABLES = ("nano_banana_user_use_log", "nano_banana_user_log")
REQUIRED_USE_LOG_COLUMNS = (
    "user_name", "device_name", "prompt", "result_path", "status",
    "error_message", "model", "duration_ms", "finish_reason",
)
REQUIRED_LOGIN_LOG_COLUMNS = (
    "user_name", "local_ip", "public_ip", "device_name", "login_time",
)


def preflight_check(config_path: Path, audit_queue_path: Path) -> Tuple[bool, str, dict]:
    """
    返回 (ok, error_detail, config)。
    - ok=True: 一切就绪,调用方可以继续启动
    - ok=False: error_detail 为详细错误描述(未脱敏;handle_preflight_failure 会脱敏后落盘)
    - config: 成功时为可用 config dict;失败时可能为部分加载或 DEFAULT_CONFIG
    """
    # 1. config.json
    try:
        config, load_err = load_config_safe(config_path)
    except Exception as e:
        return False, f"config load crashed:\n{traceback.format_exc()}", {}

    if load_err:
        return False, f"config load error: {load_err}", config

    # 2. db_config 字段完整
    db = config.get("db_config")
    if not db or not isinstance(db, dict):
        return False, "config.json 缺少 db_config 字段或格式错误", config

    missing = [k for k in REQUIRED_DB_FIELDS if not db.get(k)]
    if missing:
        return False, f"db_config 缺少字段: {missing}", config

    # 3. MySQL 连接 + SELECT 1
    conn = None
    try:
        conn = pymysql.connect(
            host=db["host"],
            port=int(db["port"]),
            user=db["user"],
            password=db["password"],
            database=db["database"],
            connect_timeout=5,
            read_timeout=5,
            write_timeout=5,
            charset="utf8mb4",
        )
    except Exception as e:
        return False, f"MySQL connect 失败: {type(e).__name__}: {e}", config

    try:
        with conn.cursor() as cur:
            cur.execute("SELECT 1")
            cur.fetchone()

            # 4. 表存在
            for table in REQUIRED_TABLES:
                try:
                    cur.execute(f"SELECT 1 FROM `{table}` LIMIT 1")
                    cur.fetchone()
                except Exception as e:
                    return False, f"审计表 {table} 不可用: {type(e).__name__}: {e}", config

            # 5. 必要列存在
            ok, col_err = _check_columns(cur, db["database"], "nano_banana_user_use_log",
                                         REQUIRED_USE_LOG_COLUMNS)
            if not ok:
                return False, col_err, config
            ok, col_err = _check_columns(cur, db["database"], "nano_banana_user_log",
                                         REQUIRED_LOGIN_LOG_COLUMNS)
            if not ok:
                return False, col_err, config
    finally:
        try:
            conn.close()
        except Exception:
            pass

    # 6. 本地队列目录可写
    try:
        audit_queue_path.parent.mkdir(parents=True, exist_ok=True)
        probe = audit_queue_path.parent / ".preflight_probe"
        probe.write_text("ok", encoding="utf-8")
        probe.unlink()
    except Exception as e:
        return False, f"审计队列目录不可写 ({audit_queue_path.parent}): {e}", config

    return True, "", config


def _check_columns(cur, db_name: str, table: str, required: tuple[str, ...]) -> Tuple[bool, str]:
    cur.execute(
        "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS "
        "WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
        (db_name, table),
    )
    existing = {row[0] for row in cur.fetchall()}
    missing = [c for c in required if c not in existing]
    if missing:
        return False, f"表 {table} 缺少列: {missing}(请运行 migrations/2026-04-21_add_audit_log_columns.sql)"
    return True, ""


def handle_preflight_failure(detail: str, logs_dir: Path) -> None:
    """
    写入脱敏详情到 logs/preflight_error.log,显示单行对话框,sys.exit(1)。
    调用此函数前必须已经创建 QApplication。
    """
    from PySide6.QtWidgets import QMessageBox, QApplication

    # 写日志(脱敏)
    try:
        logs_dir.mkdir(parents=True, exist_ok=True)
        err_log = logs_dir / "preflight_error.log"
        with open(err_log, "a", encoding="utf-8") as f:
            f.write(f"\n===== {datetime.now().isoformat(timespec='seconds')} =====\n")
            f.write(_scrub(detail))
            f.write("\n")
    except Exception:
        pass

    # 对用户:一句话
    try:
        app = QApplication.instance()
        if app is None:
            # preflight 失败比 QApplication 创建还早的极端情况(不应发生)
            app = QApplication(sys.argv)
        box = QMessageBox()
        box.setIcon(QMessageBox.Critical)
        box.setWindowTitle("启动失败")
        box.setText("应用启动失败,请联系 @柴进")
        box.setStandardButtons(QMessageBox.Ok)
        box.exec()
    except Exception:
        # 最坏情况:连对话框都弹不出来
        print("应用启动失败,请联系 @柴进", file=sys.stderr)

    sys.exit(1)


_SCRUB_PATTERNS = [
    (re.compile(r'("password"\s*:\s*)"[^"]*"'), r'\1"***"'),
    (re.compile(r'("api_key"\s*:\s*)"[^"]*"'), r'\1"***"'),
    (re.compile(r"(password\s*=\s*)\S+"), r"\1***"),
    (re.compile(r"(api_key\s*=\s*)\S+"), r"\1***"),
]


def _scrub(detail: str) -> str:
    """从详情里擦除 password / api_key。"""
    out = detail
    for pat, repl in _SCRUB_PATTERNS:
        out = pat.sub(repl, out)
    return out