test.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #include "../pch/pch.h"
  2. #include "test.h"
  3. #include "YoloFeatureManager.h"
  4. #include "ImageProcessor.h"
  5. #include "SQLiteVecManager.h"
  6. #include "../tool/CSetting.h"
  7. #include "../tool/debuglog.h"
  8. #include <opencv2/opencv.hpp>
  9. #include "../tool/debuglog.h"
  10. /*
  11. int AITest()
  12. {
  13. try
  14. {
  15. // 设置路径(可根据实际情况修改)
  16. std::wstring wsExePath = CSystem::getExePath();
  17. std::wstring wsProgramDir = CSystem::GetProgramDir();
  18. std::filesystem::path mainDir = wsProgramDir;
  19. std::string sMainDir = mainDir.string();
  20. //用于测试的图片目录
  21. std::string galleryDir = (mainDir.parent_path().parent_path().parent_path().parent_path() /"res"/"images").string(); // 图库目录路径
  22. std::string modelPath = sMainDir + "/ai/best_448.onnx"; // YOLO2026模型路径
  23. std::string searchImagePath = sMainDir + "/3.jpg"; // 搜索图片路径
  24. std::string databasePath = sMainDir + "/image_features.db"; // SQLite数据库路径
  25. std::filesystem::remove(databasePath);
  26. std::cout << "=== YOLO2026图像检索系统 (SQLite-Vec版本) ===" << std::endl;
  27. std::cout << "模型路径: " << modelPath << std::endl;
  28. std::cout << "图库路径: " << galleryDir << std::endl;
  29. std::cout << "查询图片: " << searchImagePath << std::endl;
  30. std::cout << "数据库路径: " << databasePath << std::endl;
  31. std::cout << "=========================================" << std::endl;
  32. // 初始化特征提取器
  33. std::cout << "正在初始化YOLO2026特征提取器..." << std::endl;
  34. YoloFeatureManager extractor;
  35. extractor.loadModel(modelPath);
  36. // 获取图库中的所有图片
  37. std::cout << "正在扫描图库中的图片..." << std::endl;
  38. std::vector<std::string> galleryImages = ImageProcessor::getImagesInDirectory(galleryDir);
  39. std::cout << "找到 " << galleryImages.size() << " 张图片" << std::endl;
  40. if (galleryImages.empty())
  41. {
  42. std::cerr << "图库中没有找到有效的图片文件!" << std::endl;
  43. return -1;
  44. }
  45. // 初始化SQLite-Vec数据库
  46. std::cout << "正在初始化SQLite-Vec数据库..." << std::endl;
  47. SQLiteVecManager vecManager(databasePath);
  48. auto start_time = std::chrono::high_resolution_clock::now();
  49. // 检查数据库是否已存在数据
  50. bool databaseExists = vecManager.loadDatabase();
  51. std::cout << "数据库加载结果:" << std::boolalpha << databaseExists << std::endl;
  52. std::cout << "数据集数量:" << vecManager.getFeatureCount() << std::endl;
  53. if (databaseExists && vecManager.getFeatureCount() > 0)
  54. {
  55. std::cout << "数据库已存在 " << vecManager.getFeatureCount() << " 条特征记录" << std::endl;
  56. }
  57. else
  58. {
  59. std::cout << "重新构建数据库..." << std::endl;
  60. // 提取图库图片特征并向量
  61. std::cout << "开始提取图库图片特征..." << std::endl;
  62. int featureDimension = 0;
  63. int processedCount = 0;
  64. for (size_t i = 0; i < galleryImages.size(); ++i)
  65. {
  66. try
  67. {
  68. std::vector<float> features = extractor.extractFeatures(galleryImages[i]);
  69. if (!features.empty())
  70. {
  71. if (featureDimension == 0)
  72. {
  73. featureDimension = static_cast<int>(features.size());
  74. DEBUG_HELPER::debug_printf("特征维度: %d\n", featureDimension);
  75. // 初始化数据库表结构
  76. vecManager.initializeDatabase(featureDimension);
  77. }
  78. vecManager.addFeatureVector(features, galleryImages[i]);
  79. processedCount++;
  80. DEBUG_HELPER::debug_printf("已处理 [%d/%d]: %s\n", (i + 1), static_cast<int>(galleryImages.size()), galleryImages[i].c_str());
  81. }
  82. }
  83. catch (const std::exception & e)
  84. {
  85. DEBUG_HELPER::debug_printf("处理失败 [%d/%d]: %s (%s)\n", (i + 1), static_cast<int>(galleryImages.size()), galleryImages[i].c_str(), e.what());
  86. }
  87. }
  88. auto end_time = std::chrono::high_resolution_clock::now();
  89. auto duration = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time);
  90. std::cout << "特征提取完成,耗时: " << duration.count() << " 秒" << std::endl;
  91. std::cout << "成功处理图片数量: " << processedCount << std::endl;
  92. std::cout << "数据库记录数量: " << vecManager.getFeatureCount() << std::endl;
  93. if (processedCount == 0)
  94. {
  95. std::cerr << "没有成功提取任何图片特征!" << std::endl;
  96. return -1;
  97. }
  98. // 保存数据库
  99. vecManager.saveDatabase();
  100. std::cout << "数据库保存完成" << std::endl;
  101. }
  102. // 提取查询图片特征
  103. std::cout << "正在提取查询图片特征: " << searchImagePath << std::endl;
  104. std::vector<float> queryFeatures = extractor.extractFeatures(searchImagePath);
  105. if (queryFeatures.empty())
  106. {
  107. std::cerr << "无法提取查询图片特征!" << std::endl;
  108. return -1;
  109. }
  110. std::cout << "查询图片特征提取完成,维度: " << queryFeatures.size() << std::endl;
  111. // 进行相似性搜索
  112. std::cout << "正在进行相似性搜索..." << std::endl;
  113. std::vector<std::pair<std::string, float>> searchResults = vecManager.searchSimilarVectors(queryFeatures, 20);
  114. // 显示搜索结果
  115. std::cout << "\n=== 搜索结果 ===" << std::endl;
  116. std::cout << "查询图片: " << searchImagePath << std::endl;
  117. std::cout << "最相似的图片:" << std::endl;
  118. for (size_t i = 0; i < searchResults.size(); ++i)
  119. {
  120. std::cout << (i + 1) << ". " << searchResults[i].first
  121. << " (相似度: " << searchResults[i].second << ")" << std::endl;
  122. }
  123. std::cout << "\n图像检索完成!" << std::endl;
  124. }
  125. catch (const std::exception & e)
  126. {
  127. std::string err = e.what();
  128. std::cerr << "程序执行出错: " << err << std::endl;
  129. return -1;
  130. }
  131. return 0;
  132. }
  133. */