00001 /*!@file src/Features/test-GentleBoost.C */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Dan Parks <danielfp@usc.edu> 00034 // $HeadURL$ 00035 // $Id$ 00036 // 00037 00038 #include "Component/ModelManager.H" 00039 #include "Learn/GentleBoost.H" 00040 #include "rutz/rand.h" 00041 #include "rutz/trace.h" 00042 #include "Util/SortUtil.H" 00043 #include "Util/Assert.H" 00044 #include <math.h> 00045 #include <fcntl.h> 00046 #include <limits> 00047 #include <string> 00048 #include <stdio.h> 00049 00050 00051 void makeData(const int numCategories, const uint sampleDim, std::vector<std::vector<float> >& data, std::vector<int>& labels, bool printData); 00052 00053 int main(const int argc, const char **argv) 00054 { 00055 00056 MYLOGVERB = LOG_INFO; 00057 ModelManager manager("Test Decision Tree"); 00058 00059 00060 // Create log likelihood classifier and local binary patterns objects 00061 uint nDim=4; 00062 int numCategories=3; 00063 int maxIters=1; 00064 int maxTreeSize = 4; 00065 GentleBoost gb(maxTreeSize); 00066 std::string saveDataFile("tmp.dat"); 00067 std::string compareDataFile("tmp.cmp.dat"); 00068 00069 if (manager.parseCommandLine( 00070 (const int)argc, (const char**)argv, "", 0, 0) == false) 00071 return 0; 00072 00073 manager.start(); 00074 std::vector<std::vector<float> > traindata(nDim); 00075 std::vector<int> trainlabels; 00076 std::vector<float> dimMeanIn(nDim), dimMeanOut(nDim), dimVarIn(nDim,1.0F), dimVarOut(nDim,1.0F); 00077 for(uint i=0;i<nDim;i++) 00078 { 00079 dimMeanIn[i] = nDim-i; 00080 dimMeanOut[i] = -(nDim-i); 00081 } 00082 makeData(numCategories,1000,traindata,trainlabels,false); 00083 // Train the classifier on the training set 00084 gb.train(traindata,trainlabels,maxIters); 00085 gb.save(saveDataFile); 00086 // Do a cycle of saving and loading and compare to the original save file 00087 GentleBoost tmpGB; 00088 tmpGB.load(saveDataFile); 00089 tmpGB.save(compareDataFile); 00090 00091 std::map<int,std::vector<float> > trainPDF = gb.predictPDF(traindata); 00092 std::vector<int> trainResults = gb.getMostLikelyClass(trainPDF); 00093 // Validate on training set 00094 int numCorrect=0; 00095 for(uint i=0;i<trainlabels.size();i++) 00096 { 00097 if(trainResults[i]==trainlabels[i]) numCorrect++; 00098 //printf("Train Guess %d [Ground Truth %d]\n",trainResults[i],trainlabels[i]); 00099 } 00100 printf("Training Accuracy:[Correct/Total]=[%d/%Zu]:%f\n",numCorrect,trainlabels.size(),numCorrect/float(trainlabels.size())); 00101 gb.printAllTrees(); 00102 std::vector<std::vector<float> > testdata(nDim); 00103 std::vector<int> testlabels; 00104 // Create new data from same distribution as test set 00105 makeData(numCategories,10,testdata,testlabels,true); 00106 // Classify test set 00107 std::map<int,std::vector<float> > testPDF = gb.predictPDF(testdata); 00108 std::vector<int> testResults = gb.getMostLikelyClass(testPDF); 00109 numCorrect=0; 00110 for(uint i=0;i<testlabels.size();i++) 00111 { 00112 if(testResults[i]==testlabels[i]) numCorrect++; 00113 std::map<int,std::vector<float> >::iterator litr; 00114 printf("Guess %d [",testResults[i]); 00115 for(litr=testPDF.begin();litr!=testPDF.end();litr++) 00116 { 00117 printf("(%d)%f, ",litr->first,litr->second[i]); 00118 } 00119 printf("] *** Ground Truth %d\n",testlabels[i]); 00120 } 00121 printf("Accuracy:[Correct/Total]=[%d/%Zu]:%f\n",numCorrect,testlabels.size(),numCorrect/float(testlabels.size())); 00122 manager.stop(); 00123 00124 } 00125 00126 void makeData(const int numCategories, const uint sampleDim, std::vector<std::vector<float> >& data, std::vector<int>& labels, bool printData) 00127 { 00128 // Create uniform random number generator 00129 rutz::urand rgen(time((time_t*)0)+getpid()); 00130 ASSERT(data.size()>0); 00131 // Create data and labels 00132 const uint dataDim=(uint) data.size(); 00133 00134 for(uint i=0;i<sampleDim;i++) 00135 { 00136 int l=rgen.idraw(numCategories)+1; 00137 if(printData) printf("data[][%u]: l=%d; ",i,l); 00138 for(uint j=0;j<dataDim;j++) 00139 { 00140 data[j].push_back(rgen.fdraw_range(l-0.75,l+0.75));//*dimVarIn[j]+dimMeanIn[j]); 00141 if(printData) printf("%f, ",data[j][i]); 00142 } 00143 if(printData) printf("\n"); 00144 labels.push_back(l); 00145 } 00146 } 00147 00148 00149 // ###################################################################### 00150 /* So things look consistent in everyone's emacs... */ 00151 /* Local Variables: */ 00152 /* indent-tabs-mode: nil */ 00153 /* End: */ 00154 00155 00156