taskqueue.py 8.13 KB
"""TaskQueueBridge — 任务队列 sidebar 桥层。

QML sidebar ListView 直接用 taskQueue.model(QAbstractListModel),
桥层监听 TaskQueueManager 单例信号,按 task_id 增量更新 model。

只持有最近 N 条(与 TaskQueueManager 自身的 _max_history_size 一致),
更老的任务会随 TaskQueueManager._cleanup_old_tasks 自然消失。
"""
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List

from PySide6.QtCore import (
    Property, QAbstractListModel, QModelIndex, QObject, Qt, Signal, Slot,
)

from core.generation import MODEL_BY_MODE


class _TaskListModel(QAbstractListModel):
    """暴露给 QML ListView 的任务列表模型。

    Roles:taskId / prompt / status / progress / statusText / elapsed
    """
    TaskIdRole = Qt.UserRole + 1
    PromptRole = Qt.UserRole + 2
    StatusRole = Qt.UserRole + 3
    ProgressRole = Qt.UserRole + 4
    StatusTextRole = Qt.UserRole + 5
    ElapsedRole = Qt.UserRole + 6

    def __init__(self, parent=None):
        super().__init__(parent)
        self._ids: List[str] = []
        self._rows: Dict[str, dict] = {}

    def rowCount(self, parent=QModelIndex()) -> int:
        if parent.isValid():
            return 0
        return len(self._ids)

    def roleNames(self):
        return {
            _TaskListModel.TaskIdRole: b"taskId",
            _TaskListModel.PromptRole: b"prompt",
            _TaskListModel.StatusRole: b"status",
            _TaskListModel.ProgressRole: b"progress",
            _TaskListModel.StatusTextRole: b"statusText",
            _TaskListModel.ElapsedRole: b"elapsed",
        }

    def data(self, index: QModelIndex, role: int = Qt.DisplayRole):
        if not index.isValid():
            return None
        row = index.row()
        if row < 0 or row >= len(self._ids):
            return None
        record = self._rows.get(self._ids[row], {})
        if role == _TaskListModel.TaskIdRole:
            return record.get("task_id", "")
        if role == _TaskListModel.PromptRole:
            return record.get("prompt", "")
        if role == _TaskListModel.StatusRole:
            return record.get("status", "")
        if role == _TaskListModel.ProgressRole:
            return float(record.get("progress", 0.0))
        if role == _TaskListModel.StatusTextRole:
            return record.get("status_text", "")
        if role == _TaskListModel.ElapsedRole:
            return record.get("elapsed", "")
        return None

    # ---- 增量操作(桥层调用)---------------------------------------------

    def upsert(self, task_id: str, **fields) -> None:
        if task_id in self._rows:
            self._rows[task_id].update(fields)
            row = self._ids.index(task_id)
            top = self.index(row, 0)
            self.dataChanged.emit(top, top)
        else:
            self.beginInsertRows(QModelIndex(), 0, 0)
            self._ids.insert(0, task_id)
            self._rows[task_id] = {"task_id": task_id, **fields}
            self.endInsertRows()


