Browse Source

AI学习的功能完成了

张洋 23 hours ago
parent
commit
b4e6bcccbc

+ 15 - 6
zhipuzi_pos_windows/ai/SQLiteVecManager.cpp

@@ -135,9 +135,9 @@ bool SQLiteVecManager::addFeatureVector(const std::vector<float>& features, cons
 	return rc == SQLITE_DONE;
 }
 
-std::vector<std::pair<std::string, float>> SQLiteVecManager::searchSimilarVectors(const std::vector<float> & queryVector, int k)
+std::vector<FeatureRecord> SQLiteVecManager::searchSimilarVectors(const std::vector<float> & queryVector, int k)
 {
-	std::vector<std::pair<std::string, float>> results;
+	std::vector<FeatureRecord> results;
 
 	std::cout << "使用sqlite-vec扩展进行向量搜索" << std::endl;
 
@@ -162,17 +162,26 @@ std::vector<std::pair<std::string, float>> SQLiteVecManager::searchSimilarVector
 
 	while ((rc = sqlite3_step(stmt)) == SQLITE_ROW)
 	{
-		const char * imagePath = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 0));
+		const char* food_id = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 0));
+		const char* food_name = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
+		const char* image_name = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
+		const char* image_path = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 3));
 
-		float distance = static_cast<float>(sqlite3_column_double(stmt, 1));
+		float distance = static_cast<float>(sqlite3_column_double(stmt, 4));
 		float similarity = distanceToSimilarity(distance);
 
-		results.emplace_back(std::string(imagePath), similarity);
+		FeatureRecord record;
+		record.foodId = food_id ? food_id : "";
+		record.foodName = food_name ? food_name : "";
+		record.imageName = image_name ? image_name : "";
+		record.imagePath = image_path ? image_path : "";
+		record.similarity = similarity;
+
+		results.push_back(record);
 	}
 
 	sqlite3_finalize(stmt);
 
-
 	return results;
 }
 

+ 10 - 1
zhipuzi_pos_windows/ai/SQLiteVecManager.h

@@ -5,6 +5,15 @@
 #include "../sqlite3/sqlite3.h"
 #include "../sqlite3/sqlite-vec.h"
 
