00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Component/ModelManager.H"
00039 #include "Learn/DecisionTree.H"
00040 #include "rutz/rand.h"
00041 #include "rutz/trace.h"
00042 #include "Util/SortUtil.H"
00043 #include "Util/Assert.H"
00044 #include <math.h>
00045 #include <fcntl.h>
00046 #include <limits>
00047 #include <string>
00048 #include <stdio.h>
00049
00050
00051
00052
00053
00054 void makeData(const std::vector<float>& dimMeanIn,const std::vector<float>& dimVarIn, const std::vector<float>& dimMeanOut,const std::vector<float>& dimVarOut, const uint sampleDim, std::vector<std::vector<float> >& data, std::vector<int>& labels);
00055
00056 int main(const int argc, const char **argv)
00057 {
00058
00059 MYLOGVERB = LOG_INFO;
00060 ModelManager manager("Test Decision Tree");
00061
00062
00063
00064 DecisionTree dt(3);
00065
00066 if (manager.parseCommandLine(
00067 (const int)argc, (const char**)argv, "", 0, 0) == false)
00068 return 0;
00069
00070 manager.start();
00071 uint nDim=4;
00072 std::vector<std::vector<float> > traindata(nDim);
00073 std::vector<int> trainlabels;
00074 std::vector<float> dimMeanIn(nDim), dimMeanOut(nDim), dimVarIn(nDim,1.0F), dimVarOut(nDim,1.0F);
00075 for(uint i=0;i<nDim;i++)
00076 {
00077 dimMeanIn[i] = nDim-i;
00078 dimMeanOut[i] = -(nDim-i);
00079 }
00080 makeData(dimMeanIn,dimVarIn,dimMeanOut,dimVarOut,50,traindata,trainlabels);
00081
00082 dt.train(traindata,trainlabels);
00083 dt.printTree();
00084 std::vector<int> trainResults = dt.predict(traindata);
00085
00086 int numCorrect=0;
00087 for(uint i=0;i<trainlabels.size();i++)
00088 {
00089 if(trainResults[i]==trainlabels[i]) numCorrect++;
00090 printf("Train Guess %d [Ground Truth %d]\n",trainResults[i],trainlabels[i]);
00091 }
00092 printf("Training Accuracy:[Correct/Total]=[%d/%Zu]:%f\n",numCorrect,trainlabels.size(),numCorrect/float(trainlabels.size()));
00093
00094 std::vector<std::vector<float> > testdata(nDim);
00095 std::vector<int> testlabels;
00096
00097 makeData(dimMeanIn,dimVarIn,dimMeanOut,dimVarOut,10,testdata,testlabels);
00098
00099 std::vector<int> testResults = dt.predict(testdata);
00100 numCorrect=0;
00101 for(uint i=0;i<testlabels.size();i++)
00102 {
00103 if(testResults[i]==testlabels[i]) numCorrect++;
00104 printf("Guess %d [Ground Truth %d]\n",testResults[i],testlabels[i]);
00105 }
00106 printf("Accuracy:[Correct/Total]=[%d/%Zu]:%f\n",numCorrect,testlabels.size(),numCorrect/float(testlabels.size()));
00107 manager.stop();
00108
00109 }
00110
00111 void makeData(const std::vector<float>& dimMeanIn,const std::vector<float>& dimVarIn, const std::vector<float>& dimMeanOut,const std::vector<float>& dimVarOut, const uint sampleDim, std::vector<std::vector<float> >& data, std::vector<int>& labels)
00112 {
00113
00114 rutz::urand rgen(time((time_t*)0)+getpid());
00115
00116
00117
00118
00119
00120
00121
00122
00123 const uint dataDim=(uint) dimMeanIn.size();
00124
00125 ASSERT(data.size()==dataDim);
00126 for(uint i=0;i<sampleDim;i++)
00127 {
00128 int l=rgen.idraw(2)*2-1;
00129 printf("data[][%u]: l=%d; ",i,l);
00130 for(uint j=0;j<dataDim;j++)
00131 {
00132 if(l==1)
00133 data[j].push_back(rgen.fdraw_range(0.0,0.5));
00134 else
00135 data[j].push_back(rgen.fdraw_range(-0.5,0.0));
00136 printf("%f, ",data[j][i]);
00137 }
00138 printf("\n");
00139 labels.push_back(l);
00140 }
00141 }
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151