generation.py 6.17 KB
"""图像生成 Worker + Gemini 模型常量。

ImageGenerationWorker 是 QThread,由 TaskQueueManager 拉起执行单条生成任务。
任务参数(prompt / 参考图 / aspect_ratio / image_size / model)从队列传入,
完成后通过 finished/error/progress 信号回报。
"""
import base64
import logging
import os
from typing import Optional

from PySide6.QtCore import QThread, Signal
from google import genai
from google.genai import types


# 生成模式 -> Gemini 模型 ID 映射(单一真相源,消除原先两处 get_selected_model 复制粘贴)
# 极速模式:Nano Banana 2 (Gemini 3.1 Flash Image), 指令遵循强于 2.5-flash-image
# 慢速模式:Nano Banana Pro (Gemini 3 Pro Image Preview)
MODEL_BY_MODE = {
    "极速模式": "gemini-3.1-flash-image-preview",
    "慢速模式": "gemini-3-pro-image-preview",
}
MODEL_PRO = MODEL_BY_MODE["慢速模式"]  # 用于 Worker 中判断是否支持 image_size 参数

# Nano Banana 2 (Flash) 独占的宽高比 —— Pro 不支持,选中这些时需提示切换到极速模式
FLASH_ONLY_ASPECT_RATIOS = {"1:4", "4:1", "1:8", "8:1"}


class ImageGenerationWorker(QThread):
    """Worker thread for image generation"""
    finished = Signal(bytes, str, list, str, str,
                      str)  # image_bytes, prompt, reference_images, aspect_ratio, image_size, model
    error = Signal(str)
    progress = Signal(str)

    def __init__(self, api_key, prompt, images, aspect_ratio, image_size, model=MODEL_PRO):
        super().__init__()
        self.logger = logging.getLogger(__name__)
        self.api_key = api_key
        self.prompt = prompt
        self.images = images
        self.aspect_ratio = aspect_ratio
        self.image_size = image_size
        self.model = model

        # 审计元信息:供 TaskQueueManager 在信号回调中读取
        self.finish_reason: Optional[str] = None

        self.logger.info(f"图片生成任务初始化 - 模型: {model}, 尺寸: {image_size}, 宽高比: {aspect_ratio}")

    def _extract_finish_reason(self, response) -> Optional[str]:
        """从 Gemini 响应提取 finish_reason,失败返回 None(不抛异常)。"""
        try:
            fr = response.candidates[0].finish_reason
            if fr is None:
                return None
            name = getattr(fr, "name", None)
            return name if name else str(fr)
        except Exception:
            return None

    def run(self):
        """Execute image generation in background thread"""
        try:
            self.logger.info("开始图片生成任务")

            if not self.prompt:
                self.logger.error("图片描述为空")
                self.error.emit("请输入图片描述!")
                return

            if not self.api_key:
                self.logger.error("API密钥为空")
                self.error.emit("未找到API密钥,请在config.json中配置!")
                return

            self.progress.emit("正在连接 Gemini API...")
            self.logger.debug("正在连接 Gemini API")

            client = genai.Client(api_key=self.api_key)

            content_parts = [self.prompt]

            for img_path in self.images:
                with open(img_path, 'rb') as f:
                    img_data = f.read()

                mime_type = "image/png"
                if img_path.lower().endswith(('.jpg', '.jpeg')):
                    mime_type = "image/jpeg"

                content_parts.append(
                    types.Part.from_bytes(
                        data=img_data,
                        mime_type=mime_type
                    )
                )

            self.progress.emit("正在生成图片...")

            # 当前使用的两个模型都支持 aspect_ratio + image_size:
            #   - gemini-3.1-flash-image-preview (Nano Banana 2): 512/1K/2K/4K + 14 种 ratio
            #   - gemini-3-pro-image-preview (Nano Banana Pro):   1K/2K/4K
            config = types.GenerateContentConfig(
                response_modalities=["TEXT", "IMAGE"],
                image_config=types.ImageConfig(
                    aspect_ratio=self.aspect_ratio,
                    image_size=self.image_size
                )
            )

            response = client.models.generate_content(
                model=self.model,
                contents=content_parts,
                config=config
            )
            self.finish_reason = self._extract_finish_reason(response)

            text_fragments = []
            parts = response.parts or []
            for part in parts:
                if hasattr(part, 'inline_data') and part.inline_data:
                    if isinstance(part.inline_data.data, bytes):
                        image_bytes = part.inline_data.data
                    else:
                        image_bytes = base64.b64decode(part.inline_data.data)

                    reference_images_bytes = []
                    for img_path in self.images:
                        if img_path and os.path.exists(img_path):
                            with open(img_path, 'rb') as f:
                                reference_images_bytes.append(f.read())
                        else:
                            reference_images_bytes.append(b'')

                    self.logger.info(
                        f"图片生成成功 - 模型: {self.model}, 尺寸: {self.image_size}, "
                        f"finish_reason={self.finish_reason}"
                    )
                    self.finished.emit(image_bytes, self.prompt, reference_images_bytes,
                                       self.aspect_ratio, self.image_size, self.model)
                    return
                if getattr(part, 'text', None):
                    text_fragments.append(part.text)

            detail = " | ".join(t for t in text_fragments if t).strip()
            error_msg = f"响应中没有图片数据 (finish_reason={self.finish_reason})"
            if detail:
                error_msg += f"\n模型说明: {detail}"
            self.logger.error(error_msg)
            self.error.emit(error_msg)

        except Exception as e:
            error_msg = f"图片生成异常: {e}"
            self.logger.error(error_msg, exc_info=True)
            self.error.emit(error_msg)