| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- #include "../pch/pch.h"
- #include "test.h"
- #include "YoloFeatureManager.h"
- #include "ImageProcessor.h"
- #include "SQLiteVecManager.h"
- #include "../tool/CSetting.h"
- #include "../tool/debuglog.h"
- #include <opencv2/opencv.hpp>
- #include "../tool/debuglog.h"
- /*
- int AITest()
- {
- try
- {
- // 设置路径(可根据实际情况修改)
- std::wstring wsExePath = CSystem::getExePath();
- std::wstring wsProgramDir = CSystem::GetProgramDir();
- std::filesystem::path mainDir = wsProgramDir;
- std::string sMainDir = mainDir.string();
- //用于测试的图片目录
- std::string galleryDir = (mainDir.parent_path().parent_path().parent_path().parent_path() /"res"/"images").string(); // 图库目录路径
- std::string modelPath = sMainDir + "/ai/best_448.onnx"; // YOLO2026模型路径
- std::string searchImagePath = sMainDir + "/3.jpg"; // 搜索图片路径
- std::string databasePath = sMainDir + "/image_features.db"; // SQLite数据库路径
- std::filesystem::remove(databasePath);
- std::cout << "=== YOLO2026图像检索系统 (SQLite-Vec版本) ===" << std::endl;
- std::cout << "模型路径: " << modelPath << std::endl;
- std::cout << "图库路径: " << galleryDir << std::endl;
- std::cout << "查询图片: " << searchImagePath << std::endl;
- std::cout << "数据库路径: " << databasePath << std::endl;
- std::cout << "=========================================" << std::endl;
- // 初始化特征提取器
- std::cout << "正在初始化YOLO2026特征提取器..." << std::endl;
- YoloFeatureManager extractor;
- extractor.loadModel(modelPath);
- // 获取图库中的所有图片
- std::cout << "正在扫描图库中的图片..." << std::endl;
- std::vector<std::string> galleryImages = ImageProcessor::getImagesInDirectory(galleryDir);
- std::cout << "找到 " << galleryImages.size() << " 张图片" << std::endl;
- if (galleryImages.empty())
- {
- std::cerr << "图库中没有找到有效的图片文件!" << std::endl;
- return -1;
- }
- // 初始化SQLite-Vec数据库
- std::cout << "正在初始化SQLite-Vec数据库..." << std::endl;
- SQLiteVecManager vecManager(databasePath);
- auto start_time = std::chrono::high_resolution_clock::now();
- // 检查数据库是否已存在数据
- bool databaseExists = vecManager.loadDatabase();
- std::cout << "数据库加载结果:" << std::boolalpha << databaseExists << std::endl;
- std::cout << "数据集数量:" << vecManager.getFeatureCount() << std::endl;
- if (databaseExists && vecManager.getFeatureCount() > 0)
- {
- std::cout << "数据库已存在 " << vecManager.getFeatureCount() << " 条特征记录" << std::endl;
- }
- else
- {
- std::cout << "重新构建数据库..." << std::endl;
- // 提取图库图片特征并向量
- std::cout << "开始提取图库图片特征..." << std::endl;
- int featureDimension = 0;
- int processedCount = 0;
- for (size_t i = 0; i < galleryImages.size(); ++i)
- {
- try
- {
- std::vector<float> features = extractor.extractFeatures(galleryImages[i]);
- if (!features.empty())
- {
- if (featureDimension == 0)
- {
- featureDimension = static_cast<int>(features.size());
- DEBUG_HELPER::debug_printf("特征维度: %d\n", featureDimension);
- // 初始化数据库表结构
- vecManager.initializeDatabase(featureDimension);
- }
- vecManager.addFeatureVector(features, galleryImages[i]);
- processedCount++;
- DEBUG_HELPER::debug_printf("已处理 [%d/%d]: %s\n", (i + 1), static_cast<int>(galleryImages.size()), galleryImages[i].c_str());
- }
- }
- catch (const std::exception & e)
- {
- DEBUG_HELPER::debug_printf("处理失败 [%d/%d]: %s (%s)\n", (i + 1), static_cast<int>(galleryImages.size()), galleryImages[i].c_str(), e.what());
- }
- }
- auto end_time = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time);
- std::cout << "特征提取完成,耗时: " << duration.count() << " 秒" << std::endl;
- std::cout << "成功处理图片数量: " << processedCount << std::endl;
- std::cout << "数据库记录数量: " << vecManager.getFeatureCount() << std::endl;
- if (processedCount == 0)
- {
- std::cerr << "没有成功提取任何图片特征!" << std::endl;
- return -1;
- }
- // 保存数据库
- vecManager.saveDatabase();
- std::cout << "数据库保存完成" << std::endl;
- }
- // 提取查询图片特征
- std::cout << "正在提取查询图片特征: " << searchImagePath << std::endl;
- std::vector<float> queryFeatures = extractor.extractFeatures(searchImagePath);
- if (queryFeatures.empty())
- {
- std::cerr << "无法提取查询图片特征!" << std::endl;
- return -1;
- }
- std::cout << "查询图片特征提取完成,维度: " << queryFeatures.size() << std::endl;
- // 进行相似性搜索
- std::cout << "正在进行相似性搜索..." << std::endl;
- std::vector<std::pair<std::string, float>> searchResults = vecManager.searchSimilarVectors(queryFeatures, 20);
- // 显示搜索结果
- std::cout << "\n=== 搜索结果 ===" << std::endl;
- std::cout << "查询图片: " << searchImagePath << std::endl;
- std::cout << "最相似的图片:" << std::endl;
- for (size_t i = 0; i < searchResults.size(); ++i)
- {
- std::cout << (i + 1) << ". " << searchResults[i].first
- << " (相似度: " << searchResults[i].second << ")" << std::endl;
- }
- std::cout << "\n图像检索完成!" << std::endl;
- }
- catch (const std::exception & e)
- {
- std::string err = e.what();
- std::cerr << "程序执行出错: " << err << std::endl;
- return -1;
- }
- return 0;
- }
- */
|