使用SVM对多类多维数据进行分类

最近,本人要做个小东西,使用SVM对8类三维数据进行分类,搜索网上,发现大伙讨论的都是二维数据的二分类问题,遂决定自己研究一番。本人首先参考了opencv的tutorial,这也是二维数据的二分类问题。然后通过学习研究,发现别有洞天,遂实现之前的目标。在这里将代码贴出来,这里实现了对三维数据进行三类划分,以供大家相互学习。

  1. #include “stdafx.h”  
  2. #include <iostream>  
  3. #include <opencv2/core/core.hpp>  
  4. #include <opencv2/highgui/highgui.hpp>  
  5. #include <opencv2/ml/ml.hpp>  
  6.   
  7. using namespace cv;  
  8. using namespace std;  
  9.   
  10. int main()  
  11. {  
  12.   
  13.     //——————— 1. Set up training data randomly —————————————  
  14.     Mat trainData(100, 3, CV_32FC1);  
  15.     Mat labels   (100, 1, CV_32FC1);  
  16.   
  17.     RNG rng(100); // Random value generation class  
  18.   
  19.     // Generate random points for the class 1  
  20.     Mat trainClass = trainData.rowRange(0, 40);  
  21.     // The x coordinate of the points is in [0, 0.4)  
  22.     Mat c = trainClass.colRange(0, 1);  
  23.     rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));  
  24.     // The y coordinate of the points is in [0, 0.4)  
  25.     c = trainClass.colRange(1, 2);  
  26.     rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));  
  27.     // The z coordinate of the points is in [0, 0.4)  
  28.     c = trainClass.colRange(2, 3);  
  29.     rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));  
  30.   
  31.     // Generate random points for the class 2  
  32.     trainClass = trainData.rowRange(60, 100);  
  33.     // The x coordinate of the points is in [0.6, 1]  
  34.     c = trainClass.colRange(0, 1);  
  35.     rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));  
  36.     // The y coordinate of the points is in [0.6, 1)  
  37.     c = trainClass.colRange(1, 2);  
  38.     rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));  
  39.      // The z coordinate of the points is in [0.6, 1]  
  40.     c = trainClass.colRange(2, 3);  
  41.     rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));  
  42.   
  43.       
  44.   
  45.     // Generate random points for the classes 3  
  46.     trainClass = trainData.rowRange(  40, 60);  
  47.     // The x coordinate of the points is in [0.4, 0.6)  
  48.     c = trainClass.colRange(0,1);  
  49.     rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));  
  50.     // The y coordinate of the points is in [0.4, 0.6)  
  51.     c = trainClass.colRange(1,2);  
  52.     rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));  
  53.     // The z coordinate of the points is in [0.4, 0.6)  
  54.     c = trainClass.colRange(2,3);  
  55.     rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));  
  56.   
  57.   
  58.   
  59.     //————————- Set up the labels for the classes ———————————  
  60.     labels.rowRange( 0,  40).setTo(1);  // Class 1  
  61.     labels.rowRange(60, 100).setTo(2);  // Class 2  
  62.     labels.rowRange(40, 60).setTo(3);  // Class 3  
  63.   
  64.   
  65.     //———————— 2. Set up the support vector machines parameters ——————–  
  66.     CvSVMParams params;  
  67.     params.svm_type    = SVM::C_SVC;  
  68.     params.C           = 0.1;  
  69.     params.kernel_type = SVM::LINEAR;  
  70.     params.term_crit   = TermCriteria(CV_TERMCRIT_ITER, (int)1e7, 1e-6);  
  71.   
  72.     //———————— 3. Train the svm —————————————————-  
  73.     cout << ”Starting training process” << endl;  
  74.     CvSVM svm;  
  75.     svm.train(trainData, labels, Mat(), Mat(), params);  
  76.     cout << ”Finished training process” << endl;  
  77.   
  78.      Mat sampleMat = (Mat_<float>(1,3) << 50, 50,10);  
  79.      float response = svm.predict(sampleMat);  
  80.      cout<<response<<endl;  
  81.   
  82.      sampleMat = (Mat_<float>(1,3) << 50, 50,100);  
  83.      response = svm.predict(sampleMat);  
  84.      cout<<response<<endl;  
  85.   
  86.      sampleMat = (Mat_<float>(1,3) << 50, 50,60);  
  87.      response = svm.predict(sampleMat);  
  88.      cout<<response<<endl;  
  89.       
  90.     waitKey(0);  
  91. }  

使用SVM对多类多维数据进行分类