#include "SQLiteVecManager.h" #include #include #include #include SQLiteVecManager::SQLiteVecManager(const std::string& databaseName) : dbName(databaseName), db(nullptr), useVecExtension(false) { int rc = sqlite3_open(dbName.c_str(), &db); if (rc) { throw std::runtime_error("Can't open database: " + std::string(sqlite3_errmsg(db))); } // 初始化 sqlite-vec 扩展 char* errMsg = 0; rc = sqlite3_vec_init(db, &errMsg, 0); if (rc != SQLITE_OK) { std::cerr << "初始化 sqlite-vec 扩展失败: " << (errMsg ? errMsg : "未知错误") << std::endl; if (errMsg) sqlite3_free(errMsg); sqlite3_close(db); throw std::runtime_error("Can't 初始化 sqlite-vec: " + std::string(sqlite3_errmsg(db))); } std::cout << "sqlite-vec 扩展初始化成功!" << std::endl; useVecExtension = true; } SQLiteVecManager::~SQLiteVecManager() { if (db) { sqlite3_close(db); } } bool SQLiteVecManager::initializeDatabase(int vectorDimension) { int rc; char* errMsg = 0; if (useVecExtension == true) { std::cout << "使用sqlite-vec扩展进行向量存储和搜索" << std::endl; // 创建vec0虚拟表 std::string sql = "CREATE VIRTUAL TABLE IF NOT EXISTS image_features USING vec0(" "id INTEGER PRIMARY KEY AUTOINCREMENT," "image_path TEXT UNIQUE NOT NULL," "feature_vector FLOAT[" + std::to_string(vectorDimension) + "]);"; rc = sqlite3_exec(db, sql.c_str(), 0, 0, &errMsg); if (rc != SQLITE_OK) { std::cerr << "Failed to create vec0 table: " << errMsg << std::endl; sqlite3_free(errMsg); return false; } } else { // 使用传统表结构 std::cout << "使用传统表结构扩展进行向量存储和搜索" << std::endl; sqlite3_free(errMsg); const char* sql = R"( CREATE TABLE IF NOT EXISTS image_features ( id INTEGER PRIMARY KEY AUTOINCREMENT, image_path TEXT UNIQUE NOT NULL, feature_vector BLOB NOT NULL ); )"; rc = sqlite3_exec(db, sql, 0, 0, &errMsg); if (rc != SQLITE_OK) { std::cerr << "SQL error: " << errMsg << std::endl; sqlite3_free(errMsg); return false; } } return true; } bool SQLiteVecManager::addFeatureVector(const std::vector& features, const std::string& imagePath) { if (useVecExtension) { const char* sql = "INSERT INTO image_features(id, image_path, feature_vector) VALUES (?, ?, vec_f32(?));"; sqlite3_stmt* stmt; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { return false; } sqlite3_bind_int(stmt, 1, static_cast(getFeatureCount() + 1)); sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC); std::string blobData = vectorToBlob(features); sqlite3_bind_blob(stmt, 3, blobData.data(), static_cast(blobData.size()), SQLITE_STATIC); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); return rc == SQLITE_DONE; } else { const char* sql = "INSERT OR REPLACE INTO image_features (id, image_path, feature_vector) VALUES (?, ?, ?);"; sqlite3_stmt* stmt; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { return false; } sqlite3_bind_int(stmt, 1, static_cast(getFeatureCount() + 1)); sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC); std::string blobData = vectorToBlob(features); sqlite3_bind_blob(stmt, 3, blobData.data(), static_cast(blobData.size()), SQLITE_STATIC); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); return rc == SQLITE_DONE; } } std::vector> SQLiteVecManager::searchSimilarVectors(const std::vector& queryVector, int k) { std::vector> results; if (useVecExtension) { std::cout << "使用sqlite-vec扩展进行向量搜索" << std::endl; // 使用sqlite-vec的向量搜索功能 std::string blobData = vectorToBlob(queryVector); std::string sql = "SELECT image_path, distance FROM image_features " "WHERE feature_vector MATCH vec_f32(?) " "ORDER BY distance " "LIMIT " + std::to_string(k) + ";"; sqlite3_stmt* stmt; int rc = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, NULL); if (rc != SQLITE_OK) { return results; } sqlite3_bind_blob(stmt, 1, blobData.data(), static_cast(blobData.size()), SQLITE_STATIC); while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { const char* imagePath = reinterpret_cast(sqlite3_column_text(stmt, 0)); float distance = static_cast(sqlite3_column_double(stmt, 1)); float similarity = distanceToSimilarity(distance); results.emplace_back(std::string(imagePath), similarity); } sqlite3_finalize(stmt); } else { std::cout << "使用传统方法进行向量搜索" << std::endl; // 使用传统方式计算相似度 const char* sql = "SELECT image_path, feature_vector FROM image_features;"; sqlite3_stmt* stmt; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { return results; } // 计算查询向量的模长 float queryNorm = 0.0f; for (float val : queryVector) { queryNorm += val * val; } queryNorm = sqrt(queryNorm); while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { const char* imagePath = reinterpret_cast(sqlite3_column_text(stmt, 0)); const void* blobData = sqlite3_column_blob(stmt, 1); int blobSize = sqlite3_column_bytes(stmt, 1); std::string blobStr(static_cast(blobData), blobSize); std::vector storedFeatures = blobToVector(blobStr); if (storedFeatures.size() == queryVector.size()) { // 计算余弦相似度 float dotProduct = 0.0f; float storedNorm = 0.0f; for (size_t i = 0; i < queryVector.size(); ++i) { dotProduct += queryVector[i] * storedFeatures[i]; storedNorm += storedFeatures[i] * storedFeatures[i]; } storedNorm = sqrt(storedNorm); float similarity = 0.0f; if (queryNorm > 0.0f && storedNorm > 0.0f) { similarity = dotProduct / (queryNorm * storedNorm); } results.emplace_back(std::string(imagePath), similarity); } } sqlite3_finalize(stmt); // 按相似度降序排列 std::sort(results.begin(), results.end(), [](const auto& a, const auto& b) { return a.second > b.second; }); // 返回topK结果 if (static_cast(results.size()) > k) { results.resize(k); } } return results; } void SQLiteVecManager::saveDatabase() { sqlite3_exec(db, "PRAGMA optimize;", 0, 0, 0); } bool SQLiteVecManager::loadDatabase() { const char* sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='image_features';"; sqlite3_stmt* stmt; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) { sqlite3_finalize(stmt); return true; } sqlite3_finalize(stmt); return false; } int SQLiteVecManager::getFeatureCount() const { const char* sql = "SELECT COUNT(*) FROM image_features;"; sqlite3_stmt* stmt; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) { int count = sqlite3_column_int(stmt, 0); sqlite3_finalize(stmt); return count; } sqlite3_finalize(stmt); return 0; } bool SQLiteVecManager::isEmpty() const { return getFeatureCount() == 0; } std::string SQLiteVecManager::vectorToBlob(const std::vector& vec) { return std::string(reinterpret_cast(vec.data()), vec.size() * sizeof(float)); } std::vector SQLiteVecManager::blobToVector(const std::string& blob) { const float* data = reinterpret_cast(blob.data()); size_t count = blob.size() / sizeof(float); return std::vector(data, data + count); } float SQLiteVecManager::calculateCosineSimilarity(const std::vector& vec1, const std::vector& vec2) { if (vec1.size() != vec2.size() || vec1.empty()) { return 0.0f; } float dotProduct = 0.0f; float norm1 = 0.0f; float norm2 = 0.0f; for (size_t i = 0; i < vec1.size(); ++i) { dotProduct += vec1[i] * vec2[i]; norm1 += vec1[i] * vec1[i]; norm2 += vec2[i] * vec2[i]; } if (norm1 == 0.0f || norm2 == 0.0f) { return 0.0f; } return dotProduct / (std::sqrt(norm1) * std::sqrt(norm2)); } float SQLiteVecManager::distanceToSimilarity(float distance) { return 1.0f - distance; }