temp_clean.py
9.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
"""
任务队列系统
提供异步图像生成任务的队列管理和 UI 组件
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional, List, Dict
from queue import Queue
from threading import Lock
import uuid
import logging
import io
from PySide6.QtCore import QObject, Signal, QTimer, Qt
from PySide6.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QLabel,
QPushButton, QListWidget, QListWidgetItem, QDialog, QScrollArea, QFrame
)
from PySide6.QtGui import QPixmap, QMouseEvent
from PIL import Image
class TaskType(Enum):
"""任务类型"""
IMAGE_GENERATION = "image_gen"
STYLE_DESIGN = "style_design"
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class Task:
"""任务数据模型"""
# 标识
id: str
type: TaskType
status: TaskStatus
# 输入参数
prompt: str
api_key: str
reference_images: List[str]
aspect_ratio: str
image_size: str
model: str
# 时间戳
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
# 结果
result_bytes: Optional[bytes] = None
error_message: Optional[str] = None
# UI 相关
thumbnail: Optional[bytes] = None
progress: float = 0.0
class TaskQueueManager(QObject):
"""
单例任务队列管理器
管理所有图像生成任务的生命周期
"""
# Signals
task_added = Signal(Task)
task_started = Signal(str) # task_id
task_completed = Signal(str, bytes, str, list, str, str, str) # task_id, image_bytes, prompt, ref_images, aspect_ratio, image_size, model
task_failed = Signal(str, str) # task_id, error_message
task_progress = Signal(str, float, str) # task_id, progress, status_text
_instance = None
_lock = Lock()
def __new__(cls):
"""单例模式"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, '_initialized'):
return
super().__init__()
self.logger = logging.getLogger(__name__)
self._tasks: Dict[str, Task] = {}
self._queue = Queue()
self._current_worker = None
self._max_queue_size = 10
self._max_history_size = 10 # 只保留最近10条完成任务
self._initialized = True
self.logger.info("TaskQueueManager 初始化完成")
def submit_task(
self,
task_type: TaskType,
prompt: str,
api_key: str,
reference_images: List[str],
aspect_ratio: str,
image_size: str,
model: str
) -> str:
"""
提交新任务到队列
Args:
task_type: 任务类型
prompt: 图片描述
api_key: API 密钥
reference_images: 参考图片路径列表
aspect_ratio: 宽高比
image_size: 图片尺寸
model: 模型名称
Returns:
task_id: 任务唯一标识
Raises:
RuntimeError: 队列已满
"""
# 检查队列容量
if self._queue.qsize() >= self._max_queue_size:
raise RuntimeError(f"任务队列已满 (最大 {self._max_queue_size} 个)")
# 创建任务
task = Task(
id=str(uuid.uuid4()),
type=task_type,
status=TaskStatus.PENDING,
prompt=prompt,
api_key=api_key,
reference_images=reference_images.copy() if reference_images else [],
aspect_ratio=aspect_ratio,
image_size=image_size,
model=model,
created_at=datetime.now()
)
self._tasks[task.id] = task
self._queue.put(task.id)
self.logger.info(f"任务已提交: {task.id[:8]} - {prompt[:30]}")
self.task_added.emit(task)
# 如果没有正在运行的任务,启动处理
if self._current_worker is None or not self._current_worker.isRunning():
self._process_next()
return task.id
def _process_next(self):
"""处理队列中的下一个任务"""
if self._queue.empty():
self.logger.debug("队列为空,无任务处理")
return
task_id = self._queue.get()
task = self._tasks[task_id]
task.status = TaskStatus.RUNNING
task.started_at = datetime.now()
self.logger.info(f"开始处理任务: {task_id[:8]}")
# 导入 ImageGenerationWorker
from image_generator import ImageGenerationWorker
# 创建 worker
self._current_worker = ImageGenerationWorker(
task.api_key,
task.prompt,
task.reference_images,
task.aspect_ratio,
task.image_size,
task.model
)
# 绑定信号
self._current_worker.finished.connect(
lambda img_bytes, prompt, ref_imgs, ar, size, model:
self._on_task_completed(task_id, img_bytes, prompt, ref_imgs, ar, size, model)
)
self._current_worker.error.connect(
lambda error: self._on_task_failed(task_id, error)
)
self._current_worker.progress.connect(
lambda status: self.task_progress.emit(task_id, 0.5, status)
)
self.task_started.emit(task_id)
self._current_worker.start()
def _on_task_completed(self, task_id: str, image_bytes: bytes, prompt: str,
reference_images: list, aspect_ratio: str, image_size: str, model: str):
"""任务完成回调"""
task = self._tasks.get(task_id)
if not task:
self.logger.error(f"任务 {task_id[:8]} 不存在")
return
task.status = TaskStatus.COMPLETED
task.completed_at = datetime.now()
task.result_bytes = image_bytes
# 生成缩略图
try:
task.thumbnail = self._create_thumbnail(image_bytes)
except Exception as e:
self.logger.warning(f"生成缩略图失败: {e}")
elapsed = (task.completed_at - task.started_at).total_seconds()
self.logger.info(f"任务完成: {task_id[:8]} - 耗时 {elapsed:.1f}s")
self.task_completed.emit(task_id, image_bytes, prompt, reference_images,
aspect_ratio, image_size, model)
# 清理旧任务历史,只保留最近的完成任务
self._cleanup_old_tasks()
# 处理下一个任务
self._process_next()
def _on_task_failed(self, task_id: str, error: str):
"""任务失败回调"""
task = self._tasks.get(task_id)
if not task:
self.logger.error(f"任务 {task_id[:8]} 不存在")
return
task.status = TaskStatus.FAILED
task.completed_at = datetime.now()
task.error_message = error
self.logger.error(f"任务失败: {task_id[:8]} - {error}")
self.task_failed.emit(task_id, error)
# 清理旧任务历史
self._cleanup_old_tasks()
# 处理下一个任务
self._process_next()
def _cleanup_old_tasks(self):
"""清理旧任务,只保留最近的完成/失败任务"""
# 获取所有已完成和失败的任务,按完成时间排序
finished_tasks = [
t for t in self._tasks.values()
if t.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] and t.completed_at
]
finished_tasks.sort(key=lambda t: t.completed_at, reverse=True)
# 只保留最近的 N 条
if len(finished_tasks) > self._max_history_size:
tasks_to_remove = finished_tasks[self._max_history_size:]
for task in tasks_to_remove:
del self._tasks[task.id]
self.logger.debug(f"清理旧任务: {task.id[:8]}")
def _create_thumbnail(self, image_bytes: bytes) -> bytes:
"""
创建缩略图 (50x50)
Args:
image_bytes: 原始图片字节
Returns:
缩略图字节
"""
img = Image.open(io.BytesIO(image_bytes))
img.thumbnail((50, 50))
thumb_io = io.BytesIO()
img.save(thumb_io, format='PNG')
return thumb_io.getvalue()
def get_task(self, task_id: str) -> Optional[Task]:
"""获取任务详情"""
return self._tasks.get(task_id)
def get_all_tasks(self) -> List[Task]:
"""获取所有任务"""
return list(self._tasks.values())
def get_pending_count(self) -> int:
"""获取等待中任务数"""
return sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)
def get_running_count(self) -> int:
"""获取运行中任务数"""
return sum(1 for t in self._tasks.values() if t.status == TaskStatus.RUNNING)
def cancel_task(self, task_id: str):
"""取消任务 (仅等待中任务)"""
task = self._tasks.get(task_id)
if task and task.status == TaskStatus.PENDING:
task.status = TaskStatus.CANCELLED
self.logger.info(f"任务已取消: {task_id[:8]}")