test.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. // test.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
  2. //
  3. #include <iostream>
  4. #include <windows.h>
  5. #include <io.h>
  6. #include <ShellAPI.h>
  7. #include <atltypes.h>
  8. #include <direct.h>
  9. #include "tlhelp32.h"
  10. #include "YoloFeatureExtractor.h"
  11. #include "ImageProcessor.h"
  12. #include "SQLiteVecManager.h"
  13. #include <filesystem>
  14. #include <opencv2/core.hpp>
  15. #include <opencv2/core/ocl.hpp>
  16. #include <opencv2/highgui.hpp>
  17. #include <opencv2/imgcodecs.hpp>
  18. #include <opencv2/imgproc.hpp>
  19. void checkBuild()
  20. {
  21. //查询opencv编译时配置
  22. std::cout << cv::getBuildInformation() << std::endl;
  23. }
  24. void checkSimd()
  25. {
  26. //查询opencv线程
  27. int numTh = cv::getNumThreads(); //默认值是cpu的逻辑线程数
  28. int numCore = cv::getNumberOfCPUs();
  29. std::cout << "getNumThreads=" << numTh << std::endl;
  30. std::cout << "getNumberOfCPUs=" << numCore << std::endl;
  31. //查询opencv当前是否开启了并行优化功能
  32. bool opt = cv::useOptimized(); //默认值是true
  33. std::cout << "useOptimized=" << opt << std::endl;
  34. //查询opencv当前是否支持具体的CPU指令集
  35. bool check1 = cv::checkHardwareSupport(CV_CPU_SSE4_1);
  36. bool check2 = cv::checkHardwareSupport(CV_CPU_SSE4_2);
  37. bool check3 = cv::checkHardwareSupport(CV_CPU_AVX2);
  38. std::cout << "CV_CPU_SSE4_1=" << check1 << std::endl;
  39. std::cout << "CV_CPU_SSE4_2=" << check2 << std::endl;
  40. std::cout << "CV_CPU_AVX2=" << check3 << std::endl;
  41. //查询完整的硬件支持清单
  42. std::cout << "HardwareSupport:" << std::endl;
  43. std::cout << "CV_CPU_MMX: " << cv::checkHardwareSupport(CV_CPU_MMX) << std::endl;
  44. std::cout << "CV_CPU_SSE: " << cv::checkHardwareSupport(CV_CPU_SSE) << std::endl;
  45. std::cout << "CV_CPU_SSE2: " << cv::checkHardwareSupport(CV_CPU_SSE2) << std::endl;
  46. std::cout << "CV_CPU_SSE3: " << cv::checkHardwareSupport(CV_CPU_SSE3) << std::endl;
  47. std::cout << "CV_CPU_SSSE3: " << cv::checkHardwareSupport(CV_CPU_SSSE3) << std::endl;
  48. std::cout << "CV_CPU_SSE4_1: " << cv::checkHardwareSupport(CV_CPU_SSE4_1) << std::endl;
  49. std::cout << "CV_CPU_SSE4_2: " << cv::checkHardwareSupport(CV_CPU_SSE4_2) << std::endl;
  50. std::cout << "CV_CPU_POPCNT: " << cv::checkHardwareSupport(CV_CPU_POPCNT) << std::endl;
  51. std::cout << "CV_CPU_FP16: " << cv::checkHardwareSupport(CV_CPU_FP16) << std::endl;
  52. std::cout << "CV_CPU_AVX: " << cv::checkHardwareSupport(CV_CPU_AVX) << std::endl;
  53. std::cout << "CV_CPU_AVX2: " << cv::checkHardwareSupport(CV_CPU_AVX2) << std::endl;
  54. std::cout << "CV_CPU_FMA3: " << cv::checkHardwareSupport(CV_CPU_FMA3) << std::endl;
  55. std::cout << "CV_CPU_AVX_512F: " << cv::checkHardwareSupport(CV_CPU_AVX_512F) << std::endl;
  56. std::cout << "CV_CPU_AVX_512BW: " << cv::checkHardwareSupport(CV_CPU_AVX_512BW) << std::endl;
  57. std::cout << "CV_CPU_AVX_512CD: " << cv::checkHardwareSupport(CV_CPU_AVX_512CD) << std::endl;
  58. std::cout << "CV_CPU_AVX_512DQ: " << cv::checkHardwareSupport(CV_CPU_AVX_512DQ) << std::endl;
  59. std::cout << "CV_CPU_AVX_512ER: " << cv::checkHardwareSupport(CV_CPU_AVX_512ER) << std::endl;
  60. std::cout << "CV_CPU_AVX_512IFMA512: " << cv::checkHardwareSupport(CV_CPU_AVX_512IFMA512) << std::endl;
  61. std::cout << "CV_CPU_AVX_512IFMA: " << cv::checkHardwareSupport(CV_CPU_AVX_512IFMA) << std::endl;
  62. std::cout << "CV_CPU_AVX_512PF: " << cv::checkHardwareSupport(CV_CPU_AVX_512PF) << std::endl;
  63. std::cout << "CV_CPU_AVX_512VBMI: " << cv::checkHardwareSupport(CV_CPU_AVX_512VBMI) << std::endl;
  64. std::cout << "CV_CPU_AVX_512VL: " << cv::checkHardwareSupport(CV_CPU_AVX_512VL) << std::endl;
  65. std::cout << "CV_CPU_NEON: " << cv::checkHardwareSupport(CV_CPU_NEON) << std::endl;
  66. std::cout << "CV_CPU_VSX: " << cv::checkHardwareSupport(CV_CPU_VSX) << std::endl;
  67. std::cout << "CV_CPU_AVX512_SKX: " << cv::checkHardwareSupport(CV_CPU_AVX512_SKX) << std::endl;
  68. std::cout << "CV_HARDWARE_MAX_FEATURE: " << cv::checkHardwareSupport(CV_HARDWARE_MAX_FEATURE) << std::endl;
  69. std::cout << std::endl;
  70. //cv::setUseOptimized(false);
  71. //cv::setNumThreads(1);
  72. }
  73. void checkOpenCL() //Open Computing Language:开放计算语言,可以附加在主机处理器的CPU或GPU上执行
  74. {
  75. std::vector<cv::ocl::PlatformInfo> info;
  76. cv::ocl::getPlatfomsInfo(info);
  77. cv::ocl::PlatformInfo sdk = info.at(0);
  78. int number = sdk.deviceNumber();
  79. if (number < 1)
  80. {
  81. std::cout << "Number of devices:" << number << std::endl;
  82. return;
  83. }
  84. std::cout << "***********SDK************" << std::endl;
  85. std::cout << "Name:" << sdk.name() << std::endl;
  86. std::cout << "Vendor:" << sdk.vendor() << std::endl;
  87. std::cout << "Version:" << sdk.version() << std::endl;
  88. std::cout << "Version:" << sdk.version() << std::endl;
  89. std::cout << "Number of devices:" << number << std::endl;
  90. for (int i = 0; i < number; i++)
  91. {
  92. std::cout << std::endl;
  93. cv::ocl::Device device;
  94. sdk.getDevice(device, i);
  95. std::cout << "***********Device " << i + 1 << "***********" << std::endl;
  96. std::cout << "Vendor Id:" << device.vendorID() << std::endl;
  97. std::cout << "Vendor name:" << device.vendorName() << std::endl;
  98. std::cout << "Name:" << device.name() << std::endl;
  99. std::cout << "Driver version:" << device.vendorID() << std::endl;
  100. if (device.isAMD())
  101. std::cout << "Is AMD device" << std::endl;
  102. if (device.isIntel())
  103. std::cout << "Is Intel device" << std::endl;
  104. if (device.isNVidia())
  105. std::cout << "Is NVidia device" << std::endl;
  106. std::cout << "Global Memory size:" << device.globalMemSize() << std::endl;
  107. std::cout << "Memory cache size:" << device.globalMemCacheSize() << std::endl;
  108. std::cout << "Memory cache type:" << device.globalMemCacheType() << std::endl;
  109. std::cout << "Local Memory size:" << device.localMemSize() << std::endl;
  110. std::cout << "Local Memory type:" << device.localMemType() << std::endl;
  111. std::cout << "Max Clock frequency:" << device.maxClockFrequency() << std::endl;
  112. }
  113. }
  114. std::wstring getExePath()
  115. {
  116. wchar_t exeFullPath[MAX_PATH]; // Full path
  117. std::wstring strPath = L"";
  118. GetModuleFileName(NULL, exeFullPath, MAX_PATH);
  119. strPath = (std::wstring)exeFullPath; // Get full path of the file
  120. return strPath;
  121. }
  122. std::wstring GetProgramDir()
  123. {
  124. wchar_t exeFullPath[MAX_PATH]; // Full path
  125. std::wstring strPath = L"";
  126. GetModuleFileName(NULL, exeFullPath, MAX_PATH);
  127. strPath = (std::wstring)exeFullPath; // Get full path of the file
  128. int pos = strPath.find_last_of('\\', strPath.length());
  129. return strPath.substr(0, pos); // Return the directory without the file name
  130. }
  131. int main()
  132. {
  133. try
  134. {
  135. // 检查 OpenCL 是否可用
  136. checkBuild();
  137. checkSimd();
  138. //checkOpenCL();
  139. // 检查OpenCL是否可用
  140. if (!cv::ocl::haveOpenCL())
  141. {
  142. std::cout << "opencl不可用" << std::endl;
  143. }
  144. else
  145. {
  146. std::cout << "opencl可用" << std::endl;
  147. // 启用OpenCL
  148. //cv::ocl::setUseOpenCL(true);
  149. }
  150. // 设置路径(可根据实际情况修改)
  151. std::wstring wsExePath = getExePath();
  152. std::wstring wsProgramDir = GetProgramDir();
  153. std::filesystem::path mainDir = wsProgramDir;
  154. #ifdef _WIN64
  155. std::string sMainDir = mainDir.parent_path().parent_path().string();
  156. #else
  157. std::string sMainDir = mainDir.parent_path().string();
  158. #endif // WIN32
  159. std::string modelPath = sMainDir + "/best.onnx"; // YOLO2026模型路径
  160. std::string classesPath = sMainDir + "/cls.names"; // 类别文件路径
  161. std::string galleryDir = sMainDir + "/images"; // 图库目录路径
  162. std::string searchImagePath = sMainDir + "/3.jpg"; // 搜索图片路径
  163. std::string databasePath = sMainDir + "/image_features.db"; // SQLite数据库路径
  164. std::filesystem::remove(databasePath);
  165. std::cout << "=== YOLO2026图像检索系统 (SQLite-Vec版本) ===" << std::endl;
  166. std::cout << "模型路径: " << modelPath << std::endl;
  167. std::cout << "图库路径: " << galleryDir << std::endl;
  168. std::cout << "查询图片: " << searchImagePath << std::endl;
  169. std::cout << "数据库路径: " << databasePath << std::endl;
  170. std::cout << "=========================================" << std::endl;
  171. // 初始化特征提取器
  172. std::cout << "正在初始化YOLO2026特征提取器..." << std::endl;
  173. YoloFeatureExtractor extractor(modelPath, classesPath);
  174. extractor.initOpenCL(); // 启用OpenCL加速(如果可用)
  175. // 获取图库中的所有图片
  176. std::cout << "正在扫描图库中的图片..." << std::endl;
  177. std::vector<std::string> galleryImages = ImageProcessor::getImagesInDirectory(galleryDir);
  178. std::cout << "找到 " << galleryImages.size() << " 张图片" << std::endl;
  179. if (galleryImages.empty())
  180. {
  181. std::cerr << "图库中没有找到有效的图片文件!" << std::endl;
  182. return -1;
  183. }
  184. // 初始化SQLite-Vec数据库
  185. std::cout << "正在初始化SQLite-Vec数据库..." << std::endl;
  186. SQLiteVecManager vecManager(databasePath);
  187. auto start_time = std::chrono::high_resolution_clock::now();
  188. // 检查数据库是否已存在数据
  189. bool databaseExists = vecManager.loadDatabase();
  190. std::cout << "数据库加载结果:" << std::boolalpha << databaseExists << std::endl;
  191. std::cout << "数据集数量:" << vecManager.getFeatureCount() << std::endl;
  192. if (databaseExists && vecManager.getFeatureCount() > 0)
  193. {
  194. std::cout << "数据库已存在 " << vecManager.getFeatureCount() << " 条特征记录" << std::endl;
  195. }
  196. else
  197. {
  198. std::cout << "重新构建数据库..." << std::endl;
  199. // 提取图库图片特征并向量
  200. std::cout << "开始提取图库图片特征..." << std::endl;
  201. int featureDimension = 0;
  202. int processedCount = 0;
  203. // 初始化数据库表结构
  204. vecManager.initializeDatabase(45); // 假设特征维度为1000,实际会在第一次处理时确定
  205. for (size_t i = 0; i < galleryImages.size(); ++i)
  206. {
  207. try
  208. {
  209. std::vector<float> features = extractor.extractFeatures(galleryImages[i]);
  210. if (!features.empty())
  211. {
  212. if (featureDimension == 0)
  213. {
  214. featureDimension = static_cast<int>(features.size());
  215. std::cout << "特征维度: " << featureDimension << std::endl;
  216. // 重新初始化数据库以匹配实际维度
  217. vecManager.initializeDatabase(featureDimension);
  218. }
  219. vecManager.addFeatureVector(features, galleryImages[i]);
  220. processedCount++;
  221. std::cout << "[" << (i + 1) << "/" << galleryImages.size() << "] 已处理: "
  222. << galleryImages[i] << std::endl;
  223. }
  224. }
  225. catch (const std::exception & e)
  226. {
  227. std::cerr << "处理失败 [" << (i + 1) << "]: " << galleryImages[i]
  228. << " (" << e.what() << ")" << std::endl;
  229. }
  230. }
  231. auto end_time = std::chrono::high_resolution_clock::now();
  232. auto duration = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time);
  233. std::cout << "特征提取完成,耗时: " << duration.count() << " 秒" << std::endl;
  234. std::cout << "成功处理图片数量: " << processedCount << std::endl;
  235. std::cout << "数据库记录数量: " << vecManager.getFeatureCount() << std::endl;
  236. if (processedCount == 0)
  237. {
  238. std::cerr << "没有成功提取任何图片特征!" << std::endl;
  239. return -1;
  240. }
  241. // 保存数据库
  242. vecManager.saveDatabase();
  243. std::cout << "数据库保存完成" << std::endl;
  244. }
  245. // 提取查询图片特征
  246. std::cout << "正在提取查询图片特征: " << searchImagePath << std::endl;
  247. std::vector<float> queryFeatures = extractor.extractFeatures(searchImagePath);
  248. if (queryFeatures.empty())
  249. {
  250. std::cerr << "无法提取查询图片特征!" << std::endl;
  251. return -1;
  252. }
  253. std::cout << "查询图片特征提取完成,维度: " << queryFeatures.size() << std::endl;
  254. // 进行相似性搜索
  255. std::cout << "正在进行相似性搜索..." << std::endl;
  256. std::vector<std::pair<std::string, float>> searchResults = vecManager.searchSimilarVectors(queryFeatures, 5);
  257. // 显示搜索结果
  258. std::cout << "\n=== 搜索结果 ===" << std::endl;
  259. std::cout << "查询图片: " << searchImagePath << std::endl;
  260. std::cout << "最相似的图片:" << std::endl;
  261. for (size_t i = 0; i < searchResults.size(); ++i)
  262. {
  263. std::cout << (i + 1) << ". " << searchResults[i].first
  264. << " (相似度: " << searchResults[i].second << ")" << std::endl;
  265. }
  266. std::cout << "\n图像检索完成!" << std::endl;
  267. }
  268. catch (const std::exception & e)
  269. {
  270. std::string a = e.what();
  271. std::cerr << "程序执行出错: " << e.what() << std::endl;
  272. return -1;
  273. }
  274. sqlite3_sleep(100000);
  275. return 0;
  276. }