feature_extractor.py 11.2 KB
"""
核心特征提取模块(CNN+RANSAC优化版)
- CNN向量: MobileNetV3,语义特征
- ORB特征: 保留关键点,支持RANSAC
"""

import logging
import cv2
import numpy as np
from PIL import Image, ImageOps

# 延迟导入torch(避免启动时加载)
_torch_model = None
_torch_transforms = None

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class FeatureExtractor:
    """图片特征提取器(准确度优先)"""

    def __init__(self, orb_max_features=1200, cnn_enabled=True):
        """
        初始化特征提取器

        Args:
            orb_max_features: ORB 最大特征点数 (默认1200)
            cnn_enabled: 是否启用CNN特征提取
        """
        self.orb_max_features = orb_max_features
        self.cnn_enabled = cnn_enabled

        # 初始化ORB检测器
        self.orb = cv2.ORB_create(nfeatures=orb_max_features)

        # 延迟加载CNN模型
        if self.cnn_enabled:
            self._init_cnn_model()

        logger.info(f"特征提取器初始化完成: ORB={orb_max_features}, CNN={cnn_enabled}")

    def _init_cnn_model(self):
        """延迟初始化CNN模型(只在需要时加载)"""
        global _torch_model, _torch_transforms

        if _torch_model is not None:
            return  # 已加载

        try:
            import torch
            import torchvision.models as models
            import torchvision.transforms as transforms

            # 加载MobileNetV3-Small(轻量级模型)
            logger.info("加载MobileNetV3-Small模型...")
            model = models.mobilenet_v3_small(weights='DEFAULT')
            model.classifier = torch.nn.Identity()  # 去掉分类头
            model.eval()

            # 移动到CPU(你的环境没有GPU)
            device = torch.device("cpu")
            model = model.to(device)

            # 图像预处理
            _torch_transforms = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

            _torch_model = model
            logger.info("CNN模型加载完成")

        except Exception as e:
            logger.error(f"CNN模型加载失败: {e}")
            self.cnn_enabled = False


    # ========== 新功能4: CNN特征提取 ==========

    def extract_cnn_feature(self, img_path):
        """
        提取CNN向量特征(MobileNetV3 + 灰度转换)

        Args:
            img_path: 图片路径

        Returns:
            np.ndarray: 576维float32向量(L2归一化)
        """
        if not self.cnn_enabled or _torch_model is None:
            return None

        try:
            import torch

            # 读取图片并转灰度(消除颜色影响,借鉴你的思路)
            img = Image.open(img_path).convert("L")

            # 转回3通道(CNN需要RGB)
            img_rgb = ImageOps.colorize(img, black="black", white="white")

            # 预处理
            img_tensor = _torch_transforms(img_rgb).unsqueeze(0)

            # 推理
            with torch.no_grad():
                features = _torch_model(img_tensor).squeeze().cpu().numpy()

            # L2归一化(用于余弦相似度)
            features = features / (np.linalg.norm(features) + 1e-8)

            return features.astype('float32')

        except Exception as e:
            logger.error(f"CNN特征提取失败 {img_path}: {e}")
            return None

    # ========== 新功能5: ORB特征(保留关键点信息) ==========

    def compute_orb_with_keypoints(self, img_path):
        """
        计算ORB特征并保留关键点信息(用于RANSAC验证)

        Args:
            img_path: 图片路径

        Returns:
            tuple: (keypoints, descriptors) 或 (None, None)
                - keypoints: list of cv2.KeyPoint
                - descriptors: numpy.ndarray (N x 32, uint8)
        """
        try:
            # 读取灰度图(使用PIL避免中文路径问题)
            if isinstance(img_path, str):
                pil_img = Image.open(img_path).convert("L")
                img = np.array(pil_img)
            else:
                img = np.array(img_path.convert("L"))

            if img is None or img.size == 0:
                logger.error(f"无法读取图片 {img_path}")
                return None, None

            # 检测ORB特征点和描述子
            keypoints, descriptors = self.orb.detectAndCompute(img, None)

            if descriptors is None or len(descriptors) == 0:
                logger.warning(f"图片无ORB特征 {img_path}")
                return None, None

            return keypoints, descriptors

        except Exception as e:
            logger.error(f"计算ORB失败 {img_path}: {e}")
            return None, None

    def compute_orb(self, img_path):
        """只返回描述子(向后兼容)"""
        _, descriptors = self.compute_orb_with_keypoints(img_path)
        return descriptors

    # ========== 新功能6: 真正的RANSAC几何验证 ==========

    def compute_orb_ransac_score(self, query_kp_desc, cand_kp_desc,
                                  min_inliers=15, ransac_thresh=4.0, confidence=0.995):
        """
        计算ORB RANSAC匹配得分(真正的几何验证)

        Args:
            query_kp_desc: (keypoints1, descriptors1) 查询图特征
            cand_kp_desc: (keypoints2, descriptors2) 候选图特征
            min_inliers: 最小内点数阈值(提高到15,更严格)
            ransac_thresh: RANSAC重投影误差阈值
            confidence: RANSAC置信度

        Returns:
            int: RANSAC内点数(0表示几何不一致)
        """
        kp1, desc1 = query_kp_desc
        kp2, desc2 = cand_kp_desc

        if desc1 is None or desc2 is None:
            return 0

        # Step 1: 暴力匹配
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)

        try:
            matches = bf.knnMatch(desc1, desc2, k=2)
        except Exception as e:
            logger.error(f"ORB匹配失败: {e}")
            return 0

        # Step 2: Lowe's ratio test
        good_matches = []
        for match_pair in matches:
            if len(match_pair) == 2:
                m, n = match_pair
                if m.distance < 0.75 * n.distance:
                    good_matches.append(m)

        if len(good_matches) < 8:
            return 0

        # Step 3: 提取匹配点坐标
        pts1 = np.float32([kp1[m.queryIdx].pt for m in good_matches])
        pts2 = np.float32([kp2[m.trainIdx].pt for m in good_matches])

        # Step 4: RANSAC估计单应性矩阵
        try:
            H, mask = cv2.findHomography(pts1, pts2, cv2.RANSAC,
                                         ransacReprojThreshold=ransac_thresh,
                                         confidence=confidence)
        except Exception as e:
            logger.error(f"RANSAC失败: {e}")
            return 0

        if H is None or mask is None:
            return 0

        # Step 5: 统计内点数
        inliers = int(np.sum(mask))

        # Step 6: 验证几何合理性
        if inliers >= min_inliers:
            # 检查单应性矩阵是否合理(不能太扭曲)
            try:
                det = np.linalg.det(H[0:2, 0:2])
                if 0.1 < abs(det) < 10:  # 缩放范围[0.1, 10]
                    return inliers
            except:
                pass

        return 0

    # ========== 统一接口:提取所有特征 ==========

    def extract_all_features(self, img_path):
        """
        提取图片的所有特征(统一接口 - 优化版:只读盘一次!)

        Args:
            img_path: 图片路径

        Returns:
            dict: {
                "cnn_vector": np.ndarray,  # CNN向量(576维)
                "orb_kp": list,            # ORB关键点
                "orb_desc": np.ndarray     # ORB描述子
            } 或 None
        """
        try:
            # ========== 关键优化:只读盘一次!==========
            # 先检查图片是否可读,避免损坏文件导致整个流程失败
            try:
                pil_img_gray = Image.open(img_path).convert("L")  # 一次性读取为灰度图
            except (OSError, IOError) as e:
                # 图片文件损坏,跳过此文件
                logger.warning(f"跳过损坏的图片文件: {img_path} - {str(e)[:100]}")
                return None
            except Exception as e:
                # 其他读取错误(如权限问题、路径不存在等)
                logger.error(f"无法读取图片文件 {img_path}: {str(e)[:100]}")
                return None

            # 1. CNN向量 (复用PIL对象)
            cnn_vector = None
            if self.cnn_enabled and _torch_model is not None:
                try:
                    import torch
                    # 转回3通道(CNN需要RGB)
                    img_rgb = ImageOps.colorize(pil_img_gray, black="black", white="white")
                    img_tensor = _torch_transforms(img_rgb).unsqueeze(0)
                    with torch.no_grad():
                        features = _torch_model(img_tensor).squeeze().cpu().numpy()
                    cnn_vector = (features / (np.linalg.norm(features) + 1e-8)).astype('float32')
                except Exception as e:
                    logger.error(f"CNN特征提取失败: {e}")

            # 2. ORB特征 (复用numpy数组)
            img_array = np.array(pil_img_gray)
            keypoints, descriptors = self.orb.detectAndCompute(img_array, None)
            if descriptors is None or len(descriptors) == 0:
                logger.warning(f"图片无ORB特征 {img_path}")
                keypoints, descriptors = None, None

            return {
                "cnn_vector": cnn_vector,
                "orb_kp": keypoints,
                "orb_desc": descriptors
            }

        except Exception as e:
            logger.error(f"特征提取失败 {img_path}: {e}", exc_info=True)
            return None




def match_orb_features(desc1, desc2, ratio_threshold=0.75):
    """
    匹配两组 ORB 描述子(Lowe's ratio test)

    Args:
        desc1: 第一组描述子
        desc2: 第二组描述子
        ratio_threshold: Lowe's ratio test 阈值

    Returns:
        list: 好的匹配点列表
    """
    if desc1 is None or desc2 is None:
        return []

    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)

    try:
        matches = bf.knnMatch(desc1, desc2, k=2)
    except Exception as e:
        logger.error(f"ORB匹配失败: {e}")
        return []

    good_matches = []
    for match_pair in matches:
        if len(match_pair) == 2:
            m, n = match_pair
            if m.distance < ratio_threshold * n.distance:
                good_matches.append(m)

    return good_matches


def compute_orb_score(query_desc, candidate_desc, min_inliers=10):
    """
    计算ORB匹配得分(简化版,向后兼容)

    Args:
        query_desc: 查询图片的ORB描述子
        candidate_desc: 候选图片的ORB描述子
        min_inliers: 最小内点数阈值

    Returns:
        int: 匹配点数量
    """
    good_matches = match_orb_features(query_desc, candidate_desc)
    return len(good_matches)