imagegen.py 4.95 KB
"""ImageGenBridge — 图片生成 tab 的输入输出桥。

QML 调 imageGen.submitTask(prompt, refs, aspect, size, mode) → 桥层把 mode 中文
名翻译成 Gemini 模型 ID,调 TaskQueueManager.submit_task。任务完成后桥层
监听 TaskQueueManager.task_completed 信号,把 image_bytes 落 HistoryManager.save_generation,
然后把 result_path 通过 taskCompleted 信号转出(QML 只拿到文件路径,不传 bytes)。
"""
import logging
from typing import Optional

from PySide6.QtCore import Property, QObject, Signal, Slot

from core.generation import MODEL_BY_MODE, MODEL_PRO


class ImageGenBridge(QObject):
    apiKeyChanged = Signal()
    busyChanged = Signal()

    taskSubmitted = Signal(str)                    # task_id
    taskCompleted = Signal(str, str, str, str)     # task_id, result_path, prompt, model
    taskFailed = Signal(str, str)                  # task_id, error_message
    taskProgress = Signal(str, float, str)         # task_id, progress, status_text

    def __init__(self, task_queue_manager, history_manager, auth_bridge,
                 api_key: str = "", parent=None):
        super().__init__(parent)
        self._logger = logging.getLogger(__name__)
        self._tqm = task_queue_manager
        self._history = history_manager
        self._auth = auth_bridge
        self._api_key = api_key

        # 转发 TaskQueueManager 信号 → 桥层 QML 友好信号
        self._tqm.task_added.connect(self._on_task_added)
        self._tqm.task_progress.connect(self._on_progress)
        self._tqm.task_completed.connect(self._on_completed)
        self._tqm.task_failed.connect(self._on_failed)

    # ---- Properties -----------------------------------------------------

    @Property(str, notify=apiKeyChanged)
    def apiKey(self) -> str:
        return self._api_key

    @Property(bool, notify=busyChanged)
    def busy(self) -> bool:
        return self._tqm.get_running_count() > 0

    # ---- Slots ----------------------------------------------------------

    @Slot(str)
    def setApiKey(self, key: str) -> None:
        if key != self._api_key:
            self._api_key = key
            self.apiKeyChanged.emit()

    @Slot(str, list, str, str, str, result=str)
    def submitTask(self, prompt: str, reference_images: list,
                   aspect_ratio: str, image_size: str, mode: str) -> str:
        """提交一条生成任务,返回 task_id。失败抛 RuntimeError。

        Args:
            prompt: 中文提示词
            reference_images: 参考图本地路径 list[str]
            aspect_ratio: '1:1' / '2:3' / ...
            image_size: '1K' / '2K' / '4K'
            mode: '极速模式' 或 '慢速模式'
        """
        from task_queue import TaskType  # 局部 import 避免桥层冷启动加载 Qt UI

        model = MODEL_BY_MODE.get(mode, MODEL_PRO)
        task_id = self._tqm.submit_task(
            task_type=TaskType.IMAGE_GENERATION,
            prompt=prompt,
            api_key=self._api_key,
            reference_images=list(reference_images or []),
            aspect_ratio=aspect_ratio,
            image_size=image_size,
            model=model,
            user_name=self._auth.currentUser if self._auth else "",
            device_name=self._auth.deviceName() if self._auth else "",
        )
        self._logger.info(f"提交生成任务: {task_id[:8]} - mode={mode}")
        return task_id

    @Slot(str)
    def cancelTask(self, task_id: str) -> None:
        self._tqm.cancel_task(task_id)
        self.busyChanged.emit()

    # ---- 内部信号转发 ----------------------------------------------------

    def _on_task_added(self, task) -> None:
        self.taskSubmitted.emit(task.id)
        self.busyChanged.emit()

    def _on_progress(self, task_id: str, progress: float, status_text: str) -> None:
        self.taskProgress.emit(task_id, progress, status_text)

    def _on_completed(self, task_id: str, image_bytes: bytes, prompt: str,
                      reference_images: list, aspect_ratio: str,
                      image_size: str, model: str) -> None:
        # 落历史记录(QML 只看路径)
        try:
            timestamp = self._history.save_generation(
                image_bytes=image_bytes,
                prompt=prompt,
                reference_images=reference_images,
                aspect_ratio=aspect_ratio,
                image_size=image_size,
                model=model,
            )
            result_path = str(self._history.base_path / timestamp / "generated.png")
        except Exception as e:
            self._logger.error(f"保存历史失败 {task_id[:8]}: {e}", exc_info=True)
            self.taskFailed.emit(task_id, f"图片生成成功但保存历史失败: {e}")
            self.busyChanged.emit()
            return

        self.taskCompleted.emit(task_id, result_path, prompt, model)
        self.busyChanged.emit()

    def _on_failed(self, task_id: str, error: str) -> None:
        self.taskFailed.emit(task_id, error)
        self.busyChanged.emit()