database.py 12 KB
"""
SQLite database initialization and management for Design Image Search
Design 图像搜索的 SQLite 数据库初始化和管理模块
"""

import sqlite3
import logging
import os
from contextlib import contextmanager

logger = logging.getLogger(__name__)


class DatabaseManager:
    """SQLite 数据库管理器"""

    def __init__(self, db_path="./data/design_images.db"):
        """
        初始化数据库管理器

        Args:
            db_path: SQLite 数据库文件路径
        """
        self.db_path = db_path
        self._ensure_database_exists()
        self._initialize_tables()

    def _ensure_database_exists(self):
        """确保数据库目录存在"""
        db_dir = os.path.dirname(self.db_path)
        if db_dir and not os.path.exists(db_dir):
            os.makedirs(db_dir, exist_ok=True)
            logger.info(f"创建数据库目录: {db_dir}")

    def _initialize_tables(self):
        """初始化数据库表"""
        with self.get_connection() as conn:
            # 检查表是否存在
            cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='images'")
            table_exists = cursor.fetchone()

            if table_exists:
                # 表已存在,检查是否有 design_no 字段
                cursor = conn.execute("PRAGMA table_info(images)")
                columns = [column[1] for column in cursor.fetchall()]

                if 'design_no' not in columns:
                    # 如果有 item_no 但没有 design_no,则重命名列
                    if 'item_no' in columns:
                        logger.info("重命名 item_no 列为 design_no")
                        # SQLite 不支持直接重命名列,需要重建表
                        conn.execute("""
                            ALTER TABLE images RENAME TO images_old
                        """)

                        # 创建新表
                        conn.execute("""
                            CREATE TABLE images (
                                id TEXT PRIMARY KEY,              -- design_id
                                path TEXT,                        -- 本地图片路径
                                design_no TEXT,                   -- 设计款号
                                image_url TEXT,                   -- 原始图片 URL
                                cnn_vector BLOB,                  -- CNN 向量 (576 维)
                                orb_desc BLOB,                    -- ORB 描述子
                                orb_keypoints BLOB,               -- ORB 关键点
                                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
                            )
                        """)

                        # 复制数据
                        conn.execute("""
                            INSERT INTO images
                            (id, path, design_no, image_url, cnn_vector, orb_desc, orb_keypoints, created_at, updated_at)
                            SELECT id, path, item_no, image_url, cnn_vector, orb_desc, orb_keypoints, created_at, updated_at
                            FROM images_old
                        """)

                        # 删除旧表
                        conn.execute("DROP TABLE images_old")
                        logger.info("表结构更新完成")
                    else:
                        # 添加 design_no 列
                        conn.execute("ALTER TABLE images ADD COLUMN design_no TEXT")
            else:
                # 表不存在,创建新表
                conn.execute("""
                    CREATE TABLE images (
                        id TEXT PRIMARY KEY,              -- design_id
                        path TEXT,                        -- 本地图片路径
                        design_no TEXT,                   -- 设计款号
                        image_url TEXT,                   -- 原始图片 URL
                        cnn_vector BLOB,                  -- CNN 向量 (576 维)
                        orb_desc BLOB,                    -- ORB 描述子
                        orb_keypoints BLOB,               -- ORB 关键点
                        created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                        updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
                    )
                """)

            # 创建同步状态表
            conn.execute("""
                CREATE TABLE IF NOT EXISTS sync_status (
                    id INTEGER PRIMARY KEY CHECK (id = 1),
                    last_sync_time DATETIME,          -- 上次同步时间
                    last_sync_count INTEGER DEFAULT 0, -- 上次同步记录数
                    updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
                )
            """)

            # 创建索引
            conn.execute("CREATE INDEX IF NOT EXISTS idx_images_updated_at ON images(updated_at)")
            conn.execute("CREATE INDEX IF NOT EXISTS idx_images_design_no ON images(design_no)")

            # 初始化 sync_status 表(如果不存在)
            conn.execute("""
                INSERT OR IGNORE INTO sync_status (id, last_sync_time)
                VALUES (1, NULL)
            """)

            conn.commit()
            logger.info("数据库表初始化完成")

    @contextmanager
    def get_connection(self):
        """获取数据库连接的上下文管理器"""
        conn = sqlite3.connect(self.db_path, check_same_thread=False)
        conn.row_factory = sqlite3.Row  # 返回字典格式
        try:
            yield conn
        finally:
            conn.close()

    def save_image_features(self, design_id, design_no, image_url, image_path,
                           cnn_vector, orb_keypoints, orb_desc):
        """
        保存图片特征到数据库

        Args:
            design_id: 设计 ID
            design_no: 设计款号
            image_url: 原始图片 URL
            image_path: 本地图片路径
            cnn_vector: CNN 向量 (numpy array)
            orb_keypoints: ORB 关键点 (list of cv2.KeyPoint)
            orb_desc: ORB 描述子 (numpy array)

        Returns:
            bool: 是否成功
        """
        try:
            import pickle
            import numpy as np
            import cv2

            # 序列化数据
            cnn_blob = cnn_vector.astype(np.float32).tobytes() if cnn_vector is not None else None

            # ORB 关键点需要特殊处理
            if orb_keypoints is not None:
                # 将 cv2.KeyPoint 对象转换为可序列化的字典
                kp_list = []
                for kp in orb_keypoints:
                    kp_dict = {
                        'pt': kp.pt,
                        'size': kp.size,
                        'angle': kp.angle,
                        'response': kp.response,
                        'octave': kp.octave,
                        'class_id': kp.class_id
                    }
                    kp_list.append(kp_dict)
                orb_kp_blob = pickle.dumps(kp_list)
            else:
                orb_kp_blob = None

            orb_desc_blob = orb_desc.astype(np.uint8).tobytes() if orb_desc is not None else None

            with self.get_connection() as conn:
                # 使用 INSERT OR REPLACE 进行 upsert
                conn.execute("""
                    INSERT OR REPLACE INTO images
                    (id, design_no, image_url, path, cnn_vector, orb_keypoints, orb_desc, updated_at)
                    VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
                """, (design_id, design_no, image_url, image_path,
                      cnn_blob, orb_kp_blob, orb_desc_blob))
                conn.commit()
                return True

        except Exception as e:
            logger.error(f"保存图片特征失败 (design_id={design_id}): {e}")
            return False

    def get_last_sync_time(self):
        """
        获取上次同步时间

        Returns:
            datetime or None: 上次同步时间,首次返回 None
        """
        try:
            with self.get_connection() as conn:
                cursor = conn.execute(
                    "SELECT last_sync_time FROM sync_status WHERE id = 1"
                )
                row = cursor.fetchone()
                if row and row['last_sync_time']:
                    # 如果是字符串,转换为 datetime
                    if isinstance(row['last_sync_time'], str):
                        from datetime import datetime
                        # 尝试多种时间格式
                        for fmt in ['%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M:%S.%f']:
                            try:
                                return datetime.strptime(row['last_sync_time'], fmt)
                            except ValueError:
                                continue
                        # 如果都失败了,返回 1970-01-01
                        return datetime(1970, 1, 1)
                    return row['last_sync_time']
                return None
        except Exception as e:
            logger.error(f"获取上次同步时间失败: {e}")
            return None

    def update_sync_time(self, sync_time, sync_count=0):
        """
        更新同步时间

        Args:
            sync_time: 同步时间
            sync_count: 同步记录数
        """
        try:
            with self.get_connection() as conn:
                conn.execute("""
                    UPDATE sync_status
                    SET last_sync_time = ?, last_sync_count = ?, updated_at = CURRENT_TIMESTAMP
                    WHERE id = 1
                """, (sync_time, sync_count))
                conn.commit()
                logger.info(f"更新同步时间: {sync_time}, 记录数: {sync_count}")
        except Exception as e:
            logger.error(f"更新同步时间失败: {e}")

    def get_image_count(self):
        """
        获取图片总数

        Returns:
            int: 图片总数
        """
        try:
            with self.get_connection() as conn:
                cursor = conn.execute("SELECT COUNT(*) as count FROM images")
                row = cursor.fetchone()
                return row['count'] if row else 0
        except Exception as e:
            logger.error(f"获取图片总数失败: {e}")
            return 0

    def get_stats(self):
        """
        获取数据库统计信息

        Returns:
            dict: 统计信息
        """
        try:
            with self.get_connection() as conn:
                # 图片总数
                cursor = conn.execute("SELECT COUNT(*) as total FROM images")
                total = cursor.fetchone()['total']

                # 有 CNN 向量的数量
                cursor = conn.execute(
                    "SELECT COUNT(*) as count FROM images WHERE cnn_vector IS NOT NULL"
                )
                cnn_count = cursor.fetchone()['count']

                # 有 ORB 特征的数量
                cursor = conn.execute(
                    "SELECT COUNT(*) as count FROM images WHERE orb_desc IS NOT NULL"
                )
                orb_count = cursor.fetchone()['count']

                # 最后同步时间和记录数
                cursor = conn.execute(
                    "SELECT last_sync_time, last_sync_count FROM sync_status WHERE id = 1"
                )
                sync_row = cursor.fetchone()

                return {
                    "total_images": total,
                    "cnn_features": cnn_count,
                    "orb_features": orb_count,
                    "last_sync_time": sync_row['last_sync_time'] if sync_row else None,
                    "last_sync_count": sync_row['last_sync_count'] if sync_row else 0
                }
        except Exception as e:
            logger.error(f"获取数据库统计信息失败: {e}")
            return {
                "total_images": 0,
                "cnn_features": 0,
                "orb_features": 0,
                "last_sync_time": None,
                "last_sync_count": 0
            }