class TaskQueueBridge(QObject):
    pendingCountChanged = Signal()
    runningCountChanged = Signal()
    # sidebar 点击任务项 → 让目标 tab 回填 prompt/参考图/设置/结果图(旧 _load_task_to_main_window 等价)
    taskLoadRequested = Signal("QVariantMap")

    def __init__(self, task_queue_manager, parent=None):
        super().__init__(parent)
        self._logger = logging.getLogger(__name__)
        self._tqm = task_queue_manager
        self._model = _TaskListModel(self)

        self._tqm.task_added.connect(self._on_task_added)
        self._tqm.task_started.connect(self._on_task_started)
        self._tqm.task_completed.connect(self._on_task_completed)
        self._tqm.task_failed.connect(self._on_task_failed)
        self._tqm.task_progress.connect(self._on_progress)

    # ---- Properties -----------------------------------------------------

    @Property(QObject, constant=True)
    def model(self):
        return self._model

    @Property(int, notify=pendingCountChanged)
    def pendingCount(self) -> int:
        return self._tqm.get_pending_count()

    @Property(int, notify=runningCountChanged)
    def runningCount(self) -> int:
        return self._tqm.get_running_count()

    # ---- Slots ----------------------------------------------------------

    @Slot(str)
    def cancelTask(self, task_id: str) -> None:
        self._tqm.cancel_task(task_id)
        self.pendingCountChanged.emit()
        self.runningCountChanged.emit()

    @Slot(str)
    def loadTask(self, task_id: str) -> None:
        """点击 sidebar 任务项 → 发 taskLoadRequested 信号,让对应 tab 回填字段。

        payload 字段(QML 友好的 dict):
          taskId, type ("image_gen" | "style_design"),
          prompt, referenceImages (list[str], 已过滤掉磁盘失效路径),
          aspectRatio, imageSize, mode ("极速模式" | "慢速模式"),
          resultPath (str, 仅已完成任务有;空字符串表示未完成 / 失败 / 取消)
        """
        from task_queue import TaskStatus, TaskType
        task = self._tqm.get_task(task_id)
        if task is None:
            self._logger.warning(f"loadTask: 任务不存在 {task_id[:8]}")
            return

        type_str = "style_design" if task.type == TaskType.STYLE_DESIGN else "image_gen"

        # model_id → mode 中文名(生成时记 model_id,回填要还原 ComboBox 文字)
        mode = "慢速模式"
        for k, v in MODEL_BY_MODE.items():
            if v == task.model:
                mode = k
                break

        # 只保留磁盘上仍存在的参考图路径(旧任务可能引用已删文件)
        valid_refs = []
        for p in (task.reference_images or []):
            if not p:
                continue
            try:
                if Path(p).exists():
                    valid_refs.append(Path(p).as_posix())
            except Exception:
                continue

        result_path = task.result_path if task.status == TaskStatus.COMPLETED else ""

        payload = {
            "taskId": task_id,
            "type": type_str,
            "prompt": task.prompt or "",
            "referenceImages": valid_refs,
            "aspectRatio": task.aspect_ratio or "",
            "imageSize": task.image_size or "",
            "mode": mode,
            "resultPath": result_path or "",
        }
        self._logger.info(
            f"loadTask emit: {task_id[:8]} type={type_str} "
            f"refs={len(valid_refs)} hasResult={bool(result_path)}"
        )
        self.taskLoadRequested.emit(payload)

    # ---- 信号转 model 增量 -----------------------------------------------

    def _on_task_added(self, task) -> None:
        self._model.upsert(
            task.id,
            prompt=task.prompt,
            status="pending",
            progress=0.0,
            status_text="等待中",
            elapsed="",
        )
        self.pendingCountChanged.emit()

    def _on_task_started(self, task_id: str) -> None:
        self._model.upsert(task_id, status="running", status_text="生成中…")
        self.pendingCountChanged.emit()
        self.runningCountChanged.emit()

    def _on_progress(self, task_id: str, progress: float, status_text: str) -> None:
        self._model.upsert(task_id, progress=progress, status_text=status_text)

    def _on_task_completed(self, task_id, *_args) -> None:
        elapsed = self._format_elapsed(task_id)
        self._model.upsert(task_id, status="completed", progress=1.0,
                           status_text="已完成", elapsed=elapsed)
        self.runningCountChanged.emit()

    def _on_task_failed(self, task_id: str, error: str) -> None:
        elapsed = self._format_elapsed(task_id)
        self._model.upsert(task_id, status="failed", status_text=error or "失败",
                           elapsed=elapsed)
        self.pendingCountChanged.emit()
        self.runningCountChanged.emit()

    def _format_elapsed(self, task_id: str) -> str:
        task = self._tqm.get_task(task_id)
        if task and task.started_at and task.completed_at:
            secs = (task.completed_at - task.started_at).total_seconds()
            return f"{secs:.1f}s"
        return ""