+struct FeatureRecord
+{
+	std::string foodId;
+	std::string foodName; //注意这个是用户输入的商品名字,不是从数据库里查询出来的名字,这个名字是用来展示给用户看的,方便用户知道这个图片对应的是什么商品的
+	std::string imageName;
+	std::string imagePath;
+	float similarity;
+};
+
 class SQLiteVecManager
 {
 private:
@@ -25,7 +34,7 @@ public:
 	bool initializeDatabase(int vectorDimension);
 
 	bool addFeatureVector(const std::vector<float> & features, const std::string& foodId, const std::string& foodName, const std::string& imageName, const std::string & imagePath);
-	std::vector<std::pair<std::string, float>> searchSimilarVectors(const std::vector<float> & queryVector, int k = 5);
+	std::vector<FeatureRecord> searchSimilarVectors(const std::vector<float> & queryVector, int k = 5);
 	void saveDatabase();
 	bool loadDatabase();
 	int getFeatureCount() const;

+ 10 - 29
zhipuzi_pos_windows/ai/YoloFeatureManager.cpp

@@ -117,14 +117,20 @@ std::string YoloFeatureManager::getClassName(std::size_t classId) const
 
 std::vector<float> YoloFeatureManager::extractFeatures(const std::string & imagePath)
 {
+	cv::Mat image = cv::imread(imagePath);
+
+	return extractFeatures(image);
+}
+
+std::vector<float> YoloFeatureManager::extractFeatures(cv::Mat& image)
+{
 	try
 	{
 		auto time_1 = std::chrono::high_resolution_clock::now();
 
-		cv::Mat image = cv::imread(imagePath);
 		if (image.empty())
 		{
-			throw std::runtime_error("Could not load image: " + imagePath);
+			throw std::runtime_error("Could not load image");
 		}
 
 		// 转换为blob(归一化+通道转换)
@@ -162,7 +168,7 @@ std::vector<float> YoloFeatureManager::extractFeatures(const std::string & image
 
 		// 将Mat格式的特征转换为vector<float>(方便后续计算/存储)
 		std::vector<float> feature_vector;
-		feature_vector.assign((float *)featureMat.data, (float *)featureMat.data + featureMat.total());
+		feature_vector.assign((float*)featureMat.data, (float*)featureMat.data + featureMat.total());
 
 		//进行时间统计
 		auto time_3 = std::chrono::high_resolution_clock::now();
@@ -181,7 +187,7 @@ std::vector<float> YoloFeatureManager::extractFeatures(const std::string & image
 
 		return feature_vector;
 	}
-	catch (const std::exception & e)
+	catch (const std::exception& e)
 	{
 		std::string aa = std::string(e.what());
 		DEBUG_LOG(("提取特征失败: " + std::string(e.what())).c_str());
@@ -365,31 +371,6 @@ std::string YoloFeatureManager::Class(cv::Mat & image)
 	
 }
 
-std::string YoloFeatureManager::ClassFromVideoCapture()
-{
-	try
-	{
-		cv::Mat image;
-		CVideoCaptureWorker::GetInstance()->GetFrame(image);
-		if (image.empty())
-		{
-			//DEBUG_LOG("从摄像头获取帧失败");
-			return "Unknown";
-		}
-
-		std::string name = this->Class(image);
-
-		return name;
-	}
-	catch (const std::exception& e)
-	{
-		std::string aa = std::string(e.what());
-		DEBUG_LOG(("YOLO分类失败: " + std::string(e.what())).c_str());
-		return {};
-	}
-	
-}
-
 void YoloFeatureManager::drawChineseText(cv::Mat & img, const wchar_t * text, cv::Point pos, cv::Scalar color, int fontSize)
 {
 	// 1. 检查输入有效性

+ 2 - 2
zhipuzi_pos_windows/ai/YoloFeatureManager.h

@@ -36,6 +36,8 @@ public:
 	// 提取特征
 	std::vector<float> extractFeatures(const std::string & imagePath);
 
+	std::vector<float> extractFeatures(cv::Mat& image);
+
 	// 执行探测
 	void Detection(const std::string& imagePath);
 
@@ -46,8 +48,6 @@ public:
 	//根据摄像头读取的帧,识别出对应类别的name(英文的)
 	std::string Class(cv::Mat& image);
 
-	std::string ClassFromVideoCapture();
-
 private:
 	// 寻找置信度最高的类别
 	int getTopClass(const cv::Mat& output);

+ 1 - 3
zhipuzi_pos_windows/page/CDiandanPageUI.cpp

@@ -326,10 +326,8 @@ void CDiandanPageUI::InitFoodShow()
 			return;
 		}
 
-		std::string a = CLewaimaiString::ANSIToUTF8(ai_name);
-
 		CSqlite3 sqlite;
-		std::vector<CFood> foodlist = sqlite.GetFoodByFoodname(a);
+		std::vector<CFood> foodlist = sqlite.GetFoodByFoodname(ai_name);
 
 		for (std::vector<CFood>::iterator it = foodlist.begin(); it != foodlist.end(); it++)
 		{

+ 67 - 6
zhipuzi_pos_windows/worker/CDiandanAIShibieWorker.cpp

@@ -2,6 +2,9 @@
 #include "CDiandanAIShibieWorker.h"
 #include "CChengzhongWorker.h"
 #include "../tool/CAppEnv.h"
+#include "../ai/SQLiteVecManager.h"
+#include "CVideoCaptureWorker.h"
+#include "../tool/CSqlite3.h"
 
 CDiandanAIShibieWorker::CDiandanAIShibieWorker()
 {
@@ -10,7 +13,7 @@ CDiandanAIShibieWorker::CDiandanAIShibieWorker()
 
 CDiandanAIShibieWorker::~CDiandanAIShibieWorker()
 {
-	int a = 1;
+	
 }
 
 //启动工作线程
@@ -18,7 +21,6 @@ void CDiandanAIShibieWorker::StartWork()
 {
 	long int threadId = GetCurrentThreadId();
 
-	//默认先关闭
 	m_is_work = true;
 
 	//创建一个新线程,专门处理AI识别的结果,避免因为AI识别的结果处理比较慢,导致界面卡顿
@@ -48,7 +50,7 @@ void CDiandanAIShibieWorker::HandleDiandanAIShibie()
 			}
 
 			float weight = atof(CChengzhongWorker::GetInstance()->GetWeight().c_str());
-			if (weight < 0.01)
+			if (weight < -0.01)
 			{
 				//说明没有重量,没放东西到秤上面,那么就不识别
 				Sleep(100);
@@ -58,23 +60,82 @@ void CDiandanAIShibieWorker::HandleDiandanAIShibie()
 
 			auto time_1 = std::chrono::high_resolution_clock::now();
 
-			m_ai_shibie_foodname = YoloFeatureManager::GetInstance()->ClassFromVideoCapture();
+			cv::Mat image;
+			CVideoCaptureWorker::GetInstance()->GetFrame(image);
+			if (image.empty())
+			{
+				//DEBUG_LOG("从摄像头获取帧失败");
+				Sleep(2000);
 
+				continue;
+			}
+
+			m_ai_shibie_foodname = YoloFeatureManager::GetInstance()->Class(image);
 			if (m_ai_shibie_foodname != "Unknown")
 			{
 				std::cout << "检测到类别: " << m_ai_shibie_foodname << std::endl;
+
+				m_ai_shibie_foodname = CLewaimaiString::ANSIToUTF8(m_ai_shibie_foodname);
 			}
 			else
 			{
 				std::cout << "未检测到任何类别。" << std::endl;
+
+				//开始调用向量数据库检索
+				if (SQLiteVecManager::GetInstance()->getFeatureCount() > 0)
+				{
+					std::vector<float> feature_vector = YoloFeatureManager::GetInstance()->extractFeatures(image);
+					std::vector<FeatureRecord> searchResults = SQLiteVecManager::GetInstance()->searchSimilarVectors(feature_vector, 5);
+
+					if (!searchResults.empty())
+					{
+						std::cout << "向量数据库检索结果:" << std::endl;
+						for (const auto& result : searchResults)
+						{
+							std::string food_id = result.foodId;
+							std::string food_name = result.foodName;
+							std::string image_name = result.imageName;
+							std::string image_path = result.imagePath;
+							float similarity = result.similarity;
+
+							std::cout << "食品ID: " << food_id << ", 食品名称: " << food_name << ", 图片名称: " << image_name << ", 图片路径: " << image_path << ", 相似度: " << similarity << std::endl;
+
+							CSqlite3 sqlite;
+
+							CFood newFood;
+							bool ret = sqlite.GetFoodById(food_id, newFood);
+							if (!ret)
+							{
+								std::cout << "该相似的商品已被删除" << std::endl;
+								Sleep(100);
+
+								continue;
+							}
+
+							//UTF8格式,sqlite里面存的都是UTF8格式的字符串
+							m_ai_shibie_foodname = newFood.name;
+						}
+					}
+					else
+					{
+						std::cout << "向量数据库中没有相似的特征。" << std::endl;
+						Sleep(100);
+
+						continue;
+					}
+				}
+				else
+				{
+					std::cout << "向量数据库中没有任何特征。" << std::endl;
+					Sleep(100);
+					continue;
+				}
 			}
 
 			auto time_2 = std::chrono::high_resolution_clock::now();
 
 			auto duration_1 = std::chrono::duration_cast<std::chrono::milliseconds>(time_2 - time_1);
 			std::wstring msg = L"all time: " + std::to_wstring(duration_1.count()) + L" 毫秒";
-			//DEBUG_LOG(msg.c_str());
-			//LOG_INFO(msg);
 
 			//主线程里面去处理界面刷新
 			if (m_hwnd != NULL)