00001 /*!@file src/Features/test-DecisionTree.C */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Dan Parks <danielfp@usc.edu> 00034 // $HeadURL$ 00035 // $Id$ 00036 // 00037 00038 #include "Component/ModelManager.H" 00039 #include "Learn/DecisionTree.H" 00040 #include "rutz/rand.h" 00041 #include "rutz/trace.h" 00042 #include "Util/SortUtil.H" 00043 #include "Util/Assert.H" 00044 #include <math.h> 00045 #include <fcntl.h> 00046 #include <limits> 00047 #include <string> 00048 #include <stdio.h> 00049 // #include <boost/random/normal_distribution.hpp> 00050 // #include <boost/random/mersenne_twister.hpp> 00051 // #include <boost/random/variate_generator.hpp> 00052 00053 00054 void makeData(const std::vector<float>& dimMeanIn,const std::vector<float>& dimVarIn, const std::vector<float>& dimMeanOut,const std::vector<float>& dimVarOut, const uint sampleDim, std::vector<std::vector<float> >& data, std::vector<int>& labels); 00055 00056 int main(const int argc, const char **argv) 00057 { 00058 00059 MYLOGVERB = LOG_INFO; 00060 ModelManager manager("Test Decision Tree"); 00061 00062 00063 // Create log likelihood classifier and local binary patterns objects 00064 DecisionTree dt(3); 00065 00066 if (manager.parseCommandLine( 00067 (const int)argc, (const char**)argv, "", 0, 0) == false) 00068 return 0; 00069 00070 manager.start(); 00071 uint nDim=4; 00072 std::vector<std::vector<float> > traindata(nDim); 00073 std::vector<int> trainlabels; 00074 std::vector<float> dimMeanIn(nDim), dimMeanOut(nDim), dimVarIn(nDim,1.0F), dimVarOut(nDim,1.0F); 00075 for(uint i=0;i<nDim;i++) 00076 { 00077 dimMeanIn[i] = nDim-i; 00078 dimMeanOut[i] = -(nDim-i); 00079 } 00080 makeData(dimMeanIn,dimVarIn,dimMeanOut,dimVarOut,50,traindata,trainlabels); 00081 // Train the classifier on the training set 00082 dt.train(traindata,trainlabels); 00083 dt.printTree(); 00084 std::vector<int> trainResults = dt.predict(traindata); 00085 // Validate on training set 00086 int numCorrect=0; 00087 for(uint i=0;i<trainlabels.size();i++) 00088 { 00089 if(trainResults[i]==trainlabels[i]) numCorrect++; 00090 printf("Train Guess %d [Ground Truth %d]\n",trainResults[i],trainlabels[i]); 00091 } 00092 printf("Training Accuracy:[Correct/Total]=[%d/%Zu]:%f\n",numCorrect,trainlabels.size(),numCorrect/float(trainlabels.size())); 00093 00094 std::vector<std::vector<float> > testdata(nDim); 00095 std::vector<int> testlabels; 00096 // Create new data from same distribution as test set 00097 makeData(dimMeanIn,dimVarIn,dimMeanOut,dimVarOut,10,testdata,testlabels); 00098 // Classify test set 00099 std::vector<int> testResults = dt.predict(testdata); 00100 numCorrect=0; 00101 for(uint i=0;i<testlabels.size();i++) 00102 { 00103 if(testResults[i]==testlabels[i]) numCorrect++; 00104 printf("Guess %d [Ground Truth %d]\n",testResults[i],testlabels[i]); 00105 } 00106 printf("Accuracy:[Correct/Total]=[%d/%Zu]:%f\n",numCorrect,testlabels.size(),numCorrect/float(testlabels.size())); 00107 manager.stop(); 00108 00109 } 00110 00111 void makeData(const std::vector<float>& dimMeanIn,const std::vector<float>& dimVarIn, const std::vector<float>& dimMeanOut,const std::vector<float>& dimVarOut, const uint sampleDim, std::vector<std::vector<float> >& data, std::vector<int>& labels) 00112 { 00113 // Create uniform random number generator 00114 rutz::urand rgen(time((time_t*)0)+getpid()); 00115 00116 // // Create mersenne twister generator, attached to a Normal Distribution 00117 // boost::mt19937 igen(time((time_t*)0)+getpid()); 00118 // boost::variate_generator<boost::mt19937, boost::normal_distribution<> > 00119 // gen(igen, 00120 // boost::normal_distribution<>(0.0,1.0)); 00121 // double randVar = gen(); 00122 // Create data and labels 00123 const uint dataDim=(uint) dimMeanIn.size(); 00124 // const uint numCategories=2; 00125 ASSERT(data.size()==dataDim); 00126 for(uint i=0;i<sampleDim;i++) 00127 { 00128 int l=rgen.idraw(2)*2-1; 00129 printf("data[][%u]: l=%d; ",i,l); 00130 for(uint j=0;j<dataDim;j++) 00131 { 00132 if(l==1) 00133 data[j].push_back(rgen.fdraw_range(0.0,0.5));//*dimVarIn[j]+dimMeanIn[j]); 00134 else 00135 data[j].push_back(rgen.fdraw_range(-0.5,0.0));//*dimVarOut[j]+dimMeanOut[j]); 00136 printf("%f, ",data[j][i]); 00137 } 00138 printf("\n"); 00139 labels.push_back(l); 00140 } 00141 } 00142 00143 00144 // ###################################################################### 00145 /* So things look consistent in everyone's emacs... */ 00146 /* Local Variables: */ 00147 /* indent-tabs-mode: nil */ 00148 /* End: */ 00149 00150 00151