| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
-
- #include "SQLiteVecManager.h"
- #include <stdexcept>
- #include <iostream>
- #include <cmath>
- #include <algorithm>
- SQLiteVecManager::SQLiteVecManager(const std::string& databaseName) : dbName(databaseName), db(nullptr), useVecExtension(false) {
- int rc = sqlite3_open(dbName.c_str(), &db);
- if (rc) {
- throw std::runtime_error("Can't open database: " + std::string(sqlite3_errmsg(db)));
- }
- // 初始化 sqlite-vec 扩展
- char* errMsg = 0;
- rc = sqlite3_vec_init(db, &errMsg, 0);
- if (rc != SQLITE_OK) {
- std::cerr << "初始化 sqlite-vec 扩展失败: " << (errMsg ? errMsg : "未知错误") << std::endl;
- if (errMsg) sqlite3_free(errMsg);
- sqlite3_close(db);
-
- throw std::runtime_error("Can't 初始化 sqlite-vec: " + std::string(sqlite3_errmsg(db)));
- }
- std::cout << "sqlite-vec 扩展初始化成功!" << std::endl;
- useVecExtension = true;
- }
- SQLiteVecManager::~SQLiteVecManager() {
- if (db) {
- sqlite3_close(db);
- }
- }
- bool SQLiteVecManager::initializeDatabase(int vectorDimension) {
- int rc;
- char* errMsg = 0;
-
- if (useVecExtension == true) {
- std::cout << "使用sqlite-vec扩展进行向量存储和搜索" << std::endl;
- // 创建vec0虚拟表
- std::string sql = "CREATE VIRTUAL TABLE IF NOT EXISTS image_features USING vec0("
- "id INTEGER PRIMARY KEY AUTOINCREMENT,"
- "image_path TEXT UNIQUE NOT NULL,"
- "feature_vector FLOAT[" + std::to_string(vectorDimension) + "]);";
-
- rc = sqlite3_exec(db, sql.c_str(), 0, 0, &errMsg);
- if (rc != SQLITE_OK) {
- std::cerr << "Failed to create vec0 table: " << errMsg << std::endl;
- sqlite3_free(errMsg);
- return false;
- }
- } else {
- // 使用传统表结构
- std::cout << "使用传统表结构扩展进行向量存储和搜索" << std::endl;
- sqlite3_free(errMsg);
- const char* sql = R"(
- CREATE TABLE IF NOT EXISTS image_features (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- image_path TEXT UNIQUE NOT NULL,
- feature_vector BLOB NOT NULL
- );
- )";
- rc = sqlite3_exec(db, sql, 0, 0, &errMsg);
- if (rc != SQLITE_OK) {
- std::cerr << "SQL error: " << errMsg << std::endl;
- sqlite3_free(errMsg);
- return false;
- }
- }
-
- return true;
- }
- bool SQLiteVecManager::addFeatureVector(const std::vector<float>& features, const std::string& imagePath) {
- if (useVecExtension) {
- const char* sql = "INSERT INTO image_features(id, image_path, feature_vector) VALUES (?, ?, vec_f32(?));";
-
- sqlite3_stmt* stmt;
- int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
- if (rc != SQLITE_OK) {
- return false;
- }
-
- sqlite3_bind_int(stmt, 1, static_cast<int>(getFeatureCount() + 1));
- sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC);
-
- std::string blobData = vectorToBlob(features);
- sqlite3_bind_blob(stmt, 3, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
-
- rc = sqlite3_step(stmt);
- sqlite3_finalize(stmt);
-
- return rc == SQLITE_DONE;
- } else {
- const char* sql = "INSERT OR REPLACE INTO image_features (id, image_path, feature_vector) VALUES (?, ?, ?);";
-
- sqlite3_stmt* stmt;
- int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
- if (rc != SQLITE_OK) {
- return false;
- }
-
- sqlite3_bind_int(stmt, 1, static_cast<int>(getFeatureCount() + 1));
- sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC);
-
- std::string blobData = vectorToBlob(features);
- sqlite3_bind_blob(stmt, 3, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
-
- rc = sqlite3_step(stmt);
- sqlite3_finalize(stmt);
-
- return rc == SQLITE_DONE;
- }
- }
- std::vector<std::pair<std::string, float>> SQLiteVecManager::searchSimilarVectors(const std::vector<float>& queryVector, int k) {
- std::vector<std::pair<std::string, float>> results;
-
- if (useVecExtension) {
- std::cout << "使用sqlite-vec扩展进行向量搜索" << std::endl;
- // 使用sqlite-vec的向量搜索功能
- std::string blobData = vectorToBlob(queryVector);
-
- std::string sql = "SELECT image_path, distance FROM image_features "
- "WHERE feature_vector MATCH vec_f32(?) "
- "ORDER BY distance "
- "LIMIT " + std::to_string(k) + ";";
-
- sqlite3_stmt* stmt;
- int rc = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, NULL);
- if (rc != SQLITE_OK) {
- return results;
- }
-
- sqlite3_bind_blob(stmt, 1, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
-
- while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
- const char* imagePath = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
- float distance = static_cast<float>(sqlite3_column_double(stmt, 1));
- float similarity = distanceToSimilarity(distance);
- results.emplace_back(std::string(imagePath), similarity);
- }
-
- sqlite3_finalize(stmt);
- } else {
- std::cout << "使用传统方法进行向量搜索" << std::endl;
- // 使用传统方式计算相似度
- const char* sql = "SELECT image_path, feature_vector FROM image_features;";
- sqlite3_stmt* stmt;
- int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
-
- if (rc != SQLITE_OK) {
- return results;
- }
-
- // 计算查询向量的模长
- float queryNorm = 0.0f;
- for (float val : queryVector) {
- queryNorm += val * val;
- }
- queryNorm = sqrt(queryNorm);
-
- while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
- const char* imagePath = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
- const void* blobData = sqlite3_column_blob(stmt, 1);
- int blobSize = sqlite3_column_bytes(stmt, 1);
-
- std::string blobStr(static_cast<const char*>(blobData), blobSize);
- std::vector<float> storedFeatures = blobToVector(blobStr);
-
- if (storedFeatures.size() == queryVector.size()) {
- // 计算余弦相似度
- float dotProduct = 0.0f;
- float storedNorm = 0.0f;
-
- for (size_t i = 0; i < queryVector.size(); ++i) {
- dotProduct += queryVector[i] * storedFeatures[i];
- storedNorm += storedFeatures[i] * storedFeatures[i];
- }
-
- storedNorm = sqrt(storedNorm);
-
- float similarity = 0.0f;
- if (queryNorm > 0.0f && storedNorm > 0.0f) {
- similarity = dotProduct / (queryNorm * storedNorm);
- }
-
- results.emplace_back(std::string(imagePath), similarity);
- }
- }
-
- sqlite3_finalize(stmt);
-
- // 按相似度降序排列
- std::sort(results.begin(), results.end(),
- [](const auto& a, const auto& b) { return a.second > b.second; });
-
- // 返回topK结果
- if (static_cast<int>(results.size()) > k) {
- results.resize(k);
- }
- }
-
- return results;
- }
- void SQLiteVecManager::saveDatabase() {
- sqlite3_exec(db, "PRAGMA optimize;", 0, 0, 0);
- }
- bool SQLiteVecManager::loadDatabase() {
- const char* sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='image_features';";
- sqlite3_stmt* stmt;
- int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
-
- if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) {
- sqlite3_finalize(stmt);
- return true;
- }
-
- sqlite3_finalize(stmt);
- return false;
- }
- int SQLiteVecManager::getFeatureCount() const {
- const char* sql = "SELECT COUNT(*) FROM image_features;";
- sqlite3_stmt* stmt;
- int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
-
- if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) {
- int count = sqlite3_column_int(stmt, 0);
- sqlite3_finalize(stmt);
- return count;
- }
-
- sqlite3_finalize(stmt);
- return 0;
- }
- bool SQLiteVecManager::isEmpty() const {
- return getFeatureCount() == 0;
- }
- std::string SQLiteVecManager::vectorToBlob(const std::vector<float>& vec) {
- return std::string(reinterpret_cast<const char*>(vec.data()), vec.size() * sizeof(float));
- }
- std::vector<float> SQLiteVecManager::blobToVector(const std::string& blob) {
- const float* data = reinterpret_cast<const float*>(blob.data());
- size_t count = blob.size() / sizeof(float);
- return std::vector<float>(data, data + count);
- }
- float SQLiteVecManager::calculateCosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2) {
- if (vec1.size() != vec2.size() || vec1.empty()) {
- return 0.0f;
- }
-
- float dotProduct = 0.0f;
- float norm1 = 0.0f;
- float norm2 = 0.0f;
-
- for (size_t i = 0; i < vec1.size(); ++i) {
- dotProduct += vec1[i] * vec2[i];
- norm1 += vec1[i] * vec1[i];
- norm2 += vec2[i] * vec2[i];
- }
-
- if (norm1 == 0.0f || norm2 == 0.0f) {
- return 0.0f;
- }
-
- return dotProduct / (std::sqrt(norm1) * std::sqrt(norm2));
- }
- float SQLiteVecManager::distanceToSimilarity(float distance) {
- return 1.0f - distance;
- }
|