12d6934e by 柴进

refactor: 业务核心从 image_generator.py 拆到 core/

5082 行的怪兽塞了 12 个类(数据库/历史/Worker/词库/UI),现在按职责拆开:

  core/paths.py       get_app_data_path + save_png_with_validation + 启动迁移
  core/database.py    DatabaseManager + hash_password
  core/history.py     HistoryItem + HistoryListModel + HistoryManager
  core/generation.py  ImageGenerationWorker + Gemini 模型常量
  core/jewelry.py     DEFAULT_JEWELRY_LIBRARY + JewelryLibraryManager + PromptAssembler

image_generator.py 5082 → 3781 行,剩下全是 QWidget UI 类
(LoginDialog / DraggableThumbnail / DragDropScrollArea / ImageGeneratorWindow /
StyleDesignerTab + utils + main),task #19 QML 全量切换后整体删除。

外部消费者改 import:
  task_queue.py / temp_clean.py: from image_generator → from core.generation
  image_generator.py 顶部 from core.* 引入,LoginDialog 等内部代码无感知

冒烟测试:image_generator/task_queue/core 各自 import 通过,类身份正确。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1fd96b05
"""业务核心模块。
从 image_generator.py 拆出,保留 QWidget UI 类在原文件供 task #19 整删。
QML 桥层和外部模块直接 import core.* 而不是 image_generator。
"""
"""数据库连接 + 用户认证。
DatabaseManager 只负责 MySQL 连接和 user table 的 SHA256 密码核验。
db_config 由 config_util 加载并传入;本模块不读 config 文件本身。
"""
import hashlib
import logging
import pymysql
def hash_password(password: str) -> str:
"""使用 SHA256 哈希密码"""
return hashlib.sha256(password.encode('utf-8')).hexdigest()
class DatabaseManager:
"""数据库连接管理类"""
def __init__(self, db_config):
self.config = db_config
self.logger = logging.getLogger(__name__)
def authenticate(self, username, password):
"""
验证用户凭证
返回: (success: bool, message: str)
"""
try:
self.logger.info(f"开始用户认证: {username}")
password_hash = hash_password(password)
self.logger.debug(f"连接数据库: {self.config['host']}:{self.config.get('port', 3306)}")
conn = pymysql.connect(
host=self.config['host'],
port=self.config.get('port', 3306),
user=self.config['user'],
password=self.config['password'],
database=self.config['database'],
connect_timeout=5
)
try:
with conn.cursor() as cursor:
sql = f"SELECT * FROM {self.config['table']} WHERE user_name=%s AND passwd=%s AND status='active'"
cursor.execute(sql, (username, password_hash))
result = cursor.fetchone()
if result:
self.logger.info(f"用户认证成功: {username}")
return True, "认证成功"
else:
self.logger.warning(f"用户认证失败: {username} - 用户名或密码错误")
return False, "用户名或密码错误"
finally:
conn.close()
except pymysql.OperationalError as e:
error_msg = "无法连接到服务器,请检查网络连接"
self.logger.error(f"数据库连接失败: {e}")
return False, error_msg
except Exception as e:
error_msg = f"认证失败: {str(e)}"
self.logger.error(f"认证过程异常: {e}")
return False, error_msg
"""图像生成 Worker + Gemini 模型常量。
ImageGenerationWorker 是 QThread,由 TaskQueueManager 拉起执行单条生成任务。
任务参数(prompt / 参考图 / aspect_ratio / image_size / model)从队列传入,
完成后通过 finished/error/progress 信号回报。
"""
import base64
import logging
import os
from typing import Optional
from PySide6.QtCore import QThread, Signal
from google import genai
from google.genai import types
# 生成模式 -> Gemini 模型 ID 映射(单一真相源,消除原先两处 get_selected_model 复制粘贴)
# 极速模式:Nano Banana 2 (Gemini 3.1 Flash Image), 指令遵循强于 2.5-flash-image
# 慢速模式:Nano Banana Pro (Gemini 3 Pro Image Preview)
MODEL_BY_MODE = {
"极速模式": "gemini-3.1-flash-image-preview",
"慢速模式": "gemini-3-pro-image-preview",
}
MODEL_PRO = MODEL_BY_MODE["慢速模式"] # 用于 Worker 中判断是否支持 image_size 参数
# Nano Banana 2 (Flash) 独占的宽高比 —— Pro 不支持,选中这些时需提示切换到极速模式
FLASH_ONLY_ASPECT_RATIOS = {"1:4", "4:1", "1:8", "8:1"}
class ImageGenerationWorker(QThread):
"""Worker thread for image generation"""
finished = Signal(bytes, str, list, str, str,
str) # image_bytes, prompt, reference_images, aspect_ratio, image_size, model
error = Signal(str)
progress = Signal(str)
def __init__(self, api_key, prompt, images, aspect_ratio, image_size, model=MODEL_PRO):
super().__init__()
self.logger = logging.getLogger(__name__)
self.api_key = api_key
self.prompt = prompt
self.images = images
self.aspect_ratio = aspect_ratio
self.image_size = image_size
self.model = model
# 审计元信息:供 TaskQueueManager 在信号回调中读取
self.finish_reason: Optional[str] = None
self.logger.info(f"图片生成任务初始化 - 模型: {model}, 尺寸: {image_size}, 宽高比: {aspect_ratio}")
def _extract_finish_reason(self, response) -> Optional[str]:
"""从 Gemini 响应提取 finish_reason,失败返回 None(不抛异常)。"""
try:
fr = response.candidates[0].finish_reason
if fr is None:
return None
name = getattr(fr, "name", None)
return name if name else str(fr)
except Exception:
return None
def run(self):
"""Execute image generation in background thread"""
try:
self.logger.info("开始图片生成任务")
if not self.prompt:
self.logger.error("图片描述为空")
self.error.emit("请输入图片描述!")
return
if not self.api_key:
self.logger.error("API密钥为空")
self.error.emit("未找到API密钥,请在config.json中配置!")
return
self.progress.emit("正在连接 Gemini API...")
self.logger.debug("正在连接 Gemini API")
client = genai.Client(api_key=self.api_key)
content_parts = [self.prompt]
for img_path in self.images:
with open(img_path, 'rb') as f:
img_data = f.read()
mime_type = "image/png"
if img_path.lower().endswith(('.jpg', '.jpeg')):
mime_type = "image/jpeg"
content_parts.append(
types.Part.from_bytes(
data=img_data,
mime_type=mime_type
)
)
self.progress.emit("正在生成图片...")
# 当前使用的两个模型都支持 aspect_ratio + image_size:
# - gemini-3.1-flash-image-preview (Nano Banana 2): 512/1K/2K/4K + 14 种 ratio
# - gemini-3-pro-image-preview (Nano Banana Pro): 1K/2K/4K
config = types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
image_config=types.ImageConfig(
aspect_ratio=self.aspect_ratio,
image_size=self.image_size
)
)
response = client.models.generate_content(
model=self.model,
contents=content_parts,
config=config
)
self.finish_reason = self._extract_finish_reason(response)
text_fragments = []
parts = response.parts or []
for part in parts:
if hasattr(part, 'inline_data') and part.inline_data:
if isinstance(part.inline_data.data, bytes):
image_bytes = part.inline_data.data
else:
image_bytes = base64.b64decode(part.inline_data.data)
reference_images_bytes = []
for img_path in self.images:
if img_path and os.path.exists(img_path):
with open(img_path, 'rb') as f:
reference_images_bytes.append(f.read())
else:
reference_images_bytes.append(b'')
self.logger.info(
f"图片生成成功 - 模型: {self.model}, 尺寸: {self.image_size}, "
f"finish_reason={self.finish_reason}"
)
self.finished.emit(image_bytes, self.prompt, reference_images_bytes,
self.aspect_ratio, self.image_size, self.model)
return
if getattr(part, 'text', None):
text_fragments.append(part.text)
detail = " | ".join(t for t in text_fragments if t).strip()
error_msg = f"响应中没有图片数据 (finish_reason={self.finish_reason})"
if detail:
error_msg += f"\n模型说明: {detail}"
self.logger.error(error_msg)
self.error.emit(error_msg)
except Exception as e:
error_msg = f"图片生成异常: {e}"
self.logger.error(error_msg, exc_info=True)
self.error.emit(error_msg)
"""应用数据路径 + 图片格式校验工具。
跨平台数据目录探测(macOS .app 包外存储 / Windows APPDATA / 开发环境同目录),
PNG/JPEG 格式回正(Pillow 重写防止伪装 MIME)。
"""
import io
import logging
import os
import platform
import shutil
import sys
from pathlib import Path
def _migrate_data_from_app_bundle(target_path: Path):
"""将 .app 内部的旧数据迁移到外部目录(仅 macOS 打包环境)"""
if not (getattr(sys, 'frozen', False) and platform.system() == "Darwin"):
return
old_path = Path(sys.executable).parent / "images"
if not old_path.exists() or old_path == target_path:
return
old_files = list(old_path.rglob("*"))
if not old_files:
return
try:
target_path.mkdir(parents=True, exist_ok=True)
migrated = 0
for src_file in old_path.rglob("*"):
if src_file.is_file():
rel_path = src_file.relative_to(old_path)
dst_file = target_path / rel_path
if not dst_file.exists():
dst_file.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(str(src_file), str(dst_file))
migrated += 1
print(f"已从 .app 内部迁移 {migrated} 个文件到: {target_path}")
except Exception as e:
print(f"数据迁移失败(不影响使用): {e}")
def get_app_data_path() -> Path:
"""获取应用数据存储路径 - 智能选择"""
def get_candidate_paths():
system = platform.system()
candidates = []
if getattr(sys, 'frozen', False) and system == "Darwin":
candidates.append(Path.home() / "Library/Application Support/ZB100ImageGenerator/images")
elif getattr(sys, 'frozen', False):
candidates.append(Path(sys.executable).parent / "images")
else:
# 开发环境:保持和老路径一致 —— 项目根目录下的 images/
# __file__ 在 core/,需要往上一级
candidates.append(Path(__file__).resolve().parent.parent / "images")
if system == "Darwin":
candidates.append(Path.home() / "Library/Application Support/ZB100ImageGenerator/images")
candidates.append(Path.home() / "Documents/ZB100ImageGenerator/images")
elif system == "Windows":
candidates.append(Path(os.environ.get("APPDATA", "")) / "ZB100ImageGenerator/images")
candidates.append(Path.home() / "Documents/ZB100ImageGenerator/images")
else:
candidates.append(Path.home() / ".config/zb100imagegenerator/images")
candidates.append(Path.home() / "Documents/ZB100ImageGenerator/images")
return candidates
def test_path_write_access(path: Path) -> bool:
try:
path.mkdir(parents=True, exist_ok=True)
test_file = path / ".write_test"
test_file.write_text("test")
test_file.unlink()
return True
except (PermissionError, OSError) as e:
print(f"路径 {path} 无写入权限: {e}")
return False
except Exception as e:
print(f"路径 {path} 测试失败: {e}")
return False
candidates = get_candidate_paths()
for path in candidates:
if test_path_write_access(path):
_migrate_data_from_app_bundle(path)
print(f"使用图片存储路径: {path}")
return path
fallback_path = get_candidate_paths()[0]
try:
fallback_path.mkdir(parents=True, exist_ok=True)
print(f"使用备选路径: {fallback_path}")
return fallback_path
except Exception as e:
print(f"警告: 无法创建存储路径,将在当前目录操作: {e}")
return Path.cwd() / "images"
def save_png_with_validation(file_path: str, image_bytes: bytes) -> bool:
"""使用 Pillow 验证并重写 PNG/JPEG,确保 MIME 与扩展名一致。
返回 True 表示 Pillow 处理成功;False 表示 Pillow 不可用或处理失败,
调用方应回退到原始 write_bytes。
"""
try:
from PIL import Image
with Image.open(io.BytesIO(image_bytes)) as img:
file_format = img.format
if file_format == 'JPEG':
logger = logging.getLogger(__name__)
logger.info(f"检测到伪装PNG的JPEG文件,实际格式: {file_format}")
save_format = 'PNG' if file_path.lower().endswith('.png') else 'JPEG'
if file_format and file_format != save_format:
logger = logging.getLogger(__name__)
logger.info(f"执行格式转换: {file_format} -> {save_format}")
if save_format == 'PNG':
if img.mode not in ['RGBA', 'RGB', 'L']:
if img.mode == 'P':
img = img.convert('RGBA')
elif img.mode == 'LA':
img = img.convert('RGBA')
else:
img = img.convert('RGBA')
elif save_format == 'JPEG':
if img.mode in ['RGBA', 'P']:
img = img.convert('RGB')
elif img.mode == 'L':
img = img.convert('RGB')
img.save(file_path, save_format, optimize=True)
logger = logging.getLogger(__name__)
logger.info(f"图片格式验证成功: {file_path}, 保存格式: {save_format}")
return True
except ImportError:
logger = logging.getLogger(__name__)
logger.warning("Pillow库不可用,使用原始保存方法")
return False
except Exception as e:
logger = logging.getLogger(__name__)
logger.warning(f"Pillow处理失败,使用原始保存方法: {e}")
return False
......@@ -192,7 +192,7 @@ class TaskQueueManager(QObject):
self.logger.info(f"开始处理任务: {task_id[:8]}")
# 导入 ImageGenerationWorker
from image_generator import ImageGenerationWorker
from core.generation import ImageGenerationWorker
# 创建 worker
self._current_worker = ImageGenerationWorker(
......
......@@ -177,7 +177,7 @@ class TaskQueueManager(QObject):
self.logger.info(f"开始处理任务: {task_id[:8]}")
# 导入 ImageGenerationWorker
from image_generator import ImageGenerationWorker
from core.generation import ImageGenerationWorker
# 创建 worker
self._current_worker = ImageGenerationWorker(
......