00001 /*!@file Learn/GentleBoostBinary.C GentleBoost 2-Class Classifier */ 00002 // //////////////////////////////////////////////////////////////////// // 00003 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the // 00004 // University of Southern California (USC) and the iLab at USC. // 00005 // See http://iLab.usc.edu for information about this project. // 00006 // //////////////////////////////////////////////////////////////////// // 00007 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00008 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00009 // in Visual Environments, and Applications'' by Christof Koch and // 00010 // Laurent Itti, California Institute of Technology, 2001 (patent // 00011 // pending; application number 09/912,225 filed July 23, 2001; see // 00012 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00013 // //////////////////////////////////////////////////////////////////// // 00014 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00015 // // 00016 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00017 // redistribute it and/or modify it under the terms of the GNU General // 00018 // Public License as published by the Free Software Foundation; either // 00019 // version 2 of the License, or (at your option) any later version. // 00020 // // 00021 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00022 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00023 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00024 // PURPOSE. See the GNU General Public License for more details. // 00025 // // 00026 // You should have received a copy of the GNU General Public License // 00027 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00028 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00029 // Boston, MA 02111-1307 USA. // 00030 // //////////////////////////////////////////////////////////////////// // 00031 // 00032 // Primary maintainer for this file: Dan Parks <danielfp@usc.edu> 00033 // $HeadURL$ 00034 // $Id$ 00035 // 00036 00037 #include "Learn/GentleBoostBinary.H" 00038 #include "Util/Assert.H" 00039 #include "Util/log.H" 00040 #include "Util/sformat.H" 00041 #include "Util/SortUtil.H" 00042 #include <limits> 00043 #include <math.h> 00044 #include <stdio.h> 00045 00046 00047 00048 GentleBoostBinary::GentleBoostBinary(int maxTreeSize) : 00049 itsMaxTreeSize(maxTreeSize) 00050 { 00051 } 00052 00053 00054 std::vector<float> GentleBoostBinary::predict(const std::vector<std::vector<float> >& data) 00055 { 00056 return predict(data,itsWeights); 00057 } 00058 00059 // Real valued approximation of the answer by committee of weak learners - 2 class problem 00060 std::vector<float> GentleBoostBinary::predict(const std::vector<std::vector<float> >& data, std::vector<float> weights) 00061 { 00062 ASSERT(weights.size()==itsNodes.size()); 00063 ASSERT(data.size()>0); 00064 std::vector<float> pred(data[0].size()); 00065 for(size_t i=0;i<itsNodes.size();i++) 00066 { 00067 std::vector<int> tmppred=itsNodes[i]->decide(data); 00068 for(size_t j=0;j<pred.size();j++) 00069 { 00070 pred[j]+=float(tmppred[j])*weights[i]; 00071 } 00072 } 00073 return pred; 00074 } 00075 00076 00077 void GentleBoostBinary::train(const std::vector<std::vector<float> >& data, const std::vector<int>& labels, int maxIters) 00078 { 00079 std::vector<float> predictions; 00080 train(data,labels,maxIters,predictions); 00081 00082 } 00083 00084 void GentleBoostBinary::train(const std::vector<std::vector<float> >& data, const std::vector<int>& labels, int maxIters, std::vector<float>& predictions) 00085 { 00086 ASSERT(data.size()>0); 00087 int nSamples = int(data[0].size()); 00088 ASSERT(int(labels.size())==nSamples); 00089 std::vector<float> dataWeights; 00090 if(predictions.size()>0) 00091 { 00092 dataWeights = std::vector<float>(nSamples); 00093 for(size_t i=0;i<predictions.size();i++) 00094 { 00095 dataWeights[i]=exp(-(labels[i]*predictions[i])); 00096 } 00097 } 00098 else 00099 { 00100 dataWeights = std::vector<float>(nSamples,1.0F/float(nSamples)); 00101 predictions = std::vector<float>(nSamples); 00102 } 00103 00104 for(int iter=0;iter<maxIters;iter++) 00105 { 00106 rutz::shared_ptr<DecisionTree> learner(new DecisionTree(itsMaxTreeSize)); 00107 learner->train(data,labels,dataWeights); 00108 std::deque<rutz::shared_ptr<DecisionNode> > curNodes = learner->getNodes(); 00109 if(curNodes.size()==0) 00110 { 00111 LINFO("Training complete, only trivial cuts found"); 00112 return; 00113 } 00114 for(size_t idx=0;idx<curNodes.size();idx++) 00115 { 00116 rutz::shared_ptr<DecisionNode> curNode = curNodes[idx]; 00117 std::vector<int> curNodeOut = curNode->decide(data); 00118 float s1=0.0F,s2=0.0F; 00119 for(size_t i=0;i<curNodeOut.size();i++) 00120 { 00121 if(curNodeOut[i]==1) 00122 { 00123 if(labels[i]==1) 00124 { 00125 // Weighted sum of true positives 00126 s1 += dataWeights[i]; 00127 } 00128 else if(labels[i]==-1) 00129 { 00130 // Weighted sum of false positives 00131 s2 += dataWeights[i]; 00132 } 00133 } 00134 // Deviation from original, take into account true negatives/false negatives when evaluating weights 00135 else if(curNodeOut[i]==-1) 00136 { 00137 if(labels[i]==-1) 00138 { 00139 // Weighted sum of true negatives 00140 s1 += dataWeights[i]; 00141 } 00142 else if(labels[i]==1) 00143 { 00144 // Weighted sum of false negatives 00145 s2 += dataWeights[i]; 00146 } 00147 } 00148 } 00149 if(s1==0.0F && s2==0.0F) 00150 continue; 00151 float alpha = (s1-s2)/(s1+s2); 00152 itsWeights.push_back(alpha); 00153 itsNodes.push_back(curNode); 00154 for(size_t i=0;i<predictions.size();i++) 00155 { 00156 predictions[i] += curNodeOut[i]*alpha; 00157 } 00158 } 00159 float sumDW=0; 00160 for(int i=0;i<nSamples;i++) 00161 { 00162 dataWeights[i] = exp(-1.0F * (labels[i]*predictions[i])); 00163 sumDW+=dataWeights[i]; 00164 } 00165 if(sumDW>0) 00166 { 00167 for(int i=0;i<nSamples;i++) 00168 { 00169 dataWeights[i]/=sumDW; 00170 } 00171 } 00172 } 00173 } 00174 00175 void GentleBoostBinary::printTree() 00176 { 00177 std::deque<rutz::shared_ptr<DecisionNode> >::iterator itr; 00178 LINFO("Printing Tree of %Zu nodes",itsNodes.size()); 00179 int i=0; 00180 for(itr=itsNodes.begin();itr!=itsNodes.end();itr++) 00181 { 00182 rutz::shared_ptr<DecisionNode> n=*itr; 00183 if(!n.is_valid()) 00184 { 00185 LINFO("Node[%d] <Invalid Pointer>",i); 00186 continue; 00187 } 00188 std::string output; 00189 n->printNode(output); 00190 LINFO("Weight: %f\n%s",itsWeights[i],output.c_str()); 00191 i++; 00192 } 00193 } 00194 00195 void GentleBoostBinary::writeTree(std::ostream& outstream) 00196 { 00197 rutz::shared_ptr<std::string> output = rutz::shared_ptr<std::string>(new std::string); 00198 std::deque<rutz::shared_ptr<DecisionNode> >::iterator itr; 00199 int i=0; 00200 for(itr=itsNodes.begin();itr!=itsNodes.end();itr++) 00201 { 00202 rutz::shared_ptr<DecisionNode> n=*itr; 00203 if(!n.is_valid()) 00204 { 00205 continue; 00206 } 00207 outstream << sformat("TREEWEIGHT:%f; \n",itsWeights[i]); 00208 n->writeNode(outstream); 00209 } 00210 outstream << std::string("END\n"); 00211 } 00212 00213 void GentleBoostBinary::readTree(std::istream& instream) 00214 { 00215 DecisionNode tmp; 00216 const int BUFFER_SIZE = 256; 00217 char buf[BUFFER_SIZE]; 00218 int treeIdx=0; 00219 00220 bool nodeIsValid = true; 00221 while(nodeIsValid) 00222 { 00223 instream.getline(buf,BUFFER_SIZE); 00224 float treeWeight; 00225 int numItemsFound = sscanf(buf,"TREEWEIGHT:%f; ",&treeWeight); 00226 if(numItemsFound == 1) 00227 { 00228 rutz::shared_ptr<DecisionNode> node = tmp.readNode(instream); 00229 if(!node.is_valid()) 00230 { 00231 LFATAL("No tree associated with tree weight at index %d",treeIdx); 00232 nodeIsValid = false; 00233 } 00234 itsWeights.push_back(treeWeight); 00235 itsNodes.push_back(node); 00236 treeIdx++; 00237 } 00238 else if(std::string(buf).compare("END")==0) 00239 { 00240 nodeIsValid = false; 00241 } 00242 else 00243 { 00244 LFATAL("Incomplete tree representation at index %d",treeIdx); 00245 nodeIsValid = false; 00246 } 00247 } 00248 00249 } 00250 00251 00252 void GentleBoostBinary::clear() 00253 { 00254 itsNodes.clear(); 00255 itsWeights.clear(); 00256 }