SQLiteVecManager.cpp 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #include "../pch/pch.h"
  2. #include "SQLiteVecManager.h"
  3. #include <stdexcept>
  4. #include <iostream>
  5. #include <cmath>
  6. #include <algorithm>
  7. #include "../tool/debuglog.h"
  8. SQLiteVecManager::SQLiteVecManager(const std::string & databaseName) : dbName(databaseName), db(nullptr)
  9. {
  10. int rc = SQLITE_OK;
  11. sqlite3_stmt* stmt;
  12. rc = sqlite3_open(dbName.c_str(), &db);
  13. assert(rc == SQLITE_OK);
  14. if (rc != SQLITE_OK)
  15. {
  16. std::string err = "Can't open database: " + std::string(sqlite3_errmsg(db));
  17. throw std::runtime_error("Can't open database: " + std::string(sqlite3_errmsg(db)));
  18. }
  19. // 初始化 sqlite-vec 扩展
  20. char* errMsg = 0;
  21. rc = sqlite3_vec_init(db, &errMsg, 0);
  22. assert(rc == SQLITE_OK);
  23. if (rc != SQLITE_OK)
  24. {
  25. std::string err = "Can't init vec: " + std::string(sqlite3_errmsg(db));
  26. throw std::runtime_error("Can't init vec: " + std::string(sqlite3_errmsg(db)));
  27. }
  28. rc = sqlite3_prepare_v2(db, "SELECT sqlite_version(), vec_version()", -1, &stmt, NULL);
  29. assert(rc == SQLITE_OK);
  30. rc = sqlite3_step(stmt);
  31. DEBUG_HELPER::debug_printf("sqlite_version=%s, vec_version=%s\n", sqlite3_column_text(stmt, 0), sqlite3_column_text(stmt, 1));
  32. sqlite3_finalize(stmt);
  33. }
  34. SQLiteVecManager::~SQLiteVecManager()
  35. {
  36. if (db)
  37. {
  38. sqlite3_close(db);
  39. }
  40. }
  41. bool SQLiteVecManager::initializeDatabase(int vectorDimension)
  42. {
  43. int rc;
  44. char * errMsg = 0;
  45. std::cout << "使用sqlite-vec扩展进行向量存储和搜索" << std::endl;
  46. // 创建vec0虚拟表
  47. std::string sql = R"(CREATE VIRTUAL TABLE IF NOT EXISTS image_features USING vec0(
  48. id INTEGER PRIMARY KEY AUTOINCREMENT,
  49. image_path TEXT UNIQUE NOT NULL,
  50. feature_vector FLOAT[)" + std::to_string(vectorDimension) + "] distance_metric=cosine)";
  51. rc = sqlite3_exec(db, sql.c_str(), 0, 0, &errMsg);
  52. if (rc != SQLITE_OK)
  53. {
  54. std::string err = std::string("Failed to create vec0 table: ") + errMsg;
  55. sqlite3_free(errMsg);
  56. return false;
  57. }
  58. return true;
  59. }
  60. bool SQLiteVecManager::addFeatureVector(const std::vector<float> & features, const std::string & imagePath)
  61. {
  62. const char * sql = "INSERT INTO image_features(id, image_path, feature_vector) VALUES (?, ?, ?);";
  63. sqlite3_stmt * stmt;
  64. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  65. if (rc != SQLITE_OK)
  66. {
  67. std::string err = sqlite3_errmsg(db);
  68. std::cerr << "Failed to insert feature vector: " << err << std::endl;
  69. return false;
  70. }
  71. sqlite3_bind_int(stmt, 1, static_cast<int>(getFeatureCount() + 1));
  72. sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC);
  73. //std::string blobData = vectorToBlob(features);
  74. sqlite3_bind_blob(stmt, 3, features.data(), features.size() * sizeof(float), SQLITE_STATIC);
  75. rc = sqlite3_step(stmt);
  76. if (rc != SQLITE_DONE) {
  77. std::string err = sqlite3_errmsg(db);
  78. std::cerr << "Failed to insert feature vector: " << err << std::endl;
  79. sqlite3_finalize(stmt);
  80. return false;
  81. }
  82. sqlite3_finalize(stmt);
  83. return rc == SQLITE_DONE;
  84. }
  85. std::vector<std::pair<std::string, float>> SQLiteVecManager::searchSimilarVectors(const std::vector<float> & queryVector, int k)
  86. {
  87. std::vector<std::pair<std::string, float>> results;
  88. std::cout << "使用sqlite-vec扩展进行向量搜索" << std::endl;
  89. // 使用sqlite-vec的向量搜索功能
  90. std::string blobData = vectorToBlob(queryVector);
  91. std::string sql =
  92. "SELECT image_path, distance "
  93. "FROM image_features "
  94. "WHERE feature_vector MATCH ?1 "
  95. "ORDER BY distance "
  96. "LIMIT " + std::to_string(k) + ";";
  97. sqlite3_stmt * stmt;
  98. int rc = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, NULL);
  99. if (rc != SQLITE_OK)
  100. {
  101. std::string err = sqlite3_errmsg(db);
  102. std::cerr << "Failed to insert feature vector: " << err << std::endl;
  103. return results;
  104. }
  105. sqlite3_bind_blob(stmt, 1, queryVector.data(), queryVector.size() * sizeof(float), SQLITE_STATIC);
  106. while ((rc = sqlite3_step(stmt)) == SQLITE_ROW)
  107. {
  108. const char * imagePath = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 0));
  109. float distance = static_cast<float>(sqlite3_column_double(stmt, 1));
  110. float similarity = distanceToSimilarity(distance);
  111. results.emplace_back(std::string(imagePath), similarity);
  112. }
  113. sqlite3_finalize(stmt);
  114. return results;
  115. }
  116. void SQLiteVecManager::saveDatabase()
  117. {
  118. sqlite3_exec(db, "PRAGMA optimize;", 0, 0, 0);
  119. }
  120. bool SQLiteVecManager::loadDatabase()
  121. {
  122. const char * sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='image_features';";
  123. sqlite3_stmt * stmt;
  124. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  125. if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW)
  126. {
  127. sqlite3_finalize(stmt);
  128. return true;
  129. }
  130. sqlite3_finalize(stmt);
  131. return false;
  132. }
  133. int SQLiteVecManager::getFeatureCount() const
  134. {
  135. const char * sql = "SELECT COUNT(*) FROM image_features;";
  136. sqlite3_stmt * stmt;
  137. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  138. if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW)
  139. {
  140. int count = sqlite3_column_int(stmt, 0);
  141. sqlite3_finalize(stmt);
  142. return count;
  143. }
  144. sqlite3_finalize(stmt);
  145. return 0;
  146. }
  147. bool SQLiteVecManager::isEmpty() const
  148. {
  149. return getFeatureCount() == 0;
  150. }
  151. std::string SQLiteVecManager::vectorToBlob(const std::vector<float> & vec)
  152. {
  153. std::string vecStr;
  154. for (std::size_t i = 0; i < vec.size(); i++) {
  155. vecStr += std::to_string(vec[i]);
  156. if (i != vec.size() - 1) {
  157. vecStr += ",";
  158. }
  159. }
  160. return vecStr;
  161. }
  162. std::vector<float> SQLiteVecManager::blobToVector(const std::string & blob)
  163. {
  164. const float * data = reinterpret_cast<const float *>(blob.data());
  165. size_t count = blob.size() / sizeof(float);
  166. return std::vector<float>(data, data + count);
  167. }
  168. float SQLiteVecManager::calculateCosineSimilarity(const std::vector<float> & vec1, const std::vector<float> & vec2)
  169. {
  170. if (vec1.size() != vec2.size() || vec1.empty())
  171. {
  172. return 0.0f;
  173. }
  174. float dotProduct = 0.0f;
  175. float norm1 = 0.0f;
  176. float norm2 = 0.0f;
  177. for (size_t i = 0; i < vec1.size(); ++i)
  178. {
  179. dotProduct += vec1[i] * vec2[i];
  180. norm1 += vec1[i] * vec1[i];
  181. norm2 += vec2[i] * vec2[i];
  182. }
  183. if (norm1 == 0.0f || norm2 == 0.0f)
  184. {
  185. return 0.0f;
  186. }
  187. return dotProduct / (std::sqrt(norm1) * std::sqrt(norm2));
  188. }
  189. float SQLiteVecManager::distanceToSimilarity(float distance)
  190. {
  191. return 1.0f - distance;
  192. }