feature_extractor.py
11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""
核心特征提取模块(CNN+RANSAC优化版)
- CNN向量: MobileNetV3,语义特征
- ORB特征: 保留关键点,支持RANSAC
"""
import logging
import cv2
import numpy as np
from PIL import Image, ImageOps
# 延迟导入torch(避免启动时加载)
_torch_model = None
_torch_transforms = None
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class FeatureExtractor:
"""图片特征提取器(准确度优先)"""
def __init__(self, orb_max_features=1200, cnn_enabled=True):
"""
初始化特征提取器
Args:
orb_max_features: ORB 最大特征点数 (默认1200)
cnn_enabled: 是否启用CNN特征提取
"""
self.orb_max_features = orb_max_features
self.cnn_enabled = cnn_enabled
# 初始化ORB检测器
self.orb = cv2.ORB_create(nfeatures=orb_max_features)
# 延迟加载CNN模型
if self.cnn_enabled:
self._init_cnn_model()
logger.info(f"特征提取器初始化完成: ORB={orb_max_features}, CNN={cnn_enabled}")
def _init_cnn_model(self):
"""延迟初始化CNN模型(只在需要时加载)"""
global _torch_model, _torch_transforms
if _torch_model is not None:
return # 已加载
try:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
# 加载MobileNetV3-Small(轻量级模型)
logger.info("加载MobileNetV3-Small模型...")
model = models.mobilenet_v3_small(weights='DEFAULT')
model.classifier = torch.nn.Identity() # 去掉分类头
model.eval()
# 移动到CPU(你的环境没有GPU)
device = torch.device("cpu")
model = model.to(device)
# 图像预处理
_torch_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
_torch_model = model
logger.info("CNN模型加载完成")
except Exception as e:
logger.error(f"CNN模型加载失败: {e}")
self.cnn_enabled = False
# ========== 新功能4: CNN特征提取 ==========
def extract_cnn_feature(self, img_path):
"""
提取CNN向量特征(MobileNetV3 + 灰度转换)
Args:
img_path: 图片路径
Returns:
np.ndarray: 576维float32向量(L2归一化)
"""
if not self.cnn_enabled or _torch_model is None:
return None
try:
import torch
# 读取图片并转灰度(消除颜色影响,借鉴你的思路)
img = Image.open(img_path).convert("L")
# 转回3通道(CNN需要RGB)
img_rgb = ImageOps.colorize(img, black="black", white="white")
# 预处理
img_tensor = _torch_transforms(img_rgb).unsqueeze(0)
# 推理
with torch.no_grad():
features = _torch_model(img_tensor).squeeze().cpu().numpy()
# L2归一化(用于余弦相似度)
features = features / (np.linalg.norm(features) + 1e-8)
return features.astype('float32')
except Exception as e:
logger.error(f"CNN特征提取失败 {img_path}: {e}")
return None
# ========== 新功能5: ORB特征(保留关键点信息) ==========
def compute_orb_with_keypoints(self, img_path):
"""
计算ORB特征并保留关键点信息(用于RANSAC验证)
Args:
img_path: 图片路径
Returns:
tuple: (keypoints, descriptors) 或 (None, None)
- keypoints: list of cv2.KeyPoint
- descriptors: numpy.ndarray (N x 32, uint8)
"""
try:
# 读取灰度图(使用PIL避免中文路径问题)
if isinstance(img_path, str):
pil_img = Image.open(img_path).convert("L")
img = np.array(pil_img)
else:
img = np.array(img_path.convert("L"))
if img is None or img.size == 0:
logger.error(f"无法读取图片 {img_path}")
return None, None
# 检测ORB特征点和描述子
keypoints, descriptors = self.orb.detectAndCompute(img, None)
if descriptors is None or len(descriptors) == 0:
logger.warning(f"图片无ORB特征 {img_path}")
return None, None
return keypoints, descriptors
except Exception as e:
logger.error(f"计算ORB失败 {img_path}: {e}")
return None, None
def compute_orb(self, img_path):
"""只返回描述子(向后兼容)"""
_, descriptors = self.compute_orb_with_keypoints(img_path)
return descriptors
# ========== 新功能6: 真正的RANSAC几何验证 ==========
def compute_orb_ransac_score(self, query_kp_desc, cand_kp_desc,
min_inliers=15, ransac_thresh=4.0, confidence=0.995):
"""
计算ORB RANSAC匹配得分(真正的几何验证)
Args:
query_kp_desc: (keypoints1, descriptors1) 查询图特征
cand_kp_desc: (keypoints2, descriptors2) 候选图特征
min_inliers: 最小内点数阈值(提高到15,更严格)
ransac_thresh: RANSAC重投影误差阈值
confidence: RANSAC置信度
Returns:
int: RANSAC内点数(0表示几何不一致)
"""
kp1, desc1 = query_kp_desc
kp2, desc2 = cand_kp_desc
if desc1 is None or desc2 is None:
return 0
# Step 1: 暴力匹配
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
try:
matches = bf.knnMatch(desc1, desc2, k=2)
except Exception as e:
logger.error(f"ORB匹配失败: {e}")
return 0
# Step 2: Lowe's ratio test
good_matches = []
for match_pair in matches:
if len(match_pair) == 2:
m, n = match_pair
if m.distance < 0.75 * n.distance:
good_matches.append(m)
if len(good_matches) < 8:
return 0
# Step 3: 提取匹配点坐标
pts1 = np.float32([kp1[m.queryIdx].pt for m in good_matches])
pts2 = np.float32([kp2[m.trainIdx].pt for m in good_matches])
# Step 4: RANSAC估计单应性矩阵
try:
H, mask = cv2.findHomography(pts1, pts2, cv2.RANSAC,
ransacReprojThreshold=ransac_thresh,
confidence=confidence)
except Exception as e:
logger.error(f"RANSAC失败: {e}")
return 0
if H is None or mask is None:
return 0
# Step 5: 统计内点数
inliers = int(np.sum(mask))
# Step 6: 验证几何合理性
if inliers >= min_inliers:
# 检查单应性矩阵是否合理(不能太扭曲)
try:
det = np.linalg.det(H[0:2, 0:2])
if 0.1 < abs(det) < 10: # 缩放范围[0.1, 10]
return inliers
except:
pass
return 0
# ========== 统一接口:提取所有特征 ==========
def extract_all_features(self, img_path):
"""
提取图片的所有特征(统一接口 - 优化版:只读盘一次!)
Args:
img_path: 图片路径
Returns:
dict: {
"cnn_vector": np.ndarray, # CNN向量(576维)
"orb_kp": list, # ORB关键点
"orb_desc": np.ndarray # ORB描述子
} 或 None
"""
try:
# ========== 关键优化:只读盘一次!==========
# 先检查图片是否可读,避免损坏文件导致整个流程失败
try:
pil_img_gray = Image.open(img_path).convert("L") # 一次性读取为灰度图
except (OSError, IOError) as e:
# 图片文件损坏,跳过此文件
logger.warning(f"跳过损坏的图片文件: {img_path} - {str(e)[:100]}")
return None
except Exception as e:
# 其他读取错误(如权限问题、路径不存在等)
logger.error(f"无法读取图片文件 {img_path}: {str(e)[:100]}")
return None
# 1. CNN向量 (复用PIL对象)
cnn_vector = None
if self.cnn_enabled and _torch_model is not None:
try:
import torch
# 转回3通道(CNN需要RGB)
img_rgb = ImageOps.colorize(pil_img_gray, black="black", white="white")
img_tensor = _torch_transforms(img_rgb).unsqueeze(0)
with torch.no_grad():
features = _torch_model(img_tensor).squeeze().cpu().numpy()
cnn_vector = (features / (np.linalg.norm(features) + 1e-8)).astype('float32')
except Exception as e:
logger.error(f"CNN特征提取失败: {e}")
# 2. ORB特征 (复用numpy数组)
img_array = np.array(pil_img_gray)
keypoints, descriptors = self.orb.detectAndCompute(img_array, None)
if descriptors is None or len(descriptors) == 0:
logger.warning(f"图片无ORB特征 {img_path}")
keypoints, descriptors = None, None
return {
"cnn_vector": cnn_vector,
"orb_kp": keypoints,
"orb_desc": descriptors
}
except Exception as e:
logger.error(f"特征提取失败 {img_path}: {e}", exc_info=True)
return None
def match_orb_features(desc1, desc2, ratio_threshold=0.75):
"""
匹配两组 ORB 描述子(Lowe's ratio test)
Args:
desc1: 第一组描述子
desc2: 第二组描述子
ratio_threshold: Lowe's ratio test 阈值
Returns:
list: 好的匹配点列表
"""
if desc1 is None or desc2 is None:
return []
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
try:
matches = bf.knnMatch(desc1, desc2, k=2)
except Exception as e:
logger.error(f"ORB匹配失败: {e}")
return []
good_matches = []
for match_pair in matches:
if len(match_pair) == 2:
m, n = match_pair
if m.distance < ratio_threshold * n.distance:
good_matches.append(m)
return good_matches
def compute_orb_score(query_desc, candidate_desc, min_inliers=10):
"""
计算ORB匹配得分(简化版,向后兼容)
Args:
query_desc: 查询图片的ORB描述子
candidate_desc: 候选图片的ORB描述子
min_inliers: 最小内点数阈值
Returns:
int: 匹配点数量
"""
good_matches = match_orb_features(query_desc, candidate_desc)
return len(good_matches)