增加定时重建的相关代码
Showing
4 changed files
with
102 additions
and
4 deletions
| ... | @@ -19,6 +19,8 @@ from fastapi.middleware.cors import CORSMiddleware | ... | @@ -19,6 +19,8 @@ from fastapi.middleware.cors import CORSMiddleware |
| 19 | from fastapi.responses import JSONResponse | 19 | from fastapi.responses import JSONResponse |
| 20 | import numpy as np | 20 | import numpy as np |
| 21 | from dotenv import load_dotenv | 21 | from dotenv import load_dotenv |
| 22 | from apscheduler.schedulers.background import BackgroundScheduler | ||
| 23 | from apscheduler.triggers.cron import CronTrigger | ||
| 22 | 24 | ||
| 23 | # 尝试导入不同的 JWT 库 | 25 | # 尝试导入不同的 JWT 库 |
| 24 | try: | 26 | try: |
| ... | @@ -50,6 +52,7 @@ db_manager = None | ... | @@ -50,6 +52,7 @@ db_manager = None |
| 50 | search_engine = None | 52 | search_engine = None |
| 51 | data_sync = None | 53 | data_sync = None |
| 52 | sync_thread = None | 54 | sync_thread = None |
| 55 | scheduler = None | ||
| 53 | 56 | ||
| 54 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | 57 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' |
| 55 | 58 | ||
| ... | @@ -77,10 +80,20 @@ def load_config(): | ... | @@ -77,10 +80,20 @@ def load_config(): |
| 77 | return config | 80 | return config |
| 78 | 81 | ||
| 79 | 82 | ||
| 83 | def scheduled_sync(): | ||
| 84 | """定时任务:每天 0:00 和 12:00 执行增量同步""" | ||
| 85 | try: | ||
| 86 | logger.info("🔄 定时任务:开始增量同步") | ||
| 87 | result = data_sync.sync_once() | ||
| 88 | logger.info(f"✅ 定时同步完成: {result}") | ||
| 89 | except Exception as e: | ||
| 90 | logger.error(f"❌ 定时同步失败: {e}", exc_info=True) | ||
| 91 | |||
| 92 | |||
| 80 | @asynccontextmanager | 93 | @asynccontextmanager |
| 81 | async def lifespan(app: FastAPI): | 94 | async def lifespan(app: FastAPI): |
| 82 | """应用生命周期管理""" | 95 | """应用生命周期管理""" |
| 83 | global config, db_manager, search_engine, data_sync, sync_thread | 96 | global config, db_manager, search_engine, data_sync, sync_thread, scheduler |
| 84 | 97 | ||
| 85 | logger.info("启动 Design Image Search 服务...") | 98 | logger.info("启动 Design Image Search 服务...") |
| 86 | 99 | ||
| ... | @@ -134,12 +147,26 @@ async def lifespan(app: FastAPI): | ... | @@ -134,12 +147,26 @@ async def lifespan(app: FastAPI): |
| 134 | sync_thread.start() | 147 | sync_thread.start() |
| 135 | logger.info("后台数据同步线程已启动") | 148 | logger.info("后台数据同步线程已启动") |
| 136 | 149 | ||
| 150 | # 启动定时任务(每天 0:00 和 12:00) | ||
| 151 | scheduler = BackgroundScheduler() | ||
| 152 | scheduler.add_job( | ||
| 153 | func=scheduled_sync, | ||
| 154 | trigger=CronTrigger(hour='0,12', minute='0'), | ||
| 155 | id='scheduled_sync', | ||
| 156 | replace_existing=True | ||
| 157 | ) | ||
| 158 | scheduler.start() | ||
| 159 | logger.info("定时任务已启动(每天 0:00 和 12:00 执行)") | ||
| 160 | |||
| 137 | logger.info("Design Image Search 服务启动完成") | 161 | logger.info("Design Image Search 服务启动完成") |
| 138 | 162 | ||
| 139 | yield | 163 | yield |
| 140 | 164 | ||
| 141 | # 清理代码 | 165 | # 清理代码 |
| 142 | logger.info("正在关闭 Design Image Search 服务...") | 166 | logger.info("正在关闭 Design Image Search 服务...") |
| 167 | if scheduler and scheduler.running: | ||
| 168 | scheduler.shutdown() | ||
| 169 | logger.info("定时任务已关闭") | ||
| 143 | 170 | ||
| 144 | 171 | ||
| 145 | # 创建 FastAPI 应用 | 172 | # 创建 FastAPI 应用 |
| ... | @@ -431,6 +458,48 @@ async def trigger_sync(token: Dict = Depends(verify_token)): | ... | @@ -431,6 +458,48 @@ async def trigger_sync(token: Dict = Depends(verify_token)): |
| 431 | ) | 458 | ) |
| 432 | 459 | ||
| 433 | 460 | ||
| 461 | @app.post("/admin/rebuild-all") | ||
| 462 | async def rebuild_all(token: Dict = Depends(verify_token)): | ||
| 463 | """ | ||
| 464 | 全量重建(临时将 last_sync_time 设为 1970,复用增量同步逻辑) | ||
| 465 | |||
| 466 | Args: | ||
| 467 | token: JWT 认证信息(自动注入) | ||
| 468 | |||
| 469 | Returns: | ||
| 470 | Dict: 重建结果 | ||
| 471 | """ | ||
| 472 | try: | ||
| 473 | from datetime import datetime | ||
| 474 | |||
| 475 | logger.info("🔄 手动触发全量重建") | ||
| 476 | |||
| 477 | # 临时保存当前的 last_sync_time | ||
| 478 | original_sync_time = db_manager.get_last_sync_time() | ||
| 479 | |||
| 480 | # 临时设置为 1970-01-01(获取所有历史数据) | ||
| 481 | db_manager.update_sync_time(datetime(1970, 1, 1), 0) | ||
| 482 | |||
| 483 | try: | ||
| 484 | # 调用增量同步(但会处理所有数据) | ||
| 485 | result = data_sync.sync_once() | ||
| 486 | |||
| 487 | return { | ||
| 488 | "success": True, | ||
| 489 | "message": "全量重建完成", | ||
| 490 | "result": result | ||
| 491 | } | ||
| 492 | except Exception as e: | ||
| 493 | # 失败时恢复原来的 sync_time | ||
| 494 | if original_sync_time: | ||
| 495 | db_manager.update_sync_time(original_sync_time, 0) | ||
| 496 | raise e | ||
| 497 | |||
| 498 | except Exception as e: | ||
| 499 | logger.error(f"❌ 全量重建失败: {e}", exc_info=True) | ||
| 500 | raise HTTPException(500, f"重建失败: {str(e)}") | ||
| 501 | |||
| 502 | |||
| 434 | # 配置 CORS | 503 | # 配置 CORS |
| 435 | app.add_middleware( | 504 | app.add_middleware( |
| 436 | CORSMiddleware, | 505 | CORSMiddleware, | ... | ... |
| ... | @@ -396,6 +396,28 @@ class FAISSManager: | ... | @@ -396,6 +396,28 @@ class FAISSManager: |
| 396 | except Exception as e: | 396 | except Exception as e: |
| 397 | logger.error(f"保存索引失败: {e}") | 397 | logger.error(f"保存索引失败: {e}") |
| 398 | 398 | ||
| 399 | def rebuild_index(self, db_manager): | ||
| 400 | """ | ||
| 401 | 重建索引(清理墓碑标记的向量) | ||
| 402 | |||
| 403 | 这是 compact_index 的别名方法,用于保持API兼容性 | ||
| 404 | |||
| 405 | Args: | ||
| 406 | db_manager: 数据库管理器实例或数据库路径 | ||
| 407 | |||
| 408 | Returns: | ||
| 409 | bool: 是否成功 | ||
| 410 | """ | ||
| 411 | # 如果传入的是数据库管理器对象,获取其路径 | ||
| 412 | if hasattr(db_manager, 'db_path'): | ||
| 413 | db_path = db_manager.db_path | ||
| 414 | else: | ||
| 415 | # 假设传入的是数据库路径字符串 | ||
| 416 | db_path = db_manager | ||
| 417 | |||
| 418 | logger.info("开始重建索引(清理墓碑)...") | ||
| 419 | return self.compact_index(db_path) | ||
| 420 | |||
| 399 | def get_stats(self): | 421 | def get_stats(self): |
| 400 | """获取索引统计信息""" | 422 | """获取索引统计信息""" |
| 401 | return { | 423 | return { | ... | ... |
| ... | @@ -14,6 +14,7 @@ Pillow>=10.0.0 | ... | @@ -14,6 +14,7 @@ Pillow>=10.0.0 |
| 14 | 14 | ||
| 15 | # Database | 15 | # Database |
| 16 | PyMySQL>=1.1.0 | 16 | PyMySQL>=1.1.0 |
| 17 | # sqlite3 is built-in to Python, no need to install | ||
| 17 | 18 | ||
| 18 | # JWT Authentication | 19 | # JWT Authentication |
| 19 | python-jose[cryptography]>=3.3.0 | 20 | python-jose[cryptography]>=3.3.0 |
| ... | @@ -26,7 +27,10 @@ requests>=2.31.0 | ... | @@ -26,7 +27,10 @@ requests>=2.31.0 |
| 26 | # Utilities | 27 | # Utilities |
| 27 | python-dotenv>=1.0.0 | 28 | python-dotenv>=1.0.0 |
| 28 | pyyaml>=6.0.1 | 29 | pyyaml>=6.0.1 |
| 29 | numpy>=1.24.0,<2.0 | 30 | numpy>=1.24.0,<2.0.0 |
| 30 | 31 | ||
| 31 | # Logging & Monitoring | 32 | # Logging & Monitoring |
| 32 | structlog>=23.1.0 | 33 | structlog>=23.1.0 |
| 34 | |||
| 35 | # Task Scheduling | ||
| 36 | apscheduler==3.10.4 | ... | ... |
| ... | @@ -4,7 +4,7 @@ uvicorn[standard]>=0.24.0 | ... | @@ -4,7 +4,7 @@ uvicorn[standard]>=0.24.0 |
| 4 | python-multipart>=0.0.6 | 4 | python-multipart>=0.0.6 |
| 5 | 5 | ||
| 6 | # Machine Learning & Computer Vision | 6 | # Machine Learning & Computer Vision |
| 7 | # 使用版本范围,避免CUDA依赖(镜像源会自动提供CPU版本) | 7 | # Use version ranges to avoid CUDA dependencies (CPU version from mirror) |
| 8 | torch>=2.0.0,<2.1.0 | 8 | torch>=2.0.0,<2.1.0 |
| 9 | torchvision>=0.15.0,<0.16.0 | 9 | torchvision>=0.15.0,<0.16.0 |
| 10 | faiss-cpu>=1.7.4 | 10 | faiss-cpu>=1.7.4 |
| ... | @@ -26,7 +26,10 @@ requests>=2.31.0 | ... | @@ -26,7 +26,10 @@ requests>=2.31.0 |
| 26 | # Utilities | 26 | # Utilities |
| 27 | python-dotenv>=1.0.0 | 27 | python-dotenv>=1.0.0 |
| 28 | pyyaml>=6.0.1 | 28 | pyyaml>=6.0.1 |
| 29 | numpy>=1.24.0 | 29 | numpy>=1.24.0,<2.0.0 |
| 30 | 30 | ||
| 31 | # Logging & Monitoring | 31 | # Logging & Monitoring |
| 32 | structlog>=23.1.0 | 32 | structlog>=23.1.0 |
| 33 | |||
| 34 | # Task Scheduling | ||
| 35 | apscheduler==3.10.4 | ... | ... |
-
Please register or sign in to post a comment