search_engine.py 13.3 KB
"""
搜索引擎模块(CNN+RANSAC优化版)
CNN召回 + RANSAC验证 + 置信度评分
"""

import sqlite3
import numpy as np
import cv2
import logging
import threading
from collections import Counter
from .feature_extractor import FeatureExtractor

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class ImageSearchEngine:
    """图片搜索引擎(CNN+RANSAC优化版)"""

    def __init__(self, db_path, config, memory_index=None, faiss_manager=None):
        """
        初始化搜索引擎

        Args:
            db_path: SQLite 数据库路径
            config: 配置字典
            memory_index: 内存索引实例(InMemoryIndex)
            faiss_manager: FAISS管理器实例(FAISSManager)
        """
        self.db_path = db_path
        self.config = config
        self.memory_index = memory_index
        self.faiss_manager = faiss_manager

        # 线程本地存储(每个线程独立的数据库连接)
        self._thread_local = threading.local()
        logger.info(f"搜索引擎将为每个线程创建独立数据库连接: {db_path}")

        self.extractor = FeatureExtractor(
            orb_max_features=config.get("orb_max_features", 1200),
            cnn_enabled=config.get("cnn_enabled", True)
        )

        # 加载配置参数
        search_config = config.get("search", {})
        self.cnn_top_k = search_config.get("cnn_top_k", 2000)
        self.max_candidates = search_config.get("max_candidates", 200)
        self.top_results = search_config.get("top_results", 20)

        # RANSAC配置(从search.ransac读取)
        ransac_config = search_config.get("ransac", {})
        self.min_orb_inliers = ransac_config.get("min_inliers", 15)
        self.ransac_reproj_thresh = ransac_config.get("reproj_threshold", 4.0)
        self.ransac_confidence = ransac_config.get("confidence", 0.995)

        # 融合权重(从search.fusion_weights读取)
        weights_config = search_config.get("fusion_weights", {})
        # 权重需要乘以100,因为打分逻辑中会 * 100
        self.weights = {
            "cnn": weights_config.get("cnn", 0.2) * 100,
            "ransac": weights_config.get("ransac", 0.8) * 100
        }

        logger.info(f"搜索引擎初始化完成(CNN+RANSAC优化版)")
        logger.info(f"权重: {self.weights}")

    def _get_db_conn(self):
        """
        获取当前线程的数据库连接(线程安全)

        Returns:
            sqlite3.Connection: 当前线程的数据库连接
        """
        if not hasattr(self._thread_local, 'conn'):
            # 为当前线程创建新连接
            self._thread_local.conn = sqlite3.connect(self.db_path)
            logger.debug(f"为线程 {threading.current_thread().name} 创建数据库连接")
        return self._thread_local.conn

    def search(self, query_img_path, top_k=None):
        """
        搜索相似图片(CNN优先 + RANSAC几何验证)

        流程(优化版):
        1. 优先CNN向量召回(语义特征,对变色裁切最鲁棒)
        2. 辅助pHash召回(保留但权重极低)
        3. Top候选进行RANSAC几何验证(并行化)
        4. 融合打分 + 置信度评估
        5. 返回Top-K结果

        Args:
            query_img_path: 查询图片路径
            top_k: 返回结果数量(None则使用配置的top_results)

        Returns:
            list: [{
                "path": str,
                "score": float,
                "confidence": str,  # "high", "medium", "low"
                "details": {
                    "cnn_sim": float,
                    "ransac_inliers": int
                }
            }, ...]
        """
        if top_k is None:
            top_k = self.top_results

        # 提取查询图片的所有特征
        logger.info(f"开始搜索: {query_img_path}")
        query_features = self.extractor.extract_all_features(query_img_path)

        if query_features is None:
            logger.error(f"查询图片特征提取失败: {query_img_path}")
            return []

        # ========== 阶段1: CNN召回 ==========
        logger.info("【阶段1】CNN召回...")

        cnn_candidates = set()
        cnn_similarity_cache = {}  # 缓存CNN相似度 {img_id: similarity}

        # CNN向量召回(唯一通道)
        if self.faiss_manager and query_features["cnn_vector"] is not None:
            cnn_results = self.faiss_manager.search(
                query_features["cnn_vector"],
                top_k=self.cnn_top_k
            )
            for img_id, sim in cnn_results:
                cnn_candidates.add(img_id)
                cnn_similarity_cache[img_id] = sim  # 缓存相似度
            logger.info(f"  ✓ CNN召回: {len(cnn_candidates)} 个候选(语义特征,对变色裁切最鲁棒)")

        if not cnn_candidates:
            logger.warning("未召回任何候选图片(可能是FAISS索引未初始化)")
            return []

        logger.info(f"【候选】总候选数: {len(cnn_candidates)}")

        # ========== 阶段2: 快速打分(CNN优先) ==========
        logger.info("【阶段2】快速打分...")

        scored_candidates = []

        for img_id in cnn_candidates:
            # 从内存索引获取元数据
            if not self.memory_index:
                continue

            metadata = self.memory_index.img_metadata.get(str(img_id))
            if not metadata:
                continue

            # 计算各维度得分
            details = {}

            # 2.1 CNN相似度得分(主要得分)
            cnn_sim = 0.0
            if img_id in cnn_similarity_cache:
                cnn_sim = cnn_similarity_cache[img_id]

            details["cnn_sim"] = cnn_sim
            # CNN得分:直接使用相似度乘以权重,归一化到0-100
            cnn_score = cnn_sim * 100 * self.weights["cnn"]

            # 初步总分(CNN主导)
            initial_score = cnn_score

            scored_candidates.append({
                "img_id": img_id,
                "path": metadata["path"],
                "initial_score": initial_score,
                "details": details,
                "ransac_inliers": 0  # 待RANSAC验证
            })

        # 按初步得分排序
        scored_candidates.sort(key=lambda x: x["initial_score"], reverse=True)

        # 调试:打印阶段2结果
        logger.info(f"【阶段2完成】候选数量: {len(scored_candidates)}")
        if scored_candidates:
            logger.info(f"  最高分: {scored_candidates[0]['initial_score']:.4f}")
            logger.info(f"  最低分: {scored_candidates[-1]['initial_score']:.4f}")
            logger.info(f"  前3名: {[c['img_id'] for c in scored_candidates[:3]]}")

        # 截取Top候选进行RANSAC验证
        top_for_ransac = min(100, len(scored_candidates))
        logger.info(f"【阶段3】对Top{top_for_ransac}进行RANSAC几何验证(容忍变色裁切)...")

        # ========== 阶段3: RANSAC几何验证(并行化) ==========
        from concurrent.futures import ThreadPoolExecutor, as_completed

        # 使用4线程并行处理RANSAC验证
        with ThreadPoolExecutor(max_workers=4) as executor:
            # 准备任务列表
            futures = {}
            for i, candidate in enumerate(scored_candidates[:top_for_ransac]):
                img_id = candidate["img_id"]

                # 提交RANSAC任务到线程池
                future = executor.submit(
                    self._ransac_verify_single,
                    img_id,
                    query_features["orb_kp"],
                    query_features["orb_desc"]
                )
                futures[future] = candidate

            # 收集结果
            for future in as_completed(futures):
                candidate = futures[future]
                try:
                    ransac_inliers = future.result(timeout=5)  # 单个RANSAC最多5秒
                    candidate["ransac_inliers"] = ransac_inliers
                    candidate["details"]["ransac_inliers"] = ransac_inliers
                except Exception as e:
                    logger.warning(f"RANSAC验证超时或失败 (img_id={candidate['img_id']}): {e}")
                    candidate["ransac_inliers"] = 0
                    candidate["details"]["ransac_inliers"] = 0

        # ========== 阶段4: 最终打分 + 置信度评估 ==========
        logger.info("【阶段4】最终融合打分...")

        # 调试:检查RANSAC后的候选
        with_ransac = [c for c in scored_candidates[:top_for_ransac] if c.get("ransac_inliers", 0) > 0]
        logger.info(f"RANSAC验证完成: {len(with_ransac)}/{top_for_ransac} 个候选有匹配点")
        if with_ransac:
            logger.info(f"  最高RANSAC匹配数: {max(c.get('ransac_inliers', 0) for c in with_ransac)}")

        final_results = []

        for candidate in scored_candidates[:self.max_candidates]:
            # 加入RANSAC得分
            ransac_score = candidate["ransac_inliers"] * self.weights["ransac"]

            final_score = candidate["initial_score"] + ransac_score

            # 置信度评估
            confidence = self._assess_confidence(candidate)

            final_results.append({
                "img_id": candidate["img_id"],
                "path": candidate["path"],
                "score": final_score,
                "confidence": confidence,
                "details": candidate["details"]
            })

        # 按最终得分排序
        final_results.sort(key=lambda x: x["score"], reverse=True)

        logger.info(f"【完成】返回Top{top_k}结果")
        return final_results[:top_k]


    def _load_orb_features(self, img_id):
        """
        从数据库加载图片的ORB特征(包含关键点信息)

        注意:在并行RANSAC验证时,每个线程创建独立连接以确保线程安全

        Args:
            img_id: 图片ID

        Returns:
            tuple: (keypoints, descriptors) 或 (None, None)
        """
        # 在并行环境中,使用独立连接确保线程安全
        conn = sqlite3.connect(self.db_path)
        cursor = conn.execute(
            "SELECT orb_desc, orb_keypoints FROM images WHERE id = ?",
            (img_id,)
        )

        row = cursor.fetchone()
        conn.close()

        if not row or not row[0]:
            return None, None

        orb_blob, orb_kp_blob = row

        # 反序列化描述子
        try:
            orb_desc = np.frombuffer(orb_blob, dtype=np.uint8).reshape(-1, 32)

            # 反序列化关键点(使用pickle,与database.py的序列化方式一致)
            if orb_kp_blob:
                import pickle
                kp_list = pickle.loads(orb_kp_blob)
                orb_kp = [
                    cv2.KeyPoint(
                        x=float(kp_dict['pt'][0]),
                        y=float(kp_dict['pt'][1]),
                        size=float(kp_dict['size']),
                        angle=float(kp_dict['angle'])
                    )
                    for kp_dict in kp_list
                ]
            else:
                logger.warning(f"图片 {img_id} 缺少ORB关键点,跳过RANSAC验证")
                return None, None

            return orb_kp, orb_desc

        except Exception as e:
            logger.error(f"加载ORB特征失败 (img_id={img_id}): {e}")
            return None, None

    def _ransac_verify_single(self, img_id, query_kp, query_desc):
        """
        对单个候选图进行RANSAC验证(线程安全版本)

        Args:
            img_id: 候选图ID
            query_kp: 查询图的ORB关键点
            query_desc: 查询图的ORB描述子

        Returns:
            int: RANSAC内点数(0表示几何不一致)
        """
        # 从数据库加载候选图的ORB特征
        cand_orb_kp, cand_orb_desc = self._load_orb_features(img_id)

        if cand_orb_kp is None or cand_orb_desc is None:
            return 0

        # RANSAC验证
        try:
            ransac_inliers = self.extractor.compute_orb_ransac_score(
                (query_kp, query_desc),
                (cand_orb_kp, cand_orb_desc),
                min_inliers=self.min_orb_inliers,
                ransac_thresh=self.ransac_reproj_thresh,
                confidence=self.ransac_confidence
            )
            return ransac_inliers
        except Exception as e:
            logger.error(f"RANSAC计算失败 (img_id={img_id}): {e}")
            return 0

    def _assess_confidence(self, candidate):
        """
        评估搜索结果的置信度(珠宝图片优化版)

        针对变色和裁切场景优化的判断标准:
        - high: RANSAC内点>=15(几何一致性强)或 CNN相似度>0.8
        - medium: RANSAC内点>=8(降低要求)或 CNN相似度>0.7
        - low: 其他

        Args:
            candidate: 候选结果字典

        Returns:
            str: "high", "medium", "low"
        """
        details = candidate["details"]
        ransac_inliers = details.get("ransac_inliers", 0)
        cnn_sim = details.get("cnn_sim", 0.0)

        # High confidence(几何一致性或语义确认)
        if ransac_inliers >= 15 and cnn_sim > 0.8:
            return "high"

        # Medium confidence(降低要求以提升召回)
        if ransac_inliers >= 8 and cnn_sim > 0.7:
            return "medium"

        # Low confidence
        return "low"