test.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. // test.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
  2. //
  3. #include <iostream>
  4. #include <windows.h>
  5. #include <io.h>
  6. #include <ShellAPI.h>
  7. #include <atltypes.h>
  8. #include <direct.h>
  9. #include "tlhelp32.h"
  10. #include "YoloFeatureExtractor.h"
  11. #include "ImageProcessor.h"
  12. #include "SQLiteVecManager.h"
  13. #include <filesystem>
  14. #include <opencv2/core.hpp>
  15. #include <opencv2/core/ocl.hpp>
  16. #include <opencv2/highgui.hpp>
  17. #include <opencv2/imgcodecs.hpp>
  18. #include <opencv2/imgproc.hpp>
  19. #include <opencv2/opencv.hpp>
  20. #include <iostream>
  21. using namespace cv;
  22. using namespace std;
  23. #include <iostream>
  24. void checkBuild()
  25. {
  26. //查询opencv编译时配置
  27. std::cout << cv::getBuildInformation() << std::endl;
  28. }
  29. void checkSimd()
  30. {
  31. //查询opencv线程
  32. int numTh = cv::getNumThreads(); //默认值是cpu的逻辑线程数
  33. int numCore = cv::getNumberOfCPUs();
  34. std::cout << "getNumThreads=" << numTh << std::endl;
  35. std::cout << "getNumberOfCPUs=" << numCore << std::endl;
  36. //查询opencv当前是否开启了并行优化功能
  37. bool opt = cv::useOptimized(); //默认值是true
  38. std::cout << "useOptimized=" << opt << std::endl;
  39. //查询opencv当前是否支持具体的CPU指令集
  40. bool check1 = cv::checkHardwareSupport(CV_CPU_SSE4_1);
  41. bool check2 = cv::checkHardwareSupport(CV_CPU_SSE4_2);
  42. bool check3 = cv::checkHardwareSupport(CV_CPU_AVX2);
  43. std::cout << "CV_CPU_SSE4_1=" << check1 << std::endl;
  44. std::cout << "CV_CPU_SSE4_2=" << check2 << std::endl;
  45. std::cout << "CV_CPU_AVX2=" << check3 << std::endl;
  46. //查询完整的硬件支持清单
  47. std::cout << "HardwareSupport:" << std::endl;
  48. std::cout << "CV_CPU_MMX: " << cv::checkHardwareSupport(CV_CPU_MMX) << std::endl;
  49. std::cout << "CV_CPU_SSE: " << cv::checkHardwareSupport(CV_CPU_SSE) << std::endl;
  50. std::cout << "CV_CPU_SSE2: " << cv::checkHardwareSupport(CV_CPU_SSE2) << std::endl;
  51. std::cout << "CV_CPU_SSE3: " << cv::checkHardwareSupport(CV_CPU_SSE3) << std::endl;
  52. std::cout << "CV_CPU_SSSE3: " << cv::checkHardwareSupport(CV_CPU_SSSE3) << std::endl;
  53. std::cout << "CV_CPU_SSE4_1: " << cv::checkHardwareSupport(CV_CPU_SSE4_1) << std::endl;
  54. std::cout << "CV_CPU_SSE4_2: " << cv::checkHardwareSupport(CV_CPU_SSE4_2) << std::endl;
  55. std::cout << "CV_CPU_POPCNT: " << cv::checkHardwareSupport(CV_CPU_POPCNT) << std::endl;
  56. std::cout << "CV_CPU_FP16: " << cv::checkHardwareSupport(CV_CPU_FP16) << std::endl;
  57. std::cout << "CV_CPU_AVX: " << cv::checkHardwareSupport(CV_CPU_AVX) << std::endl;
  58. std::cout << "CV_CPU_AVX2: " << cv::checkHardwareSupport(CV_CPU_AVX2) << std::endl;
  59. std::cout << "CV_CPU_FMA3: " << cv::checkHardwareSupport(CV_CPU_FMA3) << std::endl;
  60. std::cout << "CV_CPU_AVX_512F: " << cv::checkHardwareSupport(CV_CPU_AVX_512F) << std::endl;
  61. std::cout << "CV_CPU_AVX_512BW: " << cv::checkHardwareSupport(CV_CPU_AVX_512BW) << std::endl;
  62. std::cout << "CV_CPU_AVX_512CD: " << cv::checkHardwareSupport(CV_CPU_AVX_512CD) << std::endl;
  63. std::cout << "CV_CPU_AVX_512DQ: " << cv::checkHardwareSupport(CV_CPU_AVX_512DQ) << std::endl;
  64. std::cout << "CV_CPU_AVX_512ER: " << cv::checkHardwareSupport(CV_CPU_AVX_512ER) << std::endl;
  65. std::cout << "CV_CPU_AVX_512IFMA512: " << cv::checkHardwareSupport(CV_CPU_AVX_512IFMA512) << std::endl;
  66. std::cout << "CV_CPU_AVX_512IFMA: " << cv::checkHardwareSupport(CV_CPU_AVX_512IFMA) << std::endl;
  67. std::cout << "CV_CPU_AVX_512PF: " << cv::checkHardwareSupport(CV_CPU_AVX_512PF) << std::endl;
  68. std::cout << "CV_CPU_AVX_512VBMI: " << cv::checkHardwareSupport(CV_CPU_AVX_512VBMI) << std::endl;
  69. std::cout << "CV_CPU_AVX_512VL: " << cv::checkHardwareSupport(CV_CPU_AVX_512VL) << std::endl;
  70. std::cout << "CV_CPU_NEON: " << cv::checkHardwareSupport(CV_CPU_NEON) << std::endl;
  71. std::cout << "CV_CPU_VSX: " << cv::checkHardwareSupport(CV_CPU_VSX) << std::endl;
  72. std::cout << "CV_CPU_AVX512_SKX: " << cv::checkHardwareSupport(CV_CPU_AVX512_SKX) << std::endl;
  73. std::cout << "CV_HARDWARE_MAX_FEATURE: " << cv::checkHardwareSupport(CV_HARDWARE_MAX_FEATURE) << std::endl;
  74. std::cout << std::endl;
  75. //cv::setUseOptimized(false);
  76. //cv::setNumThreads(1);
  77. }
  78. void checkOpenCL() //Open Computing Language:开放计算语言,可以附加在主机处理器的CPU或GPU上执行
  79. {
  80. std::vector<cv::ocl::PlatformInfo> info;
  81. cv::ocl::getPlatfomsInfo(info);
  82. cv::ocl::PlatformInfo sdk = info.at(0);
  83. int number = sdk.deviceNumber();
  84. if (number < 1)
  85. {
  86. std::cout << "Number of devices:" << number << std::endl;
  87. return;
  88. }
  89. std::cout << "***********SDK************" << std::endl;
  90. std::cout << "Name:" << sdk.name() << std::endl;
  91. std::cout << "Vendor:" << sdk.vendor() << std::endl;
  92. std::cout << "Version:" << sdk.version() << std::endl;
  93. std::cout << "Version:" << sdk.version() << std::endl;
  94. std::cout << "Number of devices:" << number << std::endl;
  95. for (int i = 0; i < number; i++)
  96. {
  97. std::cout << std::endl;
  98. cv::ocl::Device device;
  99. sdk.getDevice(device, i);
  100. std::cout << "***********Device " << i + 1 << "***********" << std::endl;
  101. std::cout << "Vendor Id:" << device.vendorID() << std::endl;
  102. std::cout << "Vendor name:" << device.vendorName() << std::endl;
  103. std::cout << "Name:" << device.name() << std::endl;
  104. std::cout << "Driver version:" << device.vendorID() << std::endl;
  105. if (device.isAMD())
  106. std::cout << "Is AMD device" << std::endl;
  107. if (device.isIntel())
  108. std::cout << "Is Intel device" << std::endl;
  109. if (device.isNVidia())
  110. std::cout << "Is NVidia device" << std::endl;
  111. std::cout << "Global Memory size:" << device.globalMemSize() << std::endl;
  112. std::cout << "Memory cache size:" << device.globalMemCacheSize() << std::endl;
  113. std::cout << "Memory cache type:" << device.globalMemCacheType() << std::endl;
  114. std::cout << "Local Memory size:" << device.localMemSize() << std::endl;
  115. std::cout << "Local Memory type:" << device.localMemType() << std::endl;
  116. std::cout << "Max Clock frequency:" << device.maxClockFrequency() << std::endl;
  117. }
  118. }
  119. std::wstring getExePath()
  120. {
  121. wchar_t exeFullPath[MAX_PATH]; // Full path
  122. std::wstring strPath = L"";
  123. GetModuleFileName(NULL, exeFullPath, MAX_PATH);
  124. strPath = (std::wstring)exeFullPath; // Get full path of the file
  125. return strPath;
  126. }
  127. std::wstring GetProgramDir()
  128. {
  129. wchar_t exeFullPath[MAX_PATH]; // Full path
  130. std::wstring strPath = L"";
  131. GetModuleFileName(NULL, exeFullPath, MAX_PATH);
  132. strPath = (std::wstring)exeFullPath; // Get full path of the file
  133. int pos = strPath.find_last_of('\\', strPath.length());
  134. return strPath.substr(0, pos); // Return the directory without the file name
  135. }
  136. int main()
  137. {
  138. try
  139. {
  140. // 检查 OpenCL 是否可用
  141. //checkBuild();
  142. //checkSimd();
  143. //checkOpenCL();
  144. // 检查OpenCL是否可用
  145. if (!cv::ocl::haveOpenCL())
  146. {
  147. std::cout << "opencl不可用" << std::endl;
  148. }
  149. else
  150. {
  151. std::cout << "opencl可用" << std::endl;
  152. // 启用OpenCL
  153. //cv::ocl::setUseOpenCL(true);
  154. }
  155. // 设置路径(可根据实际情况修改)
  156. std::wstring wsExePath = getExePath();
  157. std::wstring wsProgramDir = GetProgramDir();
  158. std::filesystem::path mainDir = wsProgramDir;
  159. #ifdef _WIN64
  160. std::string sMainDir = mainDir.string();
  161. #else
  162. std::string sMainDir = mainDir.parent_path().string();
  163. #endif // WIN32
  164. std::string modelPath = sMainDir + "/best_2026n_448.onnx"; // YOLO2026模型路径
  165. std::string classesPath = sMainDir + "/cls.names"; // 类别文件路径
  166. std::string galleryDir = sMainDir + "/images"; // 图库目录路径
  167. std::string searchImagePath = sMainDir + "/3.jpg"; // 搜索图片路径
  168. std::string databasePath = sMainDir + "/image_features.db"; // SQLite数据库路径
  169. std::string openvino_modelPath = sMainDir + "\\best.xml"; // YOLO2026模型路径
  170. std::string openvino_configPath = sMainDir + "\\best.bin"; // YOLO2026模型路径
  171. std::filesystem::remove(databasePath);
  172. std::cout << "=== YOLO2026图像检索系统 (SQLite-Vec版本) ===" << std::endl;
  173. std::cout << "模型路径: " << modelPath << std::endl;
  174. std::cout << "图库路径: " << galleryDir << std::endl;
  175. std::cout << "查询图片: " << searchImagePath << std::endl;
  176. std::cout << "数据库路径: " << databasePath << std::endl;
  177. std::cout << "=========================================" << std::endl;
  178. // 初始化特征提取器
  179. std::cout << "正在初始化YOLO2026特征提取器..." << std::endl;
  180. //Sleep(10000);
  181. //YoloFeatureExtractor extractor(openvino_modelPath, openvino_configPath, classesPath);
  182. YoloFeatureExtractor extractor(modelPath, classesPath);
  183. extractor.initOpenCL(); // 启用OpenCL加速(如果可用)
  184. // 获取图库中的所有图片
  185. std::cout << "正在扫描图库中的图片..." << std::endl;
  186. std::vector<std::string> galleryImages = ImageProcessor::getImagesInDirectory(galleryDir);
  187. std::cout << "找到 " << galleryImages.size() << " 张图片" << std::endl;
  188. if (galleryImages.empty())
  189. {
  190. std::cerr << "图库中没有找到有效的图片文件!" << std::endl;
  191. return -1;
  192. }
  193. // 初始化SQLite-Vec数据库
  194. std::cout << "正在初始化SQLite-Vec数据库..." << std::endl;
  195. SQLiteVecManager vecManager(databasePath);
  196. auto start_time = std::chrono::high_resolution_clock::now();
  197. // 检查数据库是否已存在数据
  198. bool databaseExists = vecManager.loadDatabase();
  199. std::cout << "数据库加载结果:" << std::boolalpha << databaseExists << std::endl;
  200. std::cout << "数据集数量:" << vecManager.getFeatureCount() << std::endl;
  201. if (databaseExists && vecManager.getFeatureCount() > 0)
  202. {
  203. std::cout << "数据库已存在 " << vecManager.getFeatureCount() << " 条特征记录" << std::endl;
  204. }
  205. else
  206. {
  207. std::cout << "重新构建数据库..." << std::endl;
  208. // 提取图库图片特征并向量
  209. std::cout << "开始提取图库图片特征..." << std::endl;
  210. int featureDimension = 0;
  211. int processedCount = 0;
  212. // 初始化数据库表结构
  213. vecManager.initializeDatabase(45); // 假设特征维度为1000,实际会在第一次处理时确定
  214. for (size_t i = 0; i < galleryImages.size(); ++i)
  215. {
  216. try
  217. {
  218. std::vector<float> features = extractor.extractFeatures(galleryImages[i]);
  219. if (!features.empty())
  220. {
  221. if (featureDimension == 0)
  222. {
  223. featureDimension = static_cast<int>(features.size());
  224. std::cout << "特征维度: " << featureDimension << std::endl;
  225. // 重新初始化数据库以匹配实际维度
  226. vecManager.initializeDatabase(featureDimension);
  227. }
  228. vecManager.addFeatureVector(features, galleryImages[i]);
  229. processedCount++;
  230. std::cout << "[" << (i + 1) << "/" << galleryImages.size() << "] 已处理: "
  231. << galleryImages[i] << std::endl;
  232. }
  233. }
  234. catch (const std::exception & e)
  235. {
  236. std::cerr << "处理失败 [" << (i + 1) << "]: " << galleryImages[i]
  237. << " (" << e.what() << ")" << std::endl;
  238. }
  239. }
  240. auto end_time = std::chrono::high_resolution_clock::now();
  241. auto duration = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time);
  242. std::cout << "特征提取完成,耗时: " << duration.count() << " 秒" << std::endl;
  243. std::cout << "成功处理图片数量: " << processedCount << std::endl;
  244. std::cout << "数据库记录数量: " << vecManager.getFeatureCount() << std::endl;
  245. if (processedCount == 0)
  246. {
  247. std::cerr << "没有成功提取任何图片特征!" << std::endl;
  248. return -1;
  249. }
  250. // 保存数据库
  251. vecManager.saveDatabase();
  252. std::cout << "数据库保存完成" << std::endl;
  253. }
  254. // 提取查询图片特征
  255. std::cout << "正在提取查询图片特征: " << searchImagePath << std::endl;
  256. std::vector<float> queryFeatures = extractor.extractFeatures(searchImagePath);
  257. if (queryFeatures.empty())
  258. {
  259. std::cerr << "无法提取查询图片特征!" << std::endl;
  260. return -1;
  261. }
  262. std::cout << "查询图片特征提取完成,维度: " << queryFeatures.size() << std::endl;
  263. // 进行相似性搜索
  264. std::cout << "正在进行相似性搜索..." << std::endl;
  265. std::vector<std::pair<std::string, float>> searchResults = vecManager.searchSimilarVectors(queryFeatures, 5);
  266. // 显示搜索结果
  267. std::cout << "\n=== 搜索结果 ===" << std::endl;
  268. std::cout << "查询图片: " << searchImagePath << std::endl;
  269. std::cout << "最相似的图片:" << std::endl;
  270. for (size_t i = 0; i < searchResults.size(); ++i)
  271. {
  272. std::cout << (i + 1) << ". " << searchResults[i].first
  273. << " (相似度: " << searchResults[i].second << ")" << std::endl;
  274. }
  275. std::cout << "\n图像检索完成!" << std::endl;
  276. }
  277. catch (const std::exception & e)
  278. {
  279. std::string a = e.what();
  280. std::cerr << "程序执行出错: " << e.what() << std::endl;
  281. return -1;
  282. }
  283. sqlite3_sleep(100000);
  284. return 0;
  285. }