auth.py 7.22 KB
"""AuthBridge — 登录认证 + 当前用户 + 记住凭据。

替代旧 LoginDialog。QML LoginScreen 调 auth.login(user, pwd)(首次输入)
或 auth.loginWithSavedPassword(user)("记住密码"路径,跳过 hash 直查 db)。
登录成功后 QML 调 auth.saveCredentials(...) 写回 config.json。

PoC 模式(db_config = None):接受任意非空用户名密码,便于无 db 环境调 UI。
"""
import logging
import platform
import socket
from pathlib import Path
from typing import Optional

from PySide6.QtCore import Property, QObject, Signal, Slot

from core.database import DatabaseManager, hash_password


class AuthBridge(QObject):
    loggedInChanged = Signal()
    currentUserChanged = Signal()
    loginFailed = Signal(str)  # error_message

    def __init__(self, db_config: Optional[dict] = None, audit_logger=None,
                 last_user: str = "", saved_password_hash: str = "",
                 config_path: Optional[Path] = None, parent=None):
        super().__init__(parent)
        self._logger = logging.getLogger(__name__)
        self._db_config = db_config
        self._db = DatabaseManager(db_config) if db_config else None
        self._audit = audit_logger
        self._logged_in = False
        self._current_user = ""
        self._last_user = last_user or ""
        self._saved_password_hash = saved_password_hash or ""
        self._config_path = config_path

    @Property(bool, notify=loggedInChanged)
    def loggedIn(self) -> bool:
        return self._logged_in

    @Property(str, notify=currentUserChanged)
    def currentUser(self) -> str:
        return self._current_user

    @Property(str, constant=True)
    def lastUser(self) -> str:
        """启动期从 config.json 读到的"上次登录用户名",QML 用来预填 username 输入框。"""
        return self._last_user

    @Property(bool, constant=True)
    def hasSavedPassword(self) -> bool:
        """启动期 config.saved_password_hash 是否非空,QML 用来决定密码框是否显示 ••• 占位。"""
        return bool(self._saved_password_hash)

    @Slot(str, str, result=bool)
    def login(self, username: str, password: str) -> bool:
        """明文密码登录(用户首次输入或修改了密码框)。"""
        username = (username or "").strip()
        if not username or not password:
            self.loginFailed.emit("用户名和密码不能为空")
            return False

        # PoC 模式:无 db_config 时接受任意非空
        if self._db is None:
            self._on_login_success(username)
            return True

        ok, msg = self._db.authenticate(username, password)
        if not ok:
            self._logger.warning(f"登录失败: {username} - {msg}")
            self.loginFailed.emit(msg)
            return False

        self._on_login_success(username)
        return True

    @Slot(str, result=bool)
    def loginWithSavedPassword(self, username: str) -> bool:
        """用本地已存的 password_hash 登录("记住密码"路径,跳过 hash)。"""
        username = (username or "").strip()
        if not username:
            self.loginFailed.emit("用户名不能为空")
            return False
        if not self._saved_password_hash:
            self.loginFailed.emit("没有已保存的密码,请输入")
            return False

        if self._db is None:
            self._on_login_success(username)
            return True

        ok, msg = self._db.authenticate_with_hash(username, self._saved_password_hash)
        if not ok:
            self._logger.warning(f"已存密码登录失败: {username} - {msg}")
            self.loginFailed.emit(msg)
            return False

        self._on_login_success(username)
        return True

    @Slot(str, bool, bool)
    def saveCredentials(self, password: str, remember_user: bool, remember_password: bool) -> None:
        """登录成功后由 QML 调用,按勾选状态把 last_user / saved_password_hash 写回 config.json。

        password: 用户当前输入的明文(passwordChanged=True 时非空,反之为空字符串);
                  remember_password=True 但 password 空时保留旧 hash 不动(避免清零)。
        """
        if self._config_path is None:
            return
        from config_util import load_config_safe, save_config

        cfg, _ = load_config_safe(self._config_path)

        cfg["last_user"] = self._current_user if remember_user else ""

        if remember_password:
            if password:
                cfg["saved_password_hash"] = hash_password(password)
            # password 空 = 用户没改密码,保留旧 hash 不动
        else:
            cfg["saved_password_hash"] = ""

        if save_config(self._config_path, cfg):
            self._last_user = cfg["last_user"]
            self._saved_password_hash = cfg.get("saved_password_hash", "")
        else:
            self._logger.warning(f"saveCredentials 写盘失败: {self._config_path}")

    @Slot()
    def logout(self) -> None:
        self._logged_in = False
        self._current_user = ""
        self.loggedInChanged.emit()
        self.currentUserChanged.emit()

    @Slot(result=str)
    def deviceName(self) -> str:
        """供 audit 日志和 ImageGenBridge 使用"""
        try:
            return socket.gethostname() or platform.node() or "unknown"
        except Exception:
            return "unknown"

    # ---- 内部 -----------------------------------------------------------

    def _on_login_success(self, username: str) -> None:
        self._current_user = username
        self._logged_in = True
        self.currentUserChanged.emit()
        self.loggedInChanged.emit()
        self._logger.info(f"登录成功: {username}")

        if self._audit is not None:
            try:
                self._audit.log_login(
                    user_name=username,
                    local_ip=self._get_local_ip(),
                    public_ip=self._get_public_ip(),
                    device_name=self.deviceName(),
                )
            except Exception:
                self._logger.exception("audit log_login 失败(不影响登录)")

    @staticmethod
    def _get_local_ip() -> Optional[str]:
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
                s.settimeout(0.5)
                s.connect(("8.8.8.8", 80))
                return s.getsockname()[0]
        except Exception:
            return None

    def _get_public_ip(self) -> Optional[str]:
        """登录成功时拉一次公网 IP(与旧 LoginDialog.get_public_ip 一致)。

        三个 API 兜底,每个 3s timeout。失败返回 None;只用于 audit log,不阻塞 UI 流程
        (登录后跳主窗口前同步拿,最坏 ~3s,通常 < 500ms)。
        """
        try:
            import requests
        except Exception:
            return None
        for api in ("https://api.ipify.org", "https://ifconfig.me", "https://ipinfo.io/ip"):
            try:
                r = requests.get(api, timeout=3)
                if r.status_code == 200:
                    ip = r.text.strip()
                    if len(ip.split(".")) == 4 or ":" in ip:  # IPv4 / IPv6 粗筛
                        return ip
            except Exception:
                continue
        return None