3C134381851B082DA38E7880464684BF 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. 
  2. #include "SQLiteVecManager.h"
  3. #include <stdexcept>
  4. #include <iostream>
  5. #include <cmath>
  6. #include <algorithm>
  7. SQLiteVecManager::SQLiteVecManager(const std::string& databaseName) : dbName(databaseName), db(nullptr), useVecExtension(false) {
  8. int rc = sqlite3_open(dbName.c_str(), &db);
  9. if (rc) {
  10. throw std::runtime_error("Can't open database: " + std::string(sqlite3_errmsg(db)));
  11. }
  12. }
  13. SQLiteVecManager::~SQLiteVecManager() {
  14. if (db) {
  15. sqlite3_close(db);
  16. }
  17. }
  18. bool SQLiteVecManager::initializeDatabase(int vectorDimension) {
  19. char* errMsg = 0;
  20. // 尝试启用sqlite-vec扩展
  21. const char* enableVec = "SELECT load_extension('C:/Users/27920/Desktop/test/dll/release/vec0.dll');";
  22. int rc = sqlite3_exec(db, enableVec, 0, 0, &errMsg);
  23. if (rc == SQLITE_OK) {
  24. std::cerr << "success to create vec0 table: " << std::endl;
  25. useVecExtension = true;
  26. // 创建vec0虚拟表
  27. std::string sql = "CREATE VIRTUAL TABLE IF NOT EXISTS image_features USING vec0("
  28. "id INTEGER PRIMARY KEY, "
  29. "image_path TEXT, "
  30. "feature_vector FLOAT[" + std::to_string(vectorDimension) + "]);";
  31. rc = sqlite3_exec(db, sql.c_str(), 0, 0, &errMsg);
  32. if (rc != SQLITE_OK) {
  33. std::cerr << "Failed to create vec0 table: " << errMsg << std::endl;
  34. sqlite3_free(errMsg);
  35. return false;
  36. }
  37. } else {
  38. // 使用传统表结构
  39. sqlite3_free(errMsg);
  40. const char* sql = R"(
  41. CREATE TABLE IF NOT EXISTS image_features (
  42. id INTEGER PRIMARY KEY AUTOINCREMENT,
  43. image_path TEXT UNIQUE NOT NULL,
  44. feature_vector BLOB NOT NULL
  45. );
  46. )";
  47. rc = sqlite3_exec(db, sql, 0, 0, &errMsg);
  48. if (rc != SQLITE_OK) {
  49. std::cerr << "SQL error: " << errMsg << std::endl;
  50. sqlite3_free(errMsg);
  51. return false;
  52. }
  53. }
  54. return true;
  55. }
  56. bool SQLiteVecManager::addFeatureVector(const std::vector<float>& features, const std::string& imagePath) {
  57. if (useVecExtension) {
  58. const char* sql = "INSERT INTO image_features(id, image_path, feature_vector) VALUES (?, ?, vec_f32(?));";
  59. sqlite3_stmt* stmt;
  60. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  61. if (rc != SQLITE_OK) {
  62. return false;
  63. }
  64. sqlite3_bind_int(stmt, 1, static_cast<int>(getFeatureCount() + 1));
  65. sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC);
  66. std::string blobData = vectorToBlob(features);
  67. sqlite3_bind_blob(stmt, 3, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
  68. rc = sqlite3_step(stmt);
  69. sqlite3_finalize(stmt);
  70. return rc == SQLITE_DONE;
  71. } else {
  72. const char* sql = "INSERT OR REPLACE INTO image_features (id, image_path, feature_vector) VALUES (?, ?, ?);";
  73. sqlite3_stmt* stmt;
  74. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  75. if (rc != SQLITE_OK) {
  76. return false;
  77. }
  78. sqlite3_bind_int(stmt, 1, static_cast<int>(getFeatureCount() + 1));
  79. sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC);
  80. std::string blobData = vectorToBlob(features);
  81. sqlite3_bind_blob(stmt, 3, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
  82. rc = sqlite3_step(stmt);
  83. sqlite3_finalize(stmt);
  84. return rc == SQLITE_DONE;
  85. }
  86. }
  87. std::vector<std::pair<std::string, float>> SQLiteVecManager::searchSimilarVectors(const std::vector<float>& queryVector, int k) {
  88. std::vector<std::pair<std::string, float>> results;
  89. if (useVecExtension) {
  90. // 使用sqlite-vec的向量搜索功能
  91. std::string blobData = vectorToBlob(queryVector);
  92. std::string sql = "SELECT image_path, distance FROM image_features "
  93. "WHERE feature_vector MATCH vec_f32(?) "
  94. "ORDER BY distance "
  95. "LIMIT " + std::to_string(k) + ";";
  96. sqlite3_stmt* stmt;
  97. int rc = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, NULL);
  98. if (rc != SQLITE_OK) {
  99. return results;
  100. }
  101. sqlite3_bind_blob(stmt, 1, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
  102. while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
  103. const char* imagePath = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
  104. float distance = static_cast<float>(sqlite3_column_double(stmt, 1));
  105. results.emplace_back(std::string(imagePath), distance);
  106. }
  107. sqlite3_finalize(stmt);
  108. } else {
  109. // 使用传统方式计算相似度
  110. const char* sql = "SELECT image_path, feature_vector FROM image_features;";
  111. sqlite3_stmt* stmt;
  112. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  113. if (rc != SQLITE_OK) {
  114. return results;
  115. }
  116. // 计算查询向量的模长
  117. float queryNorm = 0.0f;
  118. for (float val : queryVector) {
  119. queryNorm += val * val;
  120. }
  121. queryNorm = sqrt(queryNorm);
  122. while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
  123. const char* imagePath = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
  124. const void* blobData = sqlite3_column_blob(stmt, 1);
  125. int blobSize = sqlite3_column_bytes(stmt, 1);
  126. std::string blobStr(static_cast<const char*>(blobData), blobSize);
  127. std::vector<float> storedFeatures = blobToVector(blobStr);
  128. if (storedFeatures.size() == queryVector.size()) {
  129. // 计算余弦相似度
  130. float dotProduct = 0.0f;
  131. float storedNorm = 0.0f;
  132. for (size_t i = 0; i < queryVector.size(); ++i) {
  133. dotProduct += queryVector[i] * storedFeatures[i];
  134. storedNorm += storedFeatures[i] * storedFeatures[i];
  135. }
  136. storedNorm = sqrt(storedNorm);
  137. float similarity = 0.0f;
  138. if (queryNorm > 0.0f && storedNorm > 0.0f) {
  139. similarity = dotProduct / (queryNorm * storedNorm);
  140. }
  141. results.emplace_back(std::string(imagePath), similarity);
  142. }
  143. }
  144. sqlite3_finalize(stmt);
  145. // 按相似度降序排列
  146. std::sort(results.begin(), results.end(),
  147. [](const auto& a, const auto& b) { return a.second > b.second; });
  148. // 返回topK结果
  149. if (static_cast<int>(results.size()) > k) {
  150. results.resize(k);
  151. }
  152. }
  153. return results;
  154. }
  155. void SQLiteVecManager::saveDatabase() {
  156. sqlite3_exec(db, "PRAGMA optimize;", 0, 0, 0);
  157. }
  158. bool SQLiteVecManager::loadDatabase() {
  159. const char* sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='image_features';";
  160. sqlite3_stmt* stmt;
  161. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  162. if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) {
  163. sqlite3_finalize(stmt);
  164. return true;
  165. }
  166. sqlite3_finalize(stmt);
  167. return false;
  168. }
  169. int SQLiteVecManager::getFeatureCount() const {
  170. const char* sql = "SELECT COUNT(*) FROM image_features;";
  171. sqlite3_stmt* stmt;
  172. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  173. if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) {
  174. int count = sqlite3_column_int(stmt, 0);
  175. sqlite3_finalize(stmt);
  176. return count;
  177. }
  178. sqlite3_finalize(stmt);
  179. return 0;
  180. }
  181. bool SQLiteVecManager::isEmpty() const {
  182. return getFeatureCount() == 0;
  183. }
  184. std::string SQLiteVecManager::vectorToBlob(const std::vector<float>& vec) {
  185. return std::string(reinterpret_cast<const char*>(vec.data()), vec.size() * sizeof(float));
  186. }
  187. std::vector<float> SQLiteVecManager::blobToVector(const std::string& blob) {
  188. const float* data = reinterpret_cast<const float*>(blob.data());
  189. size_t count = blob.size() / sizeof(float);
  190. return std::vector<float>(data, data + count);
  191. }
  192. float SQLiteVecManager::calculateCosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2) {
  193. if (vec1.size() != vec2.size() || vec1.empty()) {
  194. return 0.0f;
  195. }
  196. float dotProduct = 0.0f;
  197. float norm1 = 0.0f;
  198. float norm2 = 0.0f;
  199. for (size_t i = 0; i < vec1.size(); ++i) {
  200. dotProduct += vec1[i] * vec2[i];
  201. norm1 += vec1[i] * vec1[i];
  202. norm2 += vec2[i] * vec2[i];
  203. }
  204. if (norm1 == 0.0f || norm2 == 0.0f) {
  205. return 0.0f;
  206. }
  207. return dotProduct / (std::sqrt(norm1) * std::sqrt(norm2));
  208. }