00001 /*!@file Learn/DecisionTree.C Decision Tree Classifier */ 00002 // //////////////////////////////////////////////////////////////////// // 00003 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the // 00004 // University of Southern California (USC) and the iLab at USC. // 00005 // See http://iLab.usc.edu for information about this project. // 00006 // //////////////////////////////////////////////////////////////////// // 00007 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00008 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00009 // in Visual Environments, and Applications'' by Christof Koch and // 00010 // Laurent Itti, California Institute of Technology, 2001 (patent // 00011 // pending; application number 09/912,225 filed July 23, 2001; see // 00012 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00013 // //////////////////////////////////////////////////////////////////// // 00014 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00015 // // 00016 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00017 // redistribute it and/or modify it under the terms of the GNU General // 00018 // Public License as published by the Free Software Foundation; either // 00019 // version 2 of the License, or (at your option) any later version. // 00020 // // 00021 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00022 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00023 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00024 // PURPOSE. See the GNU General Public License for more details. // 00025 // // 00026 // You should have received a copy of the GNU General Public License // 00027 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00028 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00029 // Boston, MA 02111-1307 USA. // 00030 // //////////////////////////////////////////////////////////////////// // 00031 // 00032 // Primary maintainer for this file: Dan Parks <danielfp@usc.edu> 00033 // $HeadURL$ 00034 // $Id$ 00035 // 00036 00037 #include "DecisionTree.H" 00038 #include "Util/Assert.H" 00039 #include "Util/log.H" 00040 #include "Util/SortUtil.H" 00041 #include "Util/sformat.H" 00042 #include <limits> 00043 #include <math.h> 00044 #include <stdio.h> 00045 00046 DecisionNode::DecisionNode() : 00047 itsDim(-1), 00048 itsLeaf(true), 00049 itsLeftConstraint(-std::numeric_limits<float>::infinity()), 00050 itsRightConstraint(std::numeric_limits<float>::infinity()), 00051 itsClass(1), 00052 itsParent(NULL) 00053 { 00054 00055 } 00056 00057 bool DecisionNode::isValid() 00058 { 00059 return itsDim>=0; 00060 } 00061 00062 int DecisionNode::printNode(std::string& output,int depth) 00063 { 00064 int retDepth; 00065 if(itsParent.is_valid()) 00066 retDepth=itsParent->printNode(output,depth)+1; 00067 else retDepth=depth; 00068 char indent[250]; 00069 indent[0]='\0'; 00070 for(int i=0;i<retDepth;i++) 00071 sprintf(indent,"%s\t",indent); 00072 output = sformat("%s%sNode[%p]:<Leaf:%s> <Dim:%d, Class:%d, LeftConstraint:%f, RightConstraint:%f, Parent:%p\n",output.c_str(),indent,this,(itsLeaf)?"t":"f",itsDim,itsClass,itsLeftConstraint,itsRightConstraint,itsParent.get()); 00073 return retDepth; 00074 } 00075 00076 void DecisionNode::writeNode(std::ostream& outstream, bool needEnd) 00077 { 00078 outstream << sformat("DIM:%d,CLASS:%d,LC:%f,RC:%f; \n",itsDim,itsClass,itsLeftConstraint,itsRightConstraint); 00079 if(itsParent.is_valid()) 00080 itsParent->writeNode(outstream,false); 00081 if(needEnd) 00082 outstream << std::string("END\n"); 00083 } 00084 00085 00086 rutz::shared_ptr<DecisionNode> DecisionNode::readNode(std::istream& instream) 00087 { 00088 bool nodeIsValid = true; 00089 const int BUFFER_SIZE = 256; 00090 char buf[BUFFER_SIZE]; 00091 int depth=0; 00092 rutz::shared_ptr<DecisionNode> root(NULL), curNode(NULL); 00093 while(nodeIsValid) 00094 { 00095 instream.getline(buf,BUFFER_SIZE); 00096 int dim,cls; 00097 float lc,rc; 00098 int numItemsFound = sscanf(buf,"DIM:%d,CLASS:%d,LC:%f,RC:%f; ",&dim,&cls,&lc,&rc); 00099 if(numItemsFound == 4) 00100 { 00101 rutz::shared_ptr<DecisionNode> node(new DecisionNode()); 00102 node->setDim(dim); 00103 node->setClass(cls); 00104 node->setLeftConstraint(lc); 00105 node->setRightConstraint(rc); 00106 // If no root, set it up 00107 if(!root.is_valid()) 00108 { 00109 root = node; 00110 curNode = root; 00111 } 00112 // Otherwise set parent and move up the chain 00113 else 00114 { 00115 curNode->setParent(node); 00116 curNode = node; 00117 } 00118 depth++; 00119 } 00120 else if(std::string(buf).compare("END")==0) 00121 { 00122 // We have hit a leaf node, so mark it as such 00123 if(curNode.is_valid()) 00124 { 00125 curNode->setLeaf(true); 00126 } 00127 else 00128 { 00129 LFATAL("Empty node list in decision tree"); 00130 } 00131 nodeIsValid = false; 00132 } 00133 else 00134 { 00135 LFATAL("Incomplete node representation at depth %d num found %d buffer[%s]",depth,numItemsFound,buf); 00136 nodeIsValid = false; 00137 } 00138 } 00139 return root; 00140 } 00141 00142 00143 00144 // Given vector of data, decide whether the data is within or outside the class 00145 // Returns vector of [classId/0] where 0 if data not determined by this node, and classId if it is 00146 std::vector<int> DecisionNode::decide(const std::vector<std::vector<float> >& data) 00147 { 00148 ASSERT(int(data.size())>itsDim && itsDim>=0); 00149 std::vector<int> y(data[itsDim].size(),1); 00150 rutz::shared_ptr<DecisionNode> parNode = itsParent; 00151 // Handle parents weighting 00152 if(itsParent.is_valid()) 00153 { 00154 std::vector<int> py = parNode->decide(data); 00155 for(uint s=0;s<data[itsDim].size();s++) 00156 { 00157 y[s] *= py[s]; 00158 } 00159 } 00160 00161 int inVal=1; 00162 // The leaf node actually makes the class judgement 00163 if(itsLeaf) 00164 { 00165 inVal = itsClass; 00166 } 00167 00168 // Check if data is within right/left constraint 00169 for(uint s=0;s<data[itsDim].size();s++) 00170 { 00171 if(y[s]>0) 00172 y[s] *= (data[itsDim][s] < itsRightConstraint && data[itsDim][s] >= itsLeftConstraint) ? inVal : 0; 00173 } 00174 00175 return y; 00176 } 00177 00178 size_t DecisionNode::getDim() 00179 { 00180 return itsDim; 00181 } 00182 00183 void DecisionNode::setDim(size_t dim) 00184 { 00185 itsDim = dim; 00186 } 00187 00188 void DecisionNode::setLeaf(bool isLeaf) 00189 { 00190 itsLeaf = isLeaf; 00191 // Non leaf nodes do not have a class 00192 if(!itsLeaf) 00193 itsClass=0; 00194 } 00195 00196 void DecisionNode::setParent(rutz::shared_ptr<DecisionNode> parent) 00197 { 00198 itsParent = parent; 00199 } 00200 00201 00202 void DecisionNode::setLeftConstraint(float constraint) 00203 { 00204 itsLeftConstraint = constraint; 00205 } 00206 00207 void DecisionNode::setRightConstraint(float constraint) 00208 { 00209 itsRightConstraint = constraint; 00210 } 00211 00212 void DecisionNode::setClass(int classId) 00213 { 00214 itsClass = classId; 00215 } 00216 00217 int DecisionNode::getClass() 00218 { 00219 return itsClass; 00220 } 00221 00222 float DecisionNode::split(const std::vector<std::vector<float> >& data, const std::vector<int>& labels,const std::vector<float>& weights, rutz::shared_ptr<DecisionNode>& left, rutz::shared_ptr<DecisionNode>& right, const rutz::shared_ptr<DecisionNode> parent) 00223 { 00224 left = rutz::shared_ptr<DecisionNode>(new DecisionNode); 00225 right = rutz::shared_ptr<DecisionNode>(new DecisionNode); 00226 left->setParent(parent); 00227 right->setParent(parent); 00228 //LINFO("Splitting data on node %p",this); 00229 // Data is of size NDxNT, where ND = # of feature dimensions, and NT = # of training samples 00230 ASSERT(data.size() > 0); 00231 // Number of training samples 00232 uint tr_size = data[0].size(); 00233 // Store the lowest error, dimension of lowest error, and direciton [+1/-1] of lowest error 00234 std::vector<float> bestErr; 00235 std::vector<size_t> bestErrIdx; 00236 std::vector<float> bestErrDir; 00237 // Iterate over feature dimensions 00238 for(uint d=0;d<data.size();d++) 00239 { 00240 // Get the rank order of data for this dimension 00241 std::vector<size_t> dindices; 00242 util::sortrank(data[d],dindices); 00243 std::vector<float> dsorted=data[d]; 00244 // Sort data for this dimension 00245 std::sort(dsorted.begin(),dsorted.end()); 00246 00247 // For the current dimension, build a weighted value for each of the positive and negative data samples (and consolidate identical data) 00248 std::vector<float> vPos(tr_size); 00249 std::vector<float> vNeg(tr_size); 00250 00251 uint i=0,j=0; 00252 while(i<dsorted.size()) 00253 { 00254 uint k = 0; 00255 while(i + k < dsorted.size() && dsorted[i] == dsorted[i+k]) 00256 { 00257 if(labels[dindices[i+k]] > 0) 00258 vPos[j] += weights[dindices[i+k]]; 00259 else 00260 vNeg[j] += weights[dindices[i+k]]; 00261 k++; 00262 } 00263 i += k; 00264 j++; 00265 } 00266 // Resize to the number of unique data points 00267 vNeg.resize(j); 00268 vPos.resize(j); 00269 00270 std::vector<float> err(vPos.size()); 00271 std::vector<float> invErr(vPos.size()); 00272 00273 std::vector<float> iPos(vPos.size()); 00274 std::vector<float> iNeg(vNeg.size()); 00275 // Build cumulative sum over the weights of the sorted data 00276 for(i=0;i<iPos.size();i++) 00277 { 00278 if(i==0) 00279 { 00280 iPos[0] = vPos[0]; 00281 iNeg[0] = vNeg[0]; 00282 } 00283 else 00284 { 00285 iPos[i] = iPos[i-1] + vPos[i]; 00286 iNeg[i] = iNeg[i-1] + vNeg[i]; 00287 } 00288 } 00289 00290 // Total negative/positive training weight 00291 float totalN = (iNeg.size()>0) ? iNeg[iNeg.size()-1] : 0; 00292 float totalP = (iPos.size()>0) ? iPos[iPos.size()-1] : 0; 00293 00294 // If there is no weight on the negative or positive side, then this will result in a trivial split where all the data will end up in one child 00295 // as a result, and this will have no error, so no point in doing any more evaluation on this split() 00296 if(totalN<0.00001 || totalP<0.00001) 00297 return std::numeric_limits<float>::max(); 00298 00299 // Calculate the error if we were to split the data at each index (in both directions) 00300 for(i=0;i<j;i++) 00301 { 00302 // Deviation from original: treat error as a percentage of positive and negative 00303 err[i] = (iPos[i])/totalP + (totalN - iNeg[i])/totalN; 00304 invErr[i] = (iNeg[i])/totalN + (totalP - iPos[i])/totalP; 00305 //printf("i %d, j %d err %f, inverr %f iPos %f, iNeg %f, totalP %f, totalN %f\n",i,j,err[i],invErr[i],iPos[i],iNeg[i],totalP,totalN); 00306 } 00307 00308 // Find minimum error 00309 size_t errMinIdx = std::distance(err.begin(),std::min_element(err.begin(),err.end())); 00310 00311 // Find minimum inverse error 00312 size_t invErrMinIdx = std::distance(invErr.begin(),std::min_element(invErr.begin(),invErr.end())); 00313 00314 //LINFO("For dimension %u, err[%Zu] %f, invErr[%Zu] %f, totalN %f, totalP %f line %f, invline %f",d,errMinIdx,err[errMinIdx],invErrMinIdx,invErr[invErrMinIdx],totalN,totalP,dsorted[errMinIdx],dsorted[invErrMinIdx]); 00315 00316 // Technically for err/invErr, the higher one is the % right, and the lower one is the % wrong (if we flip the error direction accordingly) 00317 //Determine lowest error and store for this dimension 00318 if(err[errMinIdx] < invErr[invErrMinIdx]) 00319 { 00320 bestErr.push_back(err[errMinIdx]); 00321 bestErrIdx.push_back(errMinIdx); 00322 bestErrDir.push_back(-1); 00323 } 00324 else 00325 { 00326 bestErr.push_back(invErr[invErrMinIdx]); 00327 bestErrIdx.push_back(invErrMinIdx); 00328 bestErrDir.push_back(1); 00329 } 00330 } 00331 00332 // Find the dimension that will minimize the error 00333 size_t bestDim = std::distance(bestErr.begin(),std::min_element(bestErr.begin(),bestErr.end())); 00334 00335 std::vector<float> tmpDimSorted = data[bestDim]; 00336 std::sort(tmpDimSorted.begin(),tmpDimSorted.end()); 00337 std::vector<float> dimSorted(tmpDimSorted.size()); 00338 00339 uint i = 0; 00340 uint j = 0; 00341 00342 // Consolidate dimension as it was done in the previous loop, and then use this to extract the correct columns for the threshold 00343 while(i<dimSorted.size()) 00344 { 00345 uint k = 0; 00346 while(i+k<dimSorted.size() && tmpDimSorted[i] == tmpDimSorted[i+k]) 00347 { 00348 dimSorted[j] = tmpDimSorted[i]; 00349 k++; 00350 } 00351 i += k; 00352 j++; 00353 } 00354 00355 dimSorted.resize(j); 00356 00357 // Select the midpoint between the two data points that act as the dividing point 00358 float threshold = (dimSorted[bestErrIdx[bestDim]] + dimSorted[std::min(bestErrIdx[bestDim]+1,dimSorted.size()-1)])/2.0; 00359 00360 left->setDim(bestDim); 00361 left->setRightConstraint(threshold); 00362 left->setClass(bestErrDir[bestDim]); 00363 00364 right->setDim(bestDim); 00365 right->setLeftConstraint(threshold); 00366 right->setClass(-bestErrDir[bestDim]); 00367 return bestErr[bestDim]; 00368 } 00369 00370 00371 DecisionTree::DecisionTree(int maxSplits) : 00372 itsMaxSplits(maxSplits) 00373 { 00374 } 00375 00376 std::deque<rutz::shared_ptr<DecisionNode> > DecisionTree::getNodes() 00377 { 00378 return itsNodes; 00379 } 00380 00381 void DecisionTree::addNode(rutz::shared_ptr<DecisionNode> node) 00382 { 00383 itsNodes.push_back(node); 00384 } 00385 00386 void DecisionTree::printTree() 00387 { 00388 std::deque<rutz::shared_ptr<DecisionNode> >::iterator itr; 00389 LINFO("Printing Tree of %Zu nodes",itsNodes.size()); 00390 int i=0; 00391 for(itr=itsNodes.begin();itr!=itsNodes.end();itr++) 00392 { 00393 rutz::shared_ptr<DecisionNode> n=*itr; 00394 if(!n.is_valid()) 00395 { 00396 LINFO("Node[%d] <Invalid Pointer>",i); 00397 continue; 00398 } 00399 std::string output; 00400 n->printNode(output); 00401 LINFO("%s",output.c_str()); 00402 i++; 00403 } 00404 } 00405 00406 std::vector<int> DecisionTree::predict(const std::vector<std::vector<float> >& data, std::vector<float> weights) 00407 { 00408 ASSERT(data.size()>0); 00409 if(weights.size()==0) 00410 weights.resize(data[0].size(),1.0F); 00411 size_t sampleDim = data[0].size(); 00412 std::vector<int> prediction(sampleDim); 00413 std::deque<rutz::shared_ptr<DecisionNode> >::iterator itr; 00414 for(itr=itsNodes.begin();itr!=itsNodes.end();itr++) 00415 { 00416 std::vector<int> y = (*itr)->decide(data); 00417 for(uint i=0;i<y.size();i++) 00418 { 00419 prediction[i] += (y[i])*weights[i]; 00420 } 00421 } 00422 return prediction; 00423 } 00424 00425 //function train(data, labels, weights=<empty>) 00426 // This is a binary classifier, and labels must be -1/1 00427 void DecisionTree::train(const std::vector<std::vector<float> >& data, const std::vector<int>& labels, std::vector<float> weights) 00428 { 00429 ASSERT(labels.size()==weights.size() || weights.size()==0); 00430 // If weights not specified, set them all to 1.0 00431 if(weights.size()==0) 00432 { 00433 weights.resize(labels.size(),1.0/float(labels.size())); 00434 } 00435 // Clear out the nodes for this tree 00436 itsNodes.clear(); 00437 00438 rutz::shared_ptr<DecisionNode> tmpNode(new DecisionNode()); 00439 00440 // Split the temporary node to get the first binary threshold split 00441 rutz::shared_ptr<DecisionNode> left,right; 00442 tmpNode->split(data,labels,weights,left,right); 00443 00444 if(!left->isValid() || !right->isValid()) 00445 { 00446 LINFO("Split could not find a non-trivial cut in the remaining data"); 00447 return; 00448 } 00449 itsNodes.push_back(left); 00450 itsNodes.push_back(right); 00451 00452 // Determine how well the predictions correlate with the ground truth labels 00453 // Separately test the correlation and anticorrelation for both the left and right nodes 00454 std::vector<int> propAns = left->decide(data); 00455 float leftPos=0,leftNeg=0; 00456 for(uint a=0;a<propAns.size();a++) 00457 { 00458 if(propAns[a] == labels[a]) 00459 leftPos += weights[a]; 00460 else if(propAns[a] == -labels[a]) 00461 leftNeg += weights[a]; 00462 } 00463 propAns = right->decide(data); 00464 float rightPos=0,rightNeg=0; 00465 for(uint a=0;a<propAns.size();a++) 00466 { 00467 if(propAns[a] == labels[a]) 00468 rightPos += weights[a]; 00469 else if(propAns[a] == -labels[a]) 00470 rightNeg += weights[a]; 00471 } 00472 00473 // Build a list of left/right node errors 00474 // Here max(pos,neg) is the % correct, and min(pos,neg) is % wrong 00475 std::deque<float> errs; 00476 errs.push_back(std::min(leftPos,leftNeg)); 00477 errs.push_back(std::min(rightPos,rightNeg)); 00478 00479 // If we have no classification going on at left/right, then splitting is useless 00480 if(leftPos + leftNeg == 0) 00481 return; 00482 00483 if(rightPos + rightNeg == 0) 00484 return; 00485 00486 // Rank the errors based on ascending order 00487 std::deque<size_t> eIndex; 00488 util::sortrank(errs,eIndex); 00489 00490 // Copy the errors and nodes into tmp variables 00491 std::deque<float> eTmp = errs; 00492 std::deque<rutz::shared_ptr<DecisionNode> > tmpNodes = itsNodes; 00493 // Clear originals 00494 errs.clear(); 00495 itsNodes.clear(); 00496 // Put errors and nodes back in reverse order, with highest first 00497 for(int i=eIndex.size()-1;i>=0;i--) 00498 { 00499 errs.push_back(eTmp[eIndex[i]]); 00500 itsNodes.push_back(tmpNodes[eIndex[i]]); 00501 } 00502 00503 std::deque<rutz::shared_ptr<DecisionNode> > splits; 00504 std::deque<float> splitErrs; 00505 std::deque<float> deltas; 00506 00507 // Already split once, now split the remaining desired number of times (unless we get no error first) 00508 for(uint i=1;i<itsMaxSplits;i++) 00509 { 00510 ASSERT(itsNodes.size()>deltas.size()); 00511 std::deque<rutz::shared_ptr<DecisionNode> >::iterator nodeItr=itsNodes.begin()+deltas.size(); 00512 // Go through each node and determine the optimal split for that node 00513 // Only bother to determine the optimal split if it hasn't been done for that node yet (a delta has not been calculated yet) 00514 for(uint j=deltas.size();j<errs.size();j++,nodeItr++) 00515 { 00516 ASSERT(nodeItr!=itsNodes.end()); 00517 // Select the current node to test 00518 rutz::shared_ptr<DecisionNode> curNode = *(nodeItr); 00519 // Run the data through the node being tested 00520 std::vector<int> curNodeOut = curNode->decide(data); 00521 00522 // Take only the data where the node classified it as within class as a mask 00523 std::vector<uint> mask; 00524 for(uint idx=0;idx<curNodeOut.size();idx++) 00525 { 00526 if(curNodeOut[idx] == curNode->getClass()) 00527 mask.push_back(idx); 00528 } 00529 00530 leftPos=0,leftNeg=0; 00531 rightPos=0,rightNeg=0; 00532 float spliterr; 00533 // Check that there is actually masked data 00534 if(mask.size()>0) 00535 { 00536 // Apply mask to data, labels, and weights, and then split the nodes based on this weighted subset 00537 std::vector<std::vector<float> > maskedData(data.size()); 00538 std::vector<int> maskedLabels; 00539 std::vector<float> maskedWeights; 00540 bool allTrueLabels=true; 00541 bool allFalseLabels=true; 00542 for(uint idx=0;idx<mask.size();idx++) 00543 { 00544 for(uint idx2=0;idx2<data.size();idx2++) 00545 { 00546 maskedData[idx2].push_back(data[idx2][mask[idx]]); 00547 } 00548 if(labels[mask[idx]]==-1) 00549 allTrueLabels=false; 00550 if(labels[mask[idx]]==1) 00551 allFalseLabels=false; 00552 maskedLabels.push_back(labels[mask[idx]]); 00553 maskedWeights.push_back(weights[mask[idx]]); 00554 } 00555 // Check to make sure that all classified data is not already correct (in which case splitting is pointless) 00556 if((allTrueLabels && curNode->getClass()==1)||(allFalseLabels && curNode->getClass()==-1)) 00557 { 00558 LINFO("Ignoring split of node with no misclassifications"); 00559 leftPos=leftNeg=rightPos=rightNeg=0; 00560 spliterr = std::numeric_limits<float>::max(); 00561 } 00562 else 00563 { 00564 // Calculate the split error 00565 spliterr = curNode->split(maskedData,maskedLabels,maskedWeights,left,right,curNode); 00566 00567 // Determine how well the predictions correlate with the ground truth labels 00568 // Separately test the correlation and anticorrelation for both the left and right nodes 00569 std::vector<int> propAns = left->decide(data); 00570 for(uint a=0;a<propAns.size();a++) 00571 { 00572 if(propAns[a] == labels[a]) 00573 leftPos += weights[a]; 00574 else if(propAns[a] == -labels[a]) 00575 leftNeg += weights[a]; 00576 } 00577 propAns = right->decide(data); 00578 for(uint a=0;a<propAns.size();a++) 00579 { 00580 if(propAns[a] == labels[a]) 00581 rightPos += weights[a]; 00582 else if(propAns[a] == -labels[a]) 00583 rightNeg += weights[a]; 00584 } 00585 } 00586 } 00587 else 00588 { 00589 // Should not ever happen: No masked data, which means node doesn't contain any training data?!?!? 00590 LFATAL("No masked data, node does not contain any training data"); 00591 leftPos=leftNeg=rightPos=rightNeg=0; 00592 spliterr = std::numeric_limits<float>::max(); 00593 } 00594 // Append the left/right nodes to the list of split nodes (which will be used to select the best split later) 00595 splits.push_back(left); 00596 splits.push_back(right); 00597 00598 // Build delta error 00599 if(leftPos + leftNeg == 0 || rightPos + rightNeg == 0) 00600 deltas.push_back(0); 00601 else 00602 { 00603 LINFO("Delta for splitting node %d is %f, errs[] %f, spliterr %f",j,errs[j]-spliterr,errs[j],spliterr); 00604 deltas.push_back(errs[j]-spliterr); 00605 } 00606 00607 splitErrs.push_back(std::min(leftPos,leftNeg)); 00608 splitErrs.push_back(std::min(rightPos,rightNeg)); 00609 } 00610 00611 std::deque<float>::iterator maxElemItr = std::max_element(deltas.begin(),deltas.end()); 00612 LINFO("Best Delta %f For Iter %u",*maxElemItr,i); 00613 // If the smallest delta is zero, we're done 00614 if(*maxElemItr < 0.000001) 00615 { 00616 LINFO("Delta is zero or too small, done"); 00617 return; 00618 } 00619 // Get the best split index 00620 uint bestSplit = std::distance(deltas.begin(),maxElemItr); 00621 00622 // Make the split node not a leaf anymore 00623 itsNodes[bestSplit]->setLeaf(false); 00624 00625 // Remove the node that we are splitting 00626 itsNodes.erase(itsNodes.begin()+bestSplit); 00627 errs.erase(errs.begin()+bestSplit); 00628 deltas.erase(deltas.begin()+bestSplit); 00629 00630 // Insert the new left/right pair at the end 00631 ASSERT((splits.begin()+2*bestSplit)->is_valid()); 00632 ASSERT((splits.begin()+2*bestSplit+1)->is_valid()); 00633 itsNodes.push_back(*(splits.begin()+2*bestSplit)); 00634 itsNodes.push_back(*(splits.begin()+2*bestSplit+1)); 00635 00636 // Insert the corresponding split errs into the node err list 00637 errs.push_back(*(splitErrs.begin()+2*bestSplit)); 00638 errs.push_back(*(splitErrs.begin()+2*bestSplit+1)); 00639 00640 // Remove the chosen splits from the split list, since they are now official nodes 00641 splitErrs.erase(splitErrs.begin()+2*bestSplit,splitErrs.begin()+2*bestSplit+2); 00642 splits.erase(splits.begin()+2*bestSplit,splits.begin()+2*bestSplit+2); 00643 } 00644 } 00645 00646 00647 // ###################################################################### 00648 /* So things look consistent in everyone's emacs... */ 00649 /* Local Variables: */ 00650 /* indent-tabs-mode: nil */ 00651 /* End: */ 00652