YoloFeatureManager.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. #include "../pch/pch.h"
  2. #include "YoloFeatureManager.h"
  3. #include <fstream>
  4. #include <algorithm>
  5. #include <iostream>
  6. #include <functional>
  7. #include <numeric>
  8. #include <sstream>
  9. #include "../tool/debuglog.h"
  10. #include "../worker/CVideoCaptureWorker.h"
  11. #include "YoloClassName.h"
  12. YoloFeatureManager::YoloFeatureManager()
  13. {
  14. inputWidth = 320;
  15. inputHeight = 320;
  16. }
  17. YoloFeatureManager::~YoloFeatureManager()
  18. {
  19. }
  20. void YoloFeatureManager::loadModel(const std::string & modelPath)
  21. {
  22. try
  23. {
  24. net = cv::dnn::readNetFromONNX(modelPath);
  25. CONF_THRESHOLD = 0.5f; // 可以根据需要调整置信度阈值
  26. NMS_THRESHOLD = 0.4f; // 可以根据需要调整NMS阈值
  27. FRUIT_VEGETABLE_COUNT = sizeof(FRUIT_VEGETABLE_NAMES) / sizeof(FRUIT_VEGETABLE_NAMES[0]);
  28. }
  29. catch (const std::exception& e)
  30. {
  31. std::string aa = std::string(e.what());
  32. DEBUG_LOG(("加载模型失败: " + std::string(e.what())).c_str());
  33. return;
  34. }
  35. }
  36. void YoloFeatureManager::loadModel(const std::string& modelPath, const std::string& configPath)
  37. {
  38. try
  39. {
  40. net = cv::dnn::readNetFromModelOptimizer(modelPath, configPath);
  41. // 设置目标设备 (可选: CPU, GPU, MYRIAD等)
  42. net.setPreferableBackend(cv::dnn::DNN_BACKEND_INFERENCE_ENGINE);
  43. net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU); // 或DNN_TARGET_MYRIAD等
  44. CONF_THRESHOLD = 0.5f; // 可以根据需要调整置信度阈值
  45. NMS_THRESHOLD = 0.4f; // 可以根据需要调整NMS阈值
  46. FRUIT_VEGETABLE_COUNT = sizeof(FRUIT_VEGETABLE_NAMES) / sizeof(FRUIT_VEGETABLE_NAMES[0]);
  47. }
  48. catch (const std::exception& e)
  49. {
  50. std::string aa = std::string(e.what());
  51. DEBUG_LOG(("加载模型失败: " + std::string(e.what())).c_str());
  52. return;
  53. }
  54. }
  55. // 寻找置信度最高的类别
  56. int YoloFeatureManager::getTopClass(const cv::Mat& output)
  57. {
  58. // 将输出展平为一维数组
  59. cv::Mat flatOutput = output.reshape(1, 1);
  60. double maxVal;
  61. cv::Point maxLoc;
  62. // 找到最大值的位置(即最高置信度类别索引)
  63. cv::minMaxLoc(flatOutput, nullptr, &maxVal, nullptr, &maxLoc);
  64. return maxLoc.x;
  65. }
  66. // 获取类别名称
  67. std::string YoloFeatureManager::getClassName(std::size_t classId) const
  68. {
  69. if (classId >= 0 && classId < FRUIT_VEGETABLE_COUNT)
  70. {
  71. std::string englishName = FRUIT_VEGETABLE_NAMES[classId];
  72. // 这里可以添加一个映射表,将英文名称映射到中文名称
  73. auto it = FRUIT_VEGETABLE_CN_MAP.find(englishName);
  74. if (it != FRUIT_VEGETABLE_CN_MAP.end())
  75. {
  76. return it->second; // 找到则返回中文名称
  77. }
  78. else
  79. {
  80. return "Unknown"; // 未找到则返回默认值
  81. }
  82. }
  83. return "Unknown";
  84. }
  85. std::vector<float> YoloFeatureManager::extractFeatures(const std::string & imagePath)
  86. {
  87. try
  88. {
  89. auto time_1 = std::chrono::high_resolution_clock::now();
  90. cv::Mat image = cv::imread(imagePath);
  91. if (image.empty())
  92. {
  93. throw std::runtime_error("Could not load image: " + imagePath);
  94. }
  95. // 转换为blob(归一化+通道转换)
  96. cv::Mat blob;
  97. cv::dnn::blobFromImage(image, blob, 1.0 / 255, cv::Size(inputWidth, inputHeight), cv::Scalar(0, 0, 0), true, false);
  98. net.setInput(blob);
  99. auto time_2 = std::chrono::high_resolution_clock::now();
  100. //获取模型的所有层名称(调试用)
  101. //std::vector<cv::String> layerNames = net.getLayerNames();
  102. // 获取Flatten层输出(yolo26s-cls的Flatten层名称为 "onnx_node!/model.10/Flatten",这是GAP后分类头前的一层)'
  103. // GAP层是onnx_node!/model.10/pool/GlobalAveragePool
  104. cv::Mat featureMat = net.forward("onnx_node!/model.10/Flatten");
  105. // 检查输出是否有效
  106. if (featureMat.empty())
  107. {
  108. throw std::runtime_error("模型前向传播未产生有效输出");
  109. }
  110. if (featureMat.type() != CV_32F)
  111. {
  112. throw std::runtime_error("Mat类型错误");
  113. }
  114. float norm_before = cv::norm(featureMat, cv::NORM_L2);
  115. DEBUG_HELPER::debug_printf("归一化前 norm:%.6f\n", norm_before);
  116. cv::normalize(featureMat, featureMat, 1.0, 0.0, cv::NORM_L2); //L2归一化
  117. float norm_after = cv::norm(featureMat, cv::NORM_L2);
  118. DEBUG_HELPER::debug_printf("归一化后 norm:%.6f\n", norm_after);
  119. // 将Mat格式的特征转换为vector<float>(方便后续计算/存储)
  120. std::vector<float> feature_vector;
  121. feature_vector.assign((float *)featureMat.data, (float *)featureMat.data + featureMat.total());
  122. //进行时间统计
  123. auto time_3 = std::chrono::high_resolution_clock::now();
  124. auto duration_1 = std::chrono::duration_cast<std::chrono::milliseconds>(time_2 - time_1);
  125. std::wstring msg = L"图片处理耗时: " + std::to_wstring(duration_1.count()) + L" 毫秒";
  126. DEBUG_LOG(msg.c_str());
  127. auto duration_2 = std::chrono::duration_cast<std::chrono::milliseconds>(time_3 - time_2);
  128. std::wstring msg2 = L"模型推理耗时: " + std::to_wstring(duration_2.count()) + L" 毫秒";
  129. DEBUG_LOG(msg2.c_str());
  130. auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(time_3 - time_1);
  131. std::wstring msg4 = L"总耗时: " + std::to_wstring(totalDuration.count()) + L" 毫秒";
  132. DEBUG_LOG(msg4.c_str());
  133. return feature_vector;
  134. }
  135. catch (const std::exception & e)
  136. {
  137. std::string aa = std::string(e.what());
  138. DEBUG_LOG(("提取特征失败: " + std::string(e.what())).c_str());
  139. return {};
  140. }
  141. }
  142. void YoloFeatureManager::Detection(const std::string & imagePath)
  143. {
  144. cv::Mat image = cv::imread(imagePath);
  145. if (image.empty())
  146. {
  147. throw std::runtime_error("Could not load image: " + imagePath);
  148. }
  149. // 构造输入blob(图像预处理)
  150. // 参数说明:输入图像、缩放因子、输入尺寸、均值归一化、是否交换RB通道、是否裁剪
  151. cv::Mat blob;
  152. cv::dnn::blobFromImage(image, blob, 1.0 / 255, cv::Size(inputWidth, inputHeight), cv::Scalar(0, 0, 0), true, false);
  153. // -------------------------- 4. 模型推理 --------------------------
  154. // 设置网络输入
  155. net.setInput(blob);
  156. // 获取输出层名称
  157. std::vector<std::string> outLayerNames = net.getUnconnectedOutLayersNames();
  158. // 前向推理
  159. std::vector<cv::Mat> outs;
  160. net.forward(outs, outLayerNames);
  161. // -------------------------- 5. 解析推理结果 --------------------------
  162. std::vector<cv::Rect> boxes; // 检测框
  163. std::vector<int> classIds; // 类别ID
  164. std::vector<float> confidences; // 置信度
  165. // 遍历所有输出层的结果
  166. for (const cv::Mat & out : outs)
  167. {
  168. float * data = (float *)out.data;
  169. // 遍历每个检测结果
  170. for (int i = 0; i < out.rows; i++, data += out.cols)
  171. {
  172. // 获取类别置信度
  173. cv::Mat scores = out.row(i).colRange(5, out.cols);
  174. cv::Point classIdPoint;
  175. double confidence;
  176. // 找到最大置信度对应的类别
  177. cv::minMaxLoc(scores, 0, &confidence, 0, &classIdPoint);
  178. // 过滤低置信度结果
  179. if (confidence > CONF_THRESHOLD)
  180. {
  181. // 解析检测框坐标(YOLO输出的是相对坐标,需转换为绝对坐标)
  182. int centerX = (int)(data[0] * image.cols);
  183. int centerY = (int)(data[1] * image.rows);
  184. int width = (int)(data[2] * image.cols);
  185. int height = (int)(data[3] * image.rows);
  186. // 计算检测框左上角坐标
  187. int left = centerX - width / 2;
  188. int top = centerY - height / 2;
  189. // 保存结果
  190. boxes.push_back(cv::Rect(left, top, width, height));
  191. classIds.push_back(classIdPoint.x);
  192. confidences.push_back((float)confidence);
  193. }
  194. }
  195. }
  196. // -------------------------- 6. 非极大值抑制(NMS) --------------------------
  197. std::vector<int> indices;
  198. cv::dnn::NMSBoxes(boxes, confidences, CONF_THRESHOLD, NMS_THRESHOLD, indices);
  199. // 提取NMS后的结果
  200. std::vector<cv::Rect> finalBoxes;
  201. std::vector<int> finalClassIds;
  202. std::vector<float> finalConfidences;
  203. for (int idx : indices)
  204. {
  205. finalBoxes.push_back(boxes[idx]);
  206. finalClassIds.push_back(classIds[idx]);
  207. finalConfidences.push_back(confidences[idx]);
  208. }
  209. // -------------------------- 7. 绘制并显示结果 --------------------------
  210. drawDetection(image, finalBoxes, finalClassIds, finalConfidences);
  211. // 显示检测结果
  212. cv::imshow("YOLO Detection Result", image);
  213. // 保存检测结果
  214. cv::imwrite("result.jpg", image);
  215. cv::waitKey(0);
  216. cv::destroyAllWindows();
  217. }
  218. // 绘制检测结果
  219. void YoloFeatureManager::drawDetection(cv::Mat & img, const std::vector<cv::Rect> & boxes, const std::vector<int> & classIds,
  220. const std::vector<float> & confidences)
  221. {
  222. // 生成随机颜色(每个类别一种颜色)
  223. std::vector<cv::Scalar> colors;
  224. srand(time(0));
  225. for (std::size_t i = 0; i < FRUIT_VEGETABLE_COUNT; i++)
  226. {
  227. int r = rand() % 256;
  228. int g = rand() % 256;
  229. int b = rand() % 256;
  230. colors.push_back(cv::Scalar(r, g, b));
  231. }
  232. // 绘制每个检测框
  233. for (size_t i = 0; i < boxes.size(); i++)
  234. {
  235. cv::Rect box = boxes[i];
  236. // 绘制矩形框
  237. cv::rectangle(img, box, colors[classIds[i]], 2);
  238. // 构造标签文本(类别 + 置信度)
  239. std::string label = FRUIT_VEGETABLE_NAMES[classIds[i]] + ": " + std::to_string(confidences[i]).substr(0, 4);
  240. int baseLine;
  241. cv::Size labelSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  242. // 绘制标签背景
  243. cv::rectangle(img, cv::Point(box.x, box.y - labelSize.height),
  244. cv::Point(box.x + labelSize.width, box.y + baseLine),
  245. colors[classIds[i]], cv::FILLED);
  246. // 绘制标签文本
  247. cv::putText(img, label, cv::Point(box.x, box.y), cv::FONT_HERSHEY_SIMPLEX,
  248. 0.5, cv::Scalar(255, 255, 255), 1);
  249. }
  250. }
  251. std::string YoloFeatureManager::Class(cv::Mat & image)
  252. {
  253. try
  254. {
  255. std::string className = "";
  256. // ====================== 图像预处理 ======================
  257. // 转换为blob格式:归一化(0-1)、通道转换(BGR->RGB)、调整尺寸
  258. cv::Mat blob;
  259. cv::dnn::blobFromImage(image, blob, 1.0 / 255.0, cv::Size(inputWidth, inputHeight), cv::Scalar(0, 0, 0), true, false);
  260. net.setInput(blob);
  261. // ====================== 模型推理 ======================
  262. cv::Mat output = net.forward(); // 输出形状:1x1000(对应ImageNet 1000类)
  263. // ====================== 解析结果 ======================
  264. int topClassIdx = this->getTopClass(output);
  265. float topConfidence = output.at<float>(topClassIdx);
  266. // 只显示置信度高于阈值的结果
  267. if (topConfidence > 0.8)
  268. {
  269. className = this->getClassName(topClassIdx);
  270. }
  271. else
  272. {
  273. className = "Unknown";
  274. }
  275. // 在画面上绘制分类结果
  276. //std::wstring resultText = CLewaimaiString::ANSIToUnicode(className) + L" : " + std::to_wstring(round(topConfidence * 10000) / 100) + L"%";
  277. //this->drawChineseText(image, resultText.c_str(), cv::Point(20, 50), cv::Scalar(0, 255, 0), 24);
  278. //cv::imshow("yolo26n-cls 实时图像分类", image);
  279. //if (cv::waitKey(30) >= 0); // 按任意键退出
  280. return className;
  281. }
  282. catch (const std::exception& e)
  283. {
  284. std::string aa = std::string(e.what());
  285. DEBUG_LOG(("YOLO分类失败: " + std::string(e.what())).c_str());
  286. return {};
  287. }
  288. }
  289. std::string YoloFeatureManager::ClassFromVideoCapture()
  290. {
  291. try
  292. {
  293. cv::Mat image;
  294. CVideoCaptureWorker::GetInstance()->GetFrame(image);
  295. if (image.empty())
  296. {
  297. //DEBUG_LOG("从摄像头获取帧失败");
  298. return "Unknown";
  299. }
  300. std::string name = this->Class(image);
  301. return name;
  302. }
  303. catch (const std::exception& e)
  304. {
  305. std::string aa = std::string(e.what());
  306. DEBUG_LOG(("YOLO分类失败: " + std::string(e.what())).c_str());
  307. return {};
  308. }
  309. }
  310. void YoloFeatureManager::drawChineseText(cv::Mat & img, const wchar_t * text, cv::Point pos, cv::Scalar color, int fontSize)
  311. {
  312. // 1. 检查输入有效性
  313. if (img.empty() || text == nullptr || wcslen(text) == 0)
  314. {
  315. return;
  316. }
  317. if (img.type() != CV_8UC3)
  318. {
  319. // 仅支持 3 通道彩色图像
  320. cvtColor(img, img, cv::COLOR_GRAY2BGR);
  321. }
  322. // 2. 创建内存 DC 并关联临时位图(关键:基于图像的 DC 创建,而非屏幕 DC)
  323. HDC hScreenDC = GetDC(NULL);
  324. HDC hMemDC = CreateCompatibleDC(hScreenDC);
  325. // 创建与原图像尺寸、格式匹配的位图
  326. HBITMAP hMemBmp = CreateCompatibleBitmap(hScreenDC, img.cols, img.rows);
  327. // 保存原始位图句柄,用于后续恢复
  328. HBITMAP hOldBmp = (HBITMAP)SelectObject(hMemDC, hMemBmp);
  329. // 3. 将 OpenCV 图像数据复制到内存位图(保留原图像内容,而非黑色)
  330. BITMAPINFO bmi = { 0 };
  331. bmi.bmiHeader.biSize = sizeof(BITMAPINFOHEADER);
  332. bmi.bmiHeader.biWidth = img.cols;
  333. bmi.bmiHeader.biHeight = -img.rows; // 翻转 Y 轴(OpenCV 与 GDI 坐标方向相反)
  334. bmi.bmiHeader.biPlanes = 1;
  335. bmi.bmiHeader.biBitCount = 24;
  336. bmi.bmiHeader.biCompression = BI_RGB;
  337. // 将 OpenCV 图像写入内存位图
  338. SetDIBits(hScreenDC, hMemBmp, 0, img.rows, img.data, &bmi, DIB_RGB_COLORS);
  339. // 4. 设置中文字体(修复字体创建参数,增加容错)
  340. HFONT hFont = CreateFont(
  341. fontSize, 0, 0, 0, FW_NORMAL, 0, 0, 0,
  342. GB2312_CHARSET, OUT_DEFAULT_PRECIS, CLIP_DEFAULT_PRECIS,
  343. DEFAULT_QUALITY, DEFAULT_PITCH | FF_DONTCARE, L"黑体"
  344. );
  345. HFONT hOldFont = (HFONT)SelectObject(hMemDC, hFont);
  346. // 5. 设置文字绘制属性(背景透明、颜色正确)
  347. SetBkMode(hMemDC, TRANSPARENT);
  348. // OpenCV 是 BGR,GDI 是 RGB,需转换
  349. SetTextColor(hMemDC, RGB((int)color[2], (int)color[1], (int)color[0]));
  350. // 6. 绘制中文字符(确保坐标在图像范围内)
  351. int textLen = wcslen(text);
  352. if (pos.x >= 0 && pos.y >= 0 && pos.x < img.cols && pos.y < img.rows)
  353. {
  354. TextOutW(hMemDC, pos.x, pos.y, text, textLen);
  355. }
  356. // 7. 将绘制后的位图数据复制回 OpenCV 图像
  357. GetDIBits(hScreenDC, hMemBmp, 0, img.rows, img.data, &bmi, DIB_RGB_COLORS);
  358. // 8. 释放资源(关键:恢复原始句柄后再删除,避免内存泄漏)
  359. SelectObject(hMemDC, hOldFont); // 恢复原始字体
  360. DeleteObject(hFont); // 删除自定义字体
  361. SelectObject(hMemDC, hOldBmp); // 恢复原始位图
  362. DeleteObject(hMemBmp); // 删除内存位图
  363. DeleteDC(hMemDC); // 删除内存 DC
  364. ReleaseDC(NULL, hScreenDC); // 释放屏幕 DC
  365. }