faiss_manager.py 14.4 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
"""
FAISS索引管理器(增量更新 + 墓碑标记)
支持CNN向量的快速检索和增量维护
"""

import logging
import pickle
import os
import sqlite3
import numpy as np

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class FAISSManager:
    """FAISS索引管理器(HNSW + 墓碑标记)"""

    def __init__(self, index_path="./data/faiss_cnn.index",
                 mapping_path="./data/faiss_id_mapping.pkl",
                 tombstone_path="./data/faiss_tombstones.pkl",
                 vector_dim=576):
        """
        初始化FAISS管理器

        Args:
            index_path: FAISS索引文件路径
            mapping_path: ID映射文件路径(FAISS索引位置 -> image_id)
            tombstone_path: 墓碑标记文件路径(记录已删除的ID)
            vector_dim: CNN向量维度(MobileNetV3-Small输出576维)
        """
        self.index_path = index_path
        self.mapping_path = mapping_path
        self.tombstone_path = tombstone_path
        self.vector_dim = vector_dim

        self.index = None  # faiss.IndexHNSWFlat实例
        self.id_mapping = []  # [img_id1, img_id2, ...] FAISS索引位置对应的image_id
        self.tombstones = set()  # {img_id1, img_id2, ...} 已删除但未压缩的ID

        self._ensure_faiss()

    def _ensure_faiss(self):
        """确保FAISS已安装"""
        try:
            import faiss
            self.faiss = faiss
            logger.info("FAISS库加载成功")
        except ImportError:
            logger.error("FAISS库未安装,请运行: pip install faiss-cpu")
            raise

    def build_index(self, db_path, metric="cosine", M=32, efConstruction=200):
        """
        从数据库构建FAISS索引(全量)

        Args:
            db_path: 数据库路径
            metric: 相似度度量 ("cosine" 或 "L2")
            M: HNSW参数(连接数,越大精度越高但内存越大)
            efConstruction: HNSW构建参数(越大精度越高但构建越慢)

        Returns:
            bool: 是否成功
        """
        try:
            import time
            start = time.time()

            # 1. 从数据库加载CNN向量
            logger.info("从数据库加载CNN向量...")
            conn = sqlite3.connect(db_path)
            cursor = conn.execute("SELECT id, cnn_vector FROM images WHERE cnn_vector IS NOT NULL")

            vectors = []
            img_ids = []

            for row in cursor:
                img_id, cnn_blob = row
                if cnn_blob:
                    try:
                        vec = np.frombuffer(cnn_blob, dtype=np.float32)
                        if len(vec) == self.vector_dim:
                            vectors.append(vec)
                            img_ids.append(img_id)
                        else:
                            logger.warning(f"CNN向量维度不匹配: {len(vec)} != {self.vector_dim} (img_id={img_id})")
                    except Exception as e:
                        logger.error(f"解析CNN向量失败 (img_id={img_id}): {e}")

            conn.close()

            if len(vectors) == 0:
                logger.warning("未找到任何CNN向量")
                return False

            logger.info(f"已加载 {len(vectors)} 个CNN向量")

            # 2. 创建FAISS索引
            if metric == "cosine":
                # 余弦相似度:先L2归一化,再用内积
                self.index = self.faiss.IndexHNSWFlat(self.vector_dim, M, self.faiss.METRIC_INNER_PRODUCT)
            else:
                # 欧氏距离
                self.index = self.faiss.IndexHNSWFlat(self.vector_dim, M, self.faiss.METRIC_L2)

            self.index.hnsw.efConstruction = efConstruction

            # 3. 添加向量到索引
            vectors_np = np.array(vectors, dtype=np.float32)

            if metric == "cosine":
                # L2归一化(确保余弦相似度正确)
                norms = np.linalg.norm(vectors_np, axis=1, keepdims=True)
                vectors_np = vectors_np / (norms + 1e-8)

            self.index.add(vectors_np)

            # 4. 保存ID映射
            self.id_mapping = img_ids
            self.tombstones = set()

            # 5. 持久化索引
            self._save_index()

            elapsed = time.time() - start
            logger.info(f"FAISS索引构建完成,耗时 {elapsed:.2f}秒")
            logger.info(f"索引大小: {len(self.id_mapping)} 个向量")

            return True

        except Exception as e:
            logger.error(f"构建FAISS索引失败: {e}", exc_info=True)
            return False

    def incremental_add(self, img_id, cnn_vector):
        """
        增量添加单个向量到索引

        Args:
            img_id: 图片ID
            cnn_vector: CNN向量(numpy数组,576维)

        Returns:
            bool: 是否成功
        """
        try:
            # 如果索引未初始化,自动创建
            if self.index is None:
                logger.info("首次添加向量,自动创建 FAISS 索引")
                self.index = self.faiss.IndexHNSWFlat(self.vector_dim, 32)
                self.faiss.ParameterSpace().set_index_parameter(self.index, "efConstruction", 200)
                # 对于余弦相似度,使用 L2 归一化
                logger.info("FAISS 索引已创建")

            if cnn_vector is None or len(cnn_vector) != self.vector_dim:
                logger.warning(f"无效CNN向量: {cnn_vector}")
                return False

            # L2归一化(用于余弦相似度)
            vec = cnn_vector.astype(np.float32).reshape(1, -1)
            norm = np.linalg.norm(vec)
            if norm > 1e-8:
                vec = vec / norm

            # 添加到索引
            self.index.add(vec)

            # 更新映射
            self.id_mapping.append(img_id)

            # 从墓碑集合中移除(如果之前被删除过)
            if img_id in self.tombstones:
                self.tombstones.remove(img_id)

            return True

        except Exception as e:
            logger.error(f"增量添加失败 (img_id={img_id}): {e}")
            return False

    def mark_delete(self, img_id):
        """
        墓碑标记删除(不立即压缩索引)

        Args:
            img_id: 要删除的图片ID
        """
        if img_id in self.id_mapping:
            self.tombstones.add(img_id)
            logger.info(f"墓碑标记删除: img_id={img_id}")
        else:
            logger.warning(f"尝试删除不存在的ID: {img_id}")

    def search(self, query_vector, top_k=2000):
        """
        搜索最相似的向量(过滤墓碑标记)

        Args:
            query_vector: 查询向量(576维)
            top_k: 返回Top-K结果

        Returns:
            list: [(img_id, similarity_score), ...]
        """
        if self.index is None or self.index.ntotal == 0:
            logger.warning("索引为空")
            return []

        try:
            # L2归一化查询向量
            vec = query_vector.astype(np.float32).reshape(1, -1)
            norm = np.linalg.norm(vec)
            if norm > 1e-8:
                vec = vec / norm

            # 设置适当的搜索深度以提高召回率
            # 根据请求的top_k动态调整efSearch
            original_ef = self.index.hnsw.efSearch
            optimal_ef = max(200, min(top_k * 4, 800))  # 200-800之间
            self.index.hnsw.efSearch = optimal_ef

            # 搜索(多取一些以补偿墓碑过滤)
            search_k = min(top_k * 2, self.index.ntotal)
            distances, indices = self.index.search(vec, search_k)

            # 恢复原始efSearch值
            self.index.hnsw.efSearch = original_ef

            # 过滤墓碑 + 构建结果
            results = []
            for dist, idx in zip(distances[0], indices[0]):
                if idx < 0 or idx >= len(self.id_mapping):
                    continue

                img_id = self.id_mapping[idx]

                # 跳过墓碑标记的ID
                if img_id in self.tombstones:
                    continue

                # 转换距离为相似度
                # 注意:索引使用METRIC_INNER_PRODUCT(内积度量)
                # 对于L2归一化的向量,内积就是余弦相似度
                # dist值越大表示越相似(范围约为[0, 1],1表示完全相同)
                similarity = float(dist)
                results.append((img_id, similarity))

                if len(results) >= top_k:
                    break

            return results

        except Exception as e:
            logger.error(f"FAISS搜索失败: {e}")
            return []

    def compact_index(self, db_path):
        """
        压缩索引(移除墓碑标记的向量,重建索引)

        Args:
            db_path: 数据库路径

        Returns:
            bool: 是否成功
        """
        if len(self.tombstones) == 0:
            logger.info("无需压缩,墓碑集合为空")
            return True

        logger.info(f"开始压缩索引,移除 {len(self.tombstones)} 个墓碑标记")

        try:
            # 1. 从数据库重新加载有效的CNN向量
            conn = sqlite3.connect(db_path)
            cursor = conn.execute("SELECT id, cnn_vector FROM images WHERE cnn_vector IS NOT NULL")

            vectors = []
            img_ids = []

            for row in cursor:
                img_id, cnn_blob = row

                # 跳过墓碑标记的ID
                if img_id in self.tombstones:
                    continue

                if cnn_blob:
                    try:
                        vec = np.frombuffer(cnn_blob, dtype=np.float32)
                        if len(vec) == self.vector_dim:
                            vectors.append(vec)
                            img_ids.append(img_id)
                    except Exception as e:
                        logger.error(f"解析CNN向量失败 (img_id={img_id}): {e}")

            conn.close()

            if len(vectors) == 0:
                logger.warning("压缩后无有效向量")
                return False

            # 2. 重建索引
            logger.info(f"重建索引,{len(self.id_mapping)} -> {len(vectors)} 个向量")

            # 保留原索引参数(使用默认值,避免API兼容性问题)
            M = 32  # HNSW连接数
            efConstruction = 200  # 构建搜索深度

            self.index = self.faiss.IndexHNSWFlat(self.vector_dim, M, self.faiss.METRIC_INNER_PRODUCT)
            self.index.hnsw.efConstruction = efConstruction

            # 3. 添加向量
            vectors_np = np.array(vectors, dtype=np.float32)
            norms = np.linalg.norm(vectors_np, axis=1, keepdims=True)
            vectors_np = vectors_np / (norms + 1e-8)
            self.index.add(vectors_np)

            # 4. 更新映射和清空墓碑
            self.id_mapping = img_ids
            self.tombstones.clear()

            # 5. 持久化
            self._save_index()

            logger.info("索引压缩完成")
            return True

        except Exception as e:
            logger.error(f"压缩索引失败: {e}", exc_info=True)
            return False

    def load_index(self):
        """
        从磁盘加载索引

        Returns:
            bool: 是否成功
        """
        try:
            if not os.path.exists(self.index_path):
                logger.warning(f"索引文件不存在: {self.index_path}")
                return False

            import time
            start = time.time()

            # 加载FAISS索引
            self.index = self.faiss.read_index(self.index_path)

            # 加载ID映射
            if os.path.exists(self.mapping_path):
                with open(self.mapping_path, "rb") as f:
                    self.id_mapping = pickle.load(f)
            else:
                logger.warning(f"映射文件不存在: {self.mapping_path}")
                self.id_mapping = []

            # 加载墓碑标记
            if os.path.exists(self.tombstone_path):
                with open(self.tombstone_path, "rb") as f:
                    self.tombstones = pickle.load(f)
            else:
                self.tombstones = set()

            elapsed = time.time() - start
            logger.info(f"FAISS索引加载完成,耗时 {elapsed:.2f}秒")
            logger.info(f"向量数: {self.index.ntotal}, 映射数: {len(self.id_mapping)}, 墓碑数: {len(self.tombstones)}")

            return True

        except Exception as e:
            logger.error(f"加载索引失败: {e}", exc_info=True)
            return False

    def _save_index(self):
        """持久化索引到磁盘"""
        try:
            # 确保目录存在
            os.makedirs(os.path.dirname(self.index_path), exist_ok=True)

            # 保存FAISS索引
            self.faiss.write_index(self.index, self.index_path)

            # 保存ID映射
            with open(self.mapping_path, "wb") as f:
                pickle.dump(self.id_mapping, f)

            # 保存墓碑标记
            with open(self.tombstone_path, "wb") as f:
                pickle.dump(self.tombstones, f)

            logger.info("FAISS索引已保存")

        except Exception as e:
            logger.error(f"保存索引失败: {e}")

    def rebuild_index(self, db_manager):
        """
        重建索引(清理墓碑标记的向量)

        这是 compact_index 的别名方法,用于保持API兼容性

        Args:
            db_manager: 数据库管理器实例或数据库路径

        Returns:
            bool: 是否成功
        """
        # 如果传入的是数据库管理器对象,获取其路径
        if hasattr(db_manager, 'db_path'):
            db_path = db_manager.db_path
        else:
            # 假设传入的是数据库路径字符串
            db_path = db_manager

        logger.info("开始重建索引(清理墓碑)...")
        result = self.compact_index(db_path)

        if result:
            logger.info("索引重建完成,内存状态已更新")
            # 提示:如果是多进程架构,其他进程需要调用 load_index() 来获取最新索引

        return result

    def get_stats(self):
        """获取索引统计信息"""
        return {
            "total_vectors": self.index.ntotal if self.index else 0,
            "mapping_size": len(self.id_mapping),
            "tombstone_count": len(self.tombstones),
            "effective_vectors": len(self.id_mapping) - len(self.tombstones)
        }