memory_index.py 4.24 KB
"""
内存索引模块
将图片元数据加载到内存,支持快速元数据查询
"""

import logging
import sqlite3
import threading

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class InMemoryIndex:
    """内存索引(支持快速元数据查询和线程安全更新)"""

    def __init__(self):
        """初始化空索引"""
        # 图片元数据 {img_id: {"path": ..., "design_no": ..., "image_url": ...}}
        self.img_metadata = {}
        # 线程锁(支持并发读写)
        self._lock = threading.RLock()

    def load_from_db(self, db_path):
        """
        从SQLite数据库加载索引到内存

        Args:
            db_path: 数据库路径

        Returns:
            bool: 是否成功
        """
        try:
            import time
            start = time.time()

            conn = sqlite3.connect(db_path)

            # 加载元数据
            logger.info("加载图片元数据到内存...")
            cursor = conn.execute(
                "SELECT id, path, design_no, image_url FROM images"
            )

            with self._lock:
                for row in cursor:
                    img_id, path, design_no, image_url = row
                    self.img_metadata[img_id] = {
                        "id": img_id,
                        "path": image_url if image_url else path,  # 优先使用image_url
                        "design_no": design_no,
                        "image_url": image_url
                    }

            logger.info(f"✓ 已加载 {len(self.img_metadata)} 条元数据")

            conn.close()

            elapsed = time.time() - start
            logger.info(f"内存索引加载完成,耗时 {elapsed:.2f}秒")
            logger.info(f"内存占用估算: {self.estimate_memory_usage():.1f}MB")

            return True

        except Exception as e:
            logger.error(f"加载内存索引失败: {e}", exc_info=True)
            return False

    def add_or_update(self, img_id, design_no, image_url, image_path=None):
        """
        增量添加或更新元数据(线程安全)

        Args:
            img_id: 图片ID (design_id)
            design_no: 设计款号
            image_url: 图片URL
            image_path: 本地路径(可选)
        """
        with self._lock:
            self.img_metadata[img_id] = {
                "id": img_id,
                "path": image_url if image_url else image_path,
                "design_no": design_no,
                "image_url": image_url
            }
            logger.debug(f"内存索引已更新: img_id={img_id}")

    def remove(self, img_id):
        """
        删除元数据(线程安全)

        Args:
            img_id: 图片ID
        """
        with self._lock:
            if img_id in self.img_metadata:
                del self.img_metadata[img_id]
                logger.debug(f"内存索引已删除: img_id={img_id}")

    def estimate_memory_usage(self):
        """估算内存占用(MB)"""
        # img_metadata: 每条约200字节(路径字符串 + design_no)
        size = len(self.img_metadata) * 200
        return size / 1024 / 1024

    def get_metadata(self, img_id):
        """
        获取图片元数据(线程安全,兼容int/str类型)

        Args:
            img_id: 图片ID(int或str)

        Returns:
            dict or None: 元数据字典
        """
        with self._lock:
            # 尝试原始类型
            metadata = self.img_metadata.get(img_id)
            if metadata:
                return metadata

            # 尝试字符串类型(兼容FAISS返回int但SQLite存str的情况)
            metadata = self.img_metadata.get(str(img_id))
            if metadata:
                return metadata

            # 尝试int类型(兼容反向情况)
            if isinstance(img_id, str) and img_id.isdigit():
                metadata = self.img_metadata.get(int(img_id))
                if metadata:
                    return metadata

            return None

    def get_stats(self):
        """获取索引统计信息"""
        with self._lock:
            return {
                "total_images": len(self.img_metadata),
                "memory_mb": self.estimate_memory_usage()
            }