search_engine.py
13.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
"""
搜索引擎模块(CNN+RANSAC优化版)
CNN召回 + RANSAC验证 + 置信度评分
"""
import sqlite3
import numpy as np
import cv2
import logging
import threading
from collections import Counter
from .feature_extractor import FeatureExtractor
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageSearchEngine:
"""图片搜索引擎(CNN+RANSAC优化版)"""
def __init__(self, db_path, config, memory_index=None, faiss_manager=None):
"""
初始化搜索引擎
Args:
db_path: SQLite 数据库路径
config: 配置字典
memory_index: 内存索引实例(InMemoryIndex)
faiss_manager: FAISS管理器实例(FAISSManager)
"""
self.db_path = db_path
self.config = config
self.memory_index = memory_index
self.faiss_manager = faiss_manager
# 线程本地存储(每个线程独立的数据库连接)
self._thread_local = threading.local()
logger.info(f"搜索引擎将为每个线程创建独立数据库连接: {db_path}")
self.extractor = FeatureExtractor(
orb_max_features=config.get("orb_max_features", 1200),
cnn_enabled=config.get("cnn_enabled", True)
)
# 加载配置参数
search_config = config.get("search", {})
self.cnn_top_k = search_config.get("cnn_top_k", 2000)
self.max_candidates = search_config.get("max_candidates", 200)
self.top_results = search_config.get("top_results", 20)
# RANSAC配置(从search.ransac读取)
ransac_config = search_config.get("ransac", {})
self.min_orb_inliers = ransac_config.get("min_inliers", 15)
self.ransac_reproj_thresh = ransac_config.get("reproj_threshold", 4.0)
self.ransac_confidence = ransac_config.get("confidence", 0.995)
# 融合权重(从search.fusion_weights读取)
weights_config = search_config.get("fusion_weights", {})
# 权重需要乘以100,因为打分逻辑中会 * 100
self.weights = {
"cnn": weights_config.get("cnn", 0.2) * 100,
"ransac": weights_config.get("ransac", 0.8) * 100
}
logger.info(f"搜索引擎初始化完成(CNN+RANSAC优化版)")
logger.info(f"权重: {self.weights}")
def _get_db_conn(self):
"""
获取当前线程的数据库连接(线程安全)
Returns:
sqlite3.Connection: 当前线程的数据库连接
"""
if not hasattr(self._thread_local, 'conn'):
# 为当前线程创建新连接
self._thread_local.conn = sqlite3.connect(self.db_path)
logger.debug(f"为线程 {threading.current_thread().name} 创建数据库连接")
return self._thread_local.conn
def search(self, query_img_path, top_k=None):
"""
搜索相似图片(CNN优先 + RANSAC几何验证)
流程(优化版):
1. 优先CNN向量召回(语义特征,对变色裁切最鲁棒)
2. 辅助pHash召回(保留但权重极低)
3. Top候选进行RANSAC几何验证(并行化)
4. 融合打分 + 置信度评估
5. 返回Top-K结果
Args:
query_img_path: 查询图片路径
top_k: 返回结果数量(None则使用配置的top_results)
Returns:
list: [{
"path": str,
"score": float,
"confidence": str, # "high", "medium", "low"
"details": {
"cnn_sim": float,
"ransac_inliers": int
}
}, ...]
"""
if top_k is None:
top_k = self.top_results
# 提取查询图片的所有特征
logger.info(f"开始搜索: {query_img_path}")
query_features = self.extractor.extract_all_features(query_img_path)
if query_features is None:
logger.error(f"查询图片特征提取失败: {query_img_path}")
return []
# ========== 阶段1: CNN召回 ==========
logger.info("【阶段1】CNN召回...")
cnn_candidates = set()
cnn_similarity_cache = {} # 缓存CNN相似度 {img_id: similarity}
# CNN向量召回(唯一通道)
if self.faiss_manager and query_features["cnn_vector"] is not None:
cnn_results = self.faiss_manager.search(
query_features["cnn_vector"],
top_k=self.cnn_top_k
)
for img_id, sim in cnn_results:
cnn_candidates.add(img_id)
cnn_similarity_cache[img_id] = sim # 缓存相似度
logger.info(f" ✓ CNN召回: {len(cnn_candidates)} 个候选(语义特征,对变色裁切最鲁棒)")
if not cnn_candidates:
logger.warning("未召回任何候选图片(可能是FAISS索引未初始化)")
return []
logger.info(f"【候选】总候选数: {len(cnn_candidates)}")
# ========== 阶段2: 快速打分(CNN优先) ==========
logger.info("【阶段2】快速打分...")
scored_candidates = []
for img_id in cnn_candidates:
# 从内存索引获取元数据
if not self.memory_index:
continue
metadata = self.memory_index.img_metadata.get(str(img_id))
if not metadata:
continue
# 计算各维度得分
details = {}
# 2.1 CNN相似度得分(主要得分)
cnn_sim = 0.0
if img_id in cnn_similarity_cache:
cnn_sim = cnn_similarity_cache[img_id]
details["cnn_sim"] = cnn_sim
# CNN得分:直接使用相似度乘以权重,归一化到0-100
cnn_score = cnn_sim * 100 * self.weights["cnn"]
# 初步总分(CNN主导)
initial_score = cnn_score
scored_candidates.append({
"img_id": img_id,
"path": metadata["path"],
"initial_score": initial_score,
"details": details,
"ransac_inliers": 0 # 待RANSAC验证
})
# 按初步得分排序
scored_candidates.sort(key=lambda x: x["initial_score"], reverse=True)
# 调试:打印阶段2结果
logger.info(f"【阶段2完成】候选数量: {len(scored_candidates)}")
if scored_candidates:
logger.info(f" 最高分: {scored_candidates[0]['initial_score']:.4f}")
logger.info(f" 最低分: {scored_candidates[-1]['initial_score']:.4f}")
logger.info(f" 前3名: {[c['img_id'] for c in scored_candidates[:3]]}")
# 截取Top候选进行RANSAC验证
top_for_ransac = min(100, len(scored_candidates))
logger.info(f"【阶段3】对Top{top_for_ransac}进行RANSAC几何验证(容忍变色裁切)...")
# ========== 阶段3: RANSAC几何验证(并行化) ==========
from concurrent.futures import ThreadPoolExecutor, as_completed
# 使用4线程并行处理RANSAC验证
with ThreadPoolExecutor(max_workers=4) as executor:
# 准备任务列表
futures = {}
for i, candidate in enumerate(scored_candidates[:top_for_ransac]):
img_id = candidate["img_id"]
# 提交RANSAC任务到线程池
future = executor.submit(
self._ransac_verify_single,
img_id,
query_features["orb_kp"],
query_features["orb_desc"]
)
futures[future] = candidate
# 收集结果
for future in as_completed(futures):
candidate = futures[future]
try:
ransac_inliers = future.result(timeout=5) # 单个RANSAC最多5秒
candidate["ransac_inliers"] = ransac_inliers
candidate["details"]["ransac_inliers"] = ransac_inliers
except Exception as e:
logger.warning(f"RANSAC验证超时或失败 (img_id={candidate['img_id']}): {e}")
candidate["ransac_inliers"] = 0
candidate["details"]["ransac_inliers"] = 0
# ========== 阶段4: 最终打分 + 置信度评估 ==========
logger.info("【阶段4】最终融合打分...")
# 调试:检查RANSAC后的候选
with_ransac = [c for c in scored_candidates[:top_for_ransac] if c.get("ransac_inliers", 0) > 0]
logger.info(f"RANSAC验证完成: {len(with_ransac)}/{top_for_ransac} 个候选有匹配点")
if with_ransac:
logger.info(f" 最高RANSAC匹配数: {max(c.get('ransac_inliers', 0) for c in with_ransac)}")
final_results = []
for candidate in scored_candidates[:self.max_candidates]:
# 加入RANSAC得分
ransac_score = candidate["ransac_inliers"] * self.weights["ransac"]
final_score = candidate["initial_score"] + ransac_score
# 置信度评估
confidence = self._assess_confidence(candidate)
final_results.append({
"img_id": candidate["img_id"],
"path": candidate["path"],
"score": final_score,
"confidence": confidence,
"details": candidate["details"]
})
# 按最终得分排序
final_results.sort(key=lambda x: x["score"], reverse=True)
logger.info(f"【完成】返回Top{top_k}结果")
return final_results[:top_k]
def _load_orb_features(self, img_id):
"""
从数据库加载图片的ORB特征(包含关键点信息)
注意:在并行RANSAC验证时,每个线程创建独立连接以确保线程安全
Args:
img_id: 图片ID
Returns:
tuple: (keypoints, descriptors) 或 (None, None)
"""
# 在并行环境中,使用独立连接确保线程安全
conn = sqlite3.connect(self.db_path)
cursor = conn.execute(
"SELECT orb_desc, orb_keypoints FROM images WHERE id = ?",
(img_id,)
)
row = cursor.fetchone()
conn.close()
if not row or not row[0]:
return None, None
orb_blob, orb_kp_blob = row
# 反序列化描述子
try:
orb_desc = np.frombuffer(orb_blob, dtype=np.uint8).reshape(-1, 32)
# 反序列化关键点(使用pickle,与database.py的序列化方式一致)
if orb_kp_blob:
import pickle
kp_list = pickle.loads(orb_kp_blob)
orb_kp = [
cv2.KeyPoint(
x=float(kp_dict['pt'][0]),
y=float(kp_dict['pt'][1]),
size=float(kp_dict['size']),
angle=float(kp_dict['angle'])
)
for kp_dict in kp_list
]
else:
logger.warning(f"图片 {img_id} 缺少ORB关键点,跳过RANSAC验证")
return None, None
return orb_kp, orb_desc
except Exception as e:
logger.error(f"加载ORB特征失败 (img_id={img_id}): {e}")
return None, None
def _ransac_verify_single(self, img_id, query_kp, query_desc):
"""
对单个候选图进行RANSAC验证(线程安全版本)
Args:
img_id: 候选图ID
query_kp: 查询图的ORB关键点
query_desc: 查询图的ORB描述子
Returns:
int: RANSAC内点数(0表示几何不一致)
"""
# 从数据库加载候选图的ORB特征
cand_orb_kp, cand_orb_desc = self._load_orb_features(img_id)
if cand_orb_kp is None or cand_orb_desc is None:
return 0
# RANSAC验证
try:
ransac_inliers = self.extractor.compute_orb_ransac_score(
(query_kp, query_desc),
(cand_orb_kp, cand_orb_desc),
min_inliers=self.min_orb_inliers,
ransac_thresh=self.ransac_reproj_thresh,
confidence=self.ransac_confidence
)
return ransac_inliers
except Exception as e:
logger.error(f"RANSAC计算失败 (img_id={img_id}): {e}")
return 0
def _assess_confidence(self, candidate):
"""
评估搜索结果的置信度(珠宝图片优化版)
针对变色和裁切场景优化的判断标准:
- high: RANSAC内点>=15(几何一致性强)或 CNN相似度>0.8
- medium: RANSAC内点>=8(降低要求)或 CNN相似度>0.7
- low: 其他
Args:
candidate: 候选结果字典
Returns:
str: "high", "medium", "low"
"""
details = candidate["details"]
ransac_inliers = details.get("ransac_inliers", 0)
cnn_sim = details.get("cnn_sim", 0.0)
# High confidence(几何一致性或语义确认)
if ransac_inliers >= 15 and cnn_sim > 0.8:
return "high"
# Medium confidence(降低要求以提升召回)
if ransac_inliers >= 8 and cnn_sim > 0.7:
return "medium"
# Low confidence
return "low"