SQLiteVecManager.cpp 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. 
  2. #include "SQLiteVecManager.h"
  3. #include <stdexcept>
  4. #include <iostream>
  5. #include <cmath>
  6. #include <algorithm>
  7. SQLiteVecManager::SQLiteVecManager(const std::string& databaseName) : dbName(databaseName), db(nullptr), useVecExtension(false) {
  8. int rc = sqlite3_open(dbName.c_str(), &db);
  9. if (rc) {
  10. throw std::runtime_error("Can't open database: " + std::string(sqlite3_errmsg(db)));
  11. }
  12. // 初始化 sqlite-vec 扩展
  13. char* errMsg = 0;
  14. rc = sqlite3_vec_init(db, &errMsg, 0);
  15. if (rc != SQLITE_OK) {
  16. std::cerr << "初始化 sqlite-vec 扩展失败: " << (errMsg ? errMsg : "未知错误") << std::endl;
  17. if (errMsg) sqlite3_free(errMsg);
  18. sqlite3_close(db);
  19. throw std::runtime_error("Can't 初始化 sqlite-vec: " + std::string(sqlite3_errmsg(db)));
  20. }
  21. std::cout << "sqlite-vec 扩展初始化成功!" << std::endl;
  22. useVecExtension = true;
  23. }
  24. SQLiteVecManager::~SQLiteVecManager() {
  25. if (db) {
  26. sqlite3_close(db);
  27. }
  28. }
  29. bool SQLiteVecManager::initializeDatabase(int vectorDimension) {
  30. int rc;
  31. char* errMsg = 0;
  32. if (useVecExtension == true) {
  33. std::cout << "使用sqlite-vec扩展进行向量存储和搜索" << std::endl;
  34. // 创建vec0虚拟表
  35. std::string sql = "CREATE VIRTUAL TABLE IF NOT EXISTS image_features USING vec0("
  36. "id INTEGER PRIMARY KEY AUTOINCREMENT,"
  37. "image_path TEXT UNIQUE NOT NULL,"
  38. "feature_vector FLOAT[" + std::to_string(vectorDimension) + "]);";
  39. rc = sqlite3_exec(db, sql.c_str(), 0, 0, &errMsg);
  40. if (rc != SQLITE_OK) {
  41. std::cerr << "Failed to create vec0 table: " << errMsg << std::endl;
  42. sqlite3_free(errMsg);
  43. return false;
  44. }
  45. } else {
  46. // 使用传统表结构
  47. std::cout << "使用传统表结构扩展进行向量存储和搜索" << std::endl;
  48. sqlite3_free(errMsg);
  49. const char* sql = R"(
  50. CREATE TABLE IF NOT EXISTS image_features (
  51. id INTEGER PRIMARY KEY AUTOINCREMENT,
  52. image_path TEXT UNIQUE NOT NULL,
  53. feature_vector BLOB NOT NULL
  54. );
  55. )";
  56. rc = sqlite3_exec(db, sql, 0, 0, &errMsg);
  57. if (rc != SQLITE_OK) {
  58. std::cerr << "SQL error: " << errMsg << std::endl;
  59. sqlite3_free(errMsg);
  60. return false;
  61. }
  62. }
  63. return true;
  64. }
  65. bool SQLiteVecManager::addFeatureVector(const std::vector<float>& features, const std::string& imagePath) {
  66. if (useVecExtension) {
  67. const char* sql = "INSERT INTO image_features(id, image_path, feature_vector) VALUES (?, ?, vec_f32(?));";
  68. sqlite3_stmt* stmt;
  69. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  70. if (rc != SQLITE_OK) {
  71. return false;
  72. }
  73. sqlite3_bind_int(stmt, 1, static_cast<int>(getFeatureCount() + 1));
  74. sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC);
  75. std::string blobData = vectorToBlob(features);
  76. sqlite3_bind_blob(stmt, 3, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
  77. rc = sqlite3_step(stmt);
  78. sqlite3_finalize(stmt);
  79. return rc == SQLITE_DONE;
  80. } else {
  81. const char* sql = "INSERT OR REPLACE INTO image_features (id, image_path, feature_vector) VALUES (?, ?, ?);";
  82. sqlite3_stmt* stmt;
  83. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  84. if (rc != SQLITE_OK) {
  85. return false;
  86. }
  87. sqlite3_bind_int(stmt, 1, static_cast<int>(getFeatureCount() + 1));
  88. sqlite3_bind_text(stmt, 2, imagePath.c_str(), -1, SQLITE_STATIC);
  89. std::string blobData = vectorToBlob(features);
  90. sqlite3_bind_blob(stmt, 3, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
  91. rc = sqlite3_step(stmt);
  92. sqlite3_finalize(stmt);
  93. return rc == SQLITE_DONE;
  94. }
  95. }
  96. std::vector<std::pair<std::string, float>> SQLiteVecManager::searchSimilarVectors(const std::vector<float>& queryVector, int k) {
  97. std::vector<std::pair<std::string, float>> results;
  98. if (useVecExtension) {
  99. std::cout << "使用sqlite-vec扩展进行向量搜索" << std::endl;
  100. // 使用sqlite-vec的向量搜索功能
  101. std::string blobData = vectorToBlob(queryVector);
  102. std::string sql = "SELECT image_path, distance FROM image_features "
  103. "WHERE feature_vector MATCH vec_f32(?) "
  104. "ORDER BY distance "
  105. "LIMIT " + std::to_string(k) + ";";
  106. sqlite3_stmt* stmt;
  107. int rc = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, NULL);
  108. if (rc != SQLITE_OK) {
  109. return results;
  110. }
  111. sqlite3_bind_blob(stmt, 1, blobData.data(), static_cast<int>(blobData.size()), SQLITE_STATIC);
  112. while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
  113. const char* imagePath = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
  114. float distance = static_cast<float>(sqlite3_column_double(stmt, 1));
  115. float similarity = distanceToSimilarity(distance);
  116. results.emplace_back(std::string(imagePath), similarity);
  117. }
  118. sqlite3_finalize(stmt);
  119. } else {
  120. std::cout << "使用传统方法进行向量搜索" << std::endl;
  121. // 使用传统方式计算相似度
  122. const char* sql = "SELECT image_path, feature_vector FROM image_features;";
  123. sqlite3_stmt* stmt;
  124. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  125. if (rc != SQLITE_OK) {
  126. return results;
  127. }
  128. // 计算查询向量的模长
  129. float queryNorm = 0.0f;
  130. for (float val : queryVector) {
  131. queryNorm += val * val;
  132. }
  133. queryNorm = sqrt(queryNorm);
  134. while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
  135. const char* imagePath = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
  136. const void* blobData = sqlite3_column_blob(stmt, 1);
  137. int blobSize = sqlite3_column_bytes(stmt, 1);
  138. std::string blobStr(static_cast<const char*>(blobData), blobSize);
  139. std::vector<float> storedFeatures = blobToVector(blobStr);
  140. if (storedFeatures.size() == queryVector.size()) {
  141. // 计算余弦相似度
  142. float dotProduct = 0.0f;
  143. float storedNorm = 0.0f;
  144. for (size_t i = 0; i < queryVector.size(); ++i) {
  145. dotProduct += queryVector[i] * storedFeatures[i];
  146. storedNorm += storedFeatures[i] * storedFeatures[i];
  147. }
  148. storedNorm = sqrt(storedNorm);
  149. float similarity = 0.0f;
  150. if (queryNorm > 0.0f && storedNorm > 0.0f) {
  151. similarity = dotProduct / (queryNorm * storedNorm);
  152. }
  153. results.emplace_back(std::string(imagePath), similarity);
  154. }
  155. }
  156. sqlite3_finalize(stmt);
  157. // 按相似度降序排列
  158. std::sort(results.begin(), results.end(),
  159. [](const auto& a, const auto& b) { return a.second > b.second; });
  160. // 返回topK结果
  161. if (static_cast<int>(results.size()) > k) {
  162. results.resize(k);
  163. }
  164. }
  165. return results;
  166. }
  167. void SQLiteVecManager::saveDatabase() {
  168. sqlite3_exec(db, "PRAGMA optimize;", 0, 0, 0);
  169. }
  170. bool SQLiteVecManager::loadDatabase() {
  171. const char* sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='image_features';";
  172. sqlite3_stmt* stmt;
  173. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  174. if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) {
  175. sqlite3_finalize(stmt);
  176. return true;
  177. }
  178. sqlite3_finalize(stmt);
  179. return false;
  180. }
  181. int SQLiteVecManager::getFeatureCount() const {
  182. const char* sql = "SELECT COUNT(*) FROM image_features;";
  183. sqlite3_stmt* stmt;
  184. int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
  185. if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) {
  186. int count = sqlite3_column_int(stmt, 0);
  187. sqlite3_finalize(stmt);
  188. return count;
  189. }
  190. sqlite3_finalize(stmt);
  191. return 0;
  192. }
  193. bool SQLiteVecManager::isEmpty() const {
  194. return getFeatureCount() == 0;
  195. }
  196. std::string SQLiteVecManager::vectorToBlob(const std::vector<float>& vec) {
  197. return std::string(reinterpret_cast<const char*>(vec.data()), vec.size() * sizeof(float));
  198. }
  199. std::vector<float> SQLiteVecManager::blobToVector(const std::string& blob) {
  200. const float* data = reinterpret_cast<const float*>(blob.data());
  201. size_t count = blob.size() / sizeof(float);
  202. return std::vector<float>(data, data + count);
  203. }
  204. float SQLiteVecManager::calculateCosineSimilarity(const std::vector<float>& vec1, const std::vector<float>& vec2) {
  205. if (vec1.size() != vec2.size() || vec1.empty()) {
  206. return 0.0f;
  207. }
  208. float dotProduct = 0.0f;
  209. float norm1 = 0.0f;
  210. float norm2 = 0.0f;
  211. for (size_t i = 0; i < vec1.size(); ++i) {
  212. dotProduct += vec1[i] * vec2[i];
  213. norm1 += vec1[i] * vec1[i];
  214. norm2 += vec2[i] * vec2[i];
  215. }
  216. if (norm1 == 0.0f || norm2 == 0.0f) {
  217. return 0.0f;
  218. }
  219. return dotProduct / (std::sqrt(norm1) * std::sqrt(norm2));
  220. }
  221. float SQLiteVecManager::distanceToSimilarity(float distance) {
  222. return 1.0f - distance;
  223. }