00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #include "DecisionTree.H"
00038 #include "Util/Assert.H"
00039 #include "Util/log.H"
00040 #include "Util/SortUtil.H"
00041 #include "Util/sformat.H"
00042 #include <limits>
00043 #include <math.h>
00044 #include <stdio.h>
00045
00046 DecisionNode::DecisionNode() :
00047 itsDim(-1),
00048 itsLeaf(true),
00049 itsLeftConstraint(-std::numeric_limits<float>::infinity()),
00050 itsRightConstraint(std::numeric_limits<float>::infinity()),
00051 itsClass(1),
00052 itsParent(NULL)
00053 {
00054
00055 }
00056
00057 bool DecisionNode::isValid()
00058 {
00059 return itsDim>=0;
00060 }
00061
00062 int DecisionNode::printNode(std::string& output,int depth)
00063 {
00064 int retDepth;
00065 if(itsParent.is_valid())
00066 retDepth=itsParent->printNode(output,depth)+1;
00067 else retDepth=depth;
00068 char indent[250];
00069 indent[0]='\0';
00070 for(int i=0;i<retDepth;i++)
00071 sprintf(indent,"%s\t",indent);
00072 output = sformat("%s%sNode[%p]:<Leaf:%s> <Dim:%d, Class:%d, LeftConstraint:%f, RightConstraint:%f, Parent:%p\n",output.c_str(),indent,this,(itsLeaf)?"t":"f",itsDim,itsClass,itsLeftConstraint,itsRightConstraint,itsParent.get());
00073 return retDepth;
00074 }
00075
00076 void DecisionNode::writeNode(std::ostream& outstream, bool needEnd)
00077 {
00078 outstream << sformat("DIM:%d,CLASS:%d,LC:%f,RC:%f; \n",itsDim,itsClass,itsLeftConstraint,itsRightConstraint);
00079 if(itsParent.is_valid())
00080 itsParent->writeNode(outstream,false);
00081 if(needEnd)
00082 outstream << std::string("END\n");
00083 }
00084
00085
00086 rutz::shared_ptr<DecisionNode> DecisionNode::readNode(std::istream& instream)
00087 {
00088 bool nodeIsValid = true;
00089 const int BUFFER_SIZE = 256;
00090 char buf[BUFFER_SIZE];
00091 int depth=0;
00092 rutz::shared_ptr<DecisionNode> root(NULL), curNode(NULL);
00093 while(nodeIsValid)
00094 {
00095 instream.getline(buf,BUFFER_SIZE);
00096 int dim,cls;
00097 float lc,rc;
00098 int numItemsFound = sscanf(buf,"DIM:%d,CLASS:%d,LC:%f,RC:%f; ",&dim,&cls,&lc,&rc);
00099 if(numItemsFound == 4)
00100 {
00101 rutz::shared_ptr<DecisionNode> node(new DecisionNode());
00102 node->setDim(dim);
00103 node->setClass(cls);
00104 node->setLeftConstraint(lc);
00105 node->setRightConstraint(rc);
00106
00107 if(!root.is_valid())
00108 {
00109 root = node;
00110 curNode = root;
00111 }
00112
00113 else
00114 {
00115 curNode->setParent(node);
00116 curNode = node;
00117 }
00118 depth++;
00119 }
00120 else if(std::string(buf).compare("END")==0)
00121 {
00122
00123 if(curNode.is_valid())
00124 {
00125 curNode->setLeaf(true);
00126 }
00127 else
00128 {
00129 LFATAL("Empty node list in decision tree");
00130 }
00131 nodeIsValid = false;
00132 }
00133 else
00134 {
00135 LFATAL("Incomplete node representation at depth %d num found %d buffer[%s]",depth,numItemsFound,buf);
00136 nodeIsValid = false;
00137 }
00138 }
00139 return root;
00140 }
00141
00142
00143
00144
00145
00146 std::vector<int> DecisionNode::decide(const std::vector<std::vector<float> >& data)
00147 {
00148 ASSERT(int(data.size())>itsDim && itsDim>=0);
00149 std::vector<int> y(data[itsDim].size(),1);
00150 rutz::shared_ptr<DecisionNode> parNode = itsParent;
00151
00152 if(itsParent.is_valid())
00153 {
00154 std::vector<int> py = parNode->decide(data);
00155 for(uint s=0;s<data[itsDim].size();s++)
00156 {
00157 y[s] *= py[s];
00158 }
00159 }
00160
00161 int inVal=1;
00162
00163 if(itsLeaf)
00164 {
00165 inVal = itsClass;
00166 }
00167
00168
00169 for(uint s=0;s<data[itsDim].size();s++)
00170 {
00171 if(y[s]>0)
00172 y[s] *= (data[itsDim][s] < itsRightConstraint && data[itsDim][s] >= itsLeftConstraint) ? inVal : 0;
00173 }
00174
00175 return y;
00176 }
00177
00178 size_t DecisionNode::getDim()
00179 {
00180 return itsDim;
00181 }
00182
00183 void DecisionNode::setDim(size_t dim)
00184 {
00185 itsDim = dim;
00186 }
00187
00188 void DecisionNode::setLeaf(bool isLeaf)
00189 {
00190 itsLeaf = isLeaf;
00191
00192 if(!itsLeaf)
00193 itsClass=0;
00194 }
00195
00196 void DecisionNode::setParent(rutz::shared_ptr<DecisionNode> parent)
00197 {
00198 itsParent = parent;
00199 }
00200
00201
00202 void DecisionNode::setLeftConstraint(float constraint)
00203 {
00204 itsLeftConstraint = constraint;
00205 }
00206
00207 void DecisionNode::setRightConstraint(float constraint)
00208 {
00209 itsRightConstraint = constraint;
00210 }
00211
00212 void DecisionNode::setClass(int classId)
00213 {
00214 itsClass = classId;
00215 }
00216
00217 int DecisionNode::getClass()
00218 {
00219 return itsClass;
00220 }
00221
00222 float DecisionNode::split(const std::vector<std::vector<float> >& data, const std::vector<int>& labels,const std::vector<float>& weights, rutz::shared_ptr<DecisionNode>& left, rutz::shared_ptr<DecisionNode>& right, const rutz::shared_ptr<DecisionNode> parent)
00223 {
00224 left = rutz::shared_ptr<DecisionNode>(new DecisionNode);
00225 right = rutz::shared_ptr<DecisionNode>(new DecisionNode);
00226 left->setParent(parent);
00227 right->setParent(parent);
00228
00229
00230 ASSERT(data.size() > 0);
00231
00232 uint tr_size = data[0].size();
00233
00234 std::vector<float> bestErr;
00235 std::vector<size_t> bestErrIdx;
00236 std::vector<float> bestErrDir;
00237
00238 for(uint d=0;d<data.size();d++)
00239 {
00240
00241 std::vector<size_t> dindices;
00242 util::sortrank(data[d],dindices);
00243 std::vector<float> dsorted=data[d];
00244
00245 std::sort(dsorted.begin(),dsorted.end());
00246
00247
00248 std::vector<float> vPos(tr_size);
00249 std::vector<float> vNeg(tr_size);
00250
00251 uint i=0,j=0;
00252 while(i<dsorted.size())
00253 {
00254 uint k = 0;
00255 while(i + k < dsorted.size() && dsorted[i] == dsorted[i+k])
00256 {
00257 if(labels[dindices[i+k]] > 0)
00258 vPos[j] += weights[dindices[i+k]];
00259 else
00260 vNeg[j] += weights[dindices[i+k]];
00261 k++;
00262 }
00263 i += k;
00264 j++;
00265 }
00266
00267 vNeg.resize(j);
00268 vPos.resize(j);
00269
00270 std::vector<float> err(vPos.size());
00271 std::vector<float> invErr(vPos.size());
00272
00273 std::vector<float> iPos(vPos.size());
00274 std::vector<float> iNeg(vNeg.size());
00275
00276 for(i=0;i<iPos.size();i++)
00277 {
00278 if(i==0)
00279 {
00280 iPos[0] = vPos[0];
00281 iNeg[0] = vNeg[0];
00282 }
00283 else
00284 {
00285 iPos[i] = iPos[i-1] + vPos[i];
00286 iNeg[i] = iNeg[i-1] + vNeg[i];
00287 }
00288 }
00289
00290
00291 float totalN = (iNeg.size()>0) ? iNeg[iNeg.size()-1] : 0;
00292 float totalP = (iPos.size()>0) ? iPos[iPos.size()-1] : 0;
00293
00294
00295
00296 if(totalN<0.00001 || totalP<0.00001)
00297 return std::numeric_limits<float>::max();
00298
00299
00300 for(i=0;i<j;i++)
00301 {
00302
00303 err[i] = (iPos[i])/totalP + (totalN - iNeg[i])/totalN;
00304 invErr[i] = (iNeg[i])/totalN + (totalP - iPos[i])/totalP;
00305
00306 }
00307
00308
00309 size_t errMinIdx = std::distance(err.begin(),std::min_element(err.begin(),err.end()));
00310
00311
00312 size_t invErrMinIdx = std::distance(invErr.begin(),std::min_element(invErr.begin(),invErr.end()));
00313
00314
00315
00316
00317
00318 if(err[errMinIdx] < invErr[invErrMinIdx])
00319 {
00320 bestErr.push_back(err[errMinIdx]);
00321 bestErrIdx.push_back(errMinIdx);
00322 bestErrDir.push_back(-1);
00323 }
00324 else
00325 {
00326 bestErr.push_back(invErr[invErrMinIdx]);
00327 bestErrIdx.push_back(invErrMinIdx);
00328 bestErrDir.push_back(1);
00329 }
00330 }
00331
00332
00333 size_t bestDim = std::distance(bestErr.begin(),std::min_element(bestErr.begin(),bestErr.end()));
00334
00335 std::vector<float> tmpDimSorted = data[bestDim];
00336 std::sort(tmpDimSorted.begin(),tmpDimSorted.end());
00337 std::vector<float> dimSorted(tmpDimSorted.size());
00338
00339 uint i = 0;
00340 uint j = 0;
00341
00342
00343 while(i<dimSorted.size())
00344 {
00345 uint k = 0;
00346 while(i+k<dimSorted.size() && tmpDimSorted[i] == tmpDimSorted[i+k])
00347 {
00348 dimSorted[j] = tmpDimSorted[i];
00349 k++;
00350 }
00351 i += k;
00352 j++;
00353 }
00354
00355 dimSorted.resize(j);
00356
00357
00358 float threshold = (dimSorted[bestErrIdx[bestDim]] + dimSorted[std::min(bestErrIdx[bestDim]+1,dimSorted.size()-1)])/2.0;
00359
00360 left->setDim(bestDim);
00361 left->setRightConstraint(threshold);
00362 left->setClass(bestErrDir[bestDim]);
00363
00364 right->setDim(bestDim);
00365 right->setLeftConstraint(threshold);
00366 right->setClass(-bestErrDir[bestDim]);
00367 return bestErr[bestDim];
00368 }
00369
00370
00371 DecisionTree::DecisionTree(int maxSplits) :
00372 itsMaxSplits(maxSplits)
00373 {
00374 }
00375
00376 std::deque<rutz::shared_ptr<DecisionNode> > DecisionTree::getNodes()
00377 {
00378 return itsNodes;
00379 }
00380
00381 void DecisionTree::addNode(rutz::shared_ptr<DecisionNode> node)
00382 {
00383 itsNodes.push_back(node);
00384 }
00385
00386 void DecisionTree::printTree()
00387 {
00388 std::deque<rutz::shared_ptr<DecisionNode> >::iterator itr;
00389 LINFO("Printing Tree of %Zu nodes",itsNodes.size());
00390 int i=0;
00391 for(itr=itsNodes.begin();itr!=itsNodes.end();itr++)
00392 {
00393 rutz::shared_ptr<DecisionNode> n=*itr;
00394 if(!n.is_valid())
00395 {
00396 LINFO("Node[%d] <Invalid Pointer>",i);
00397 continue;
00398 }
00399 std::string output;
00400 n->printNode(output);
00401 LINFO("%s",output.c_str());
00402 i++;
00403 }
00404 }
00405
00406 std::vector<int> DecisionTree::predict(const std::vector<std::vector<float> >& data, std::vector<float> weights)
00407 {
00408 ASSERT(data.size()>0);
00409 if(weights.size()==0)
00410 weights.resize(data[0].size(),1.0F);
00411 size_t sampleDim = data[0].size();
00412 std::vector<int> prediction(sampleDim);
00413 std::deque<rutz::shared_ptr<DecisionNode> >::iterator itr;
00414 for(itr=itsNodes.begin();itr!=itsNodes.end();itr++)
00415 {
00416 std::vector<int> y = (*itr)->decide(data);
00417 for(uint i=0;i<y.size();i++)
00418 {
00419 prediction[i] += (y[i])*weights[i];
00420 }
00421 }
00422 return prediction;
00423 }
00424
00425
00426
00427 void DecisionTree::train(const std::vector<std::vector<float> >& data, const std::vector<int>& labels, std::vector<float> weights)
00428 {
00429 ASSERT(labels.size()==weights.size() || weights.size()==0);
00430
00431 if(weights.size()==0)
00432 {
00433 weights.resize(labels.size(),1.0/float(labels.size()));
00434 }
00435
00436 itsNodes.clear();
00437
00438 rutz::shared_ptr<DecisionNode> tmpNode(new DecisionNode());
00439
00440
00441 rutz::shared_ptr<DecisionNode> left,right;
00442 tmpNode->split(data,labels,weights,left,right);
00443
00444 if(!left->isValid() || !right->isValid())
00445 {
00446 LINFO("Split could not find a non-trivial cut in the remaining data");
00447 return;
00448 }
00449 itsNodes.push_back(left);
00450 itsNodes.push_back(right);
00451
00452
00453
00454 std::vector<int> propAns = left->decide(data);
00455 float leftPos=0,leftNeg=0;
00456 for(uint a=0;a<propAns.size();a++)
00457 {
00458 if(propAns[a] == labels[a])
00459 leftPos += weights[a];
00460 else if(propAns[a] == -labels[a])
00461 leftNeg += weights[a];
00462 }
00463 propAns = right->decide(data);
00464 float rightPos=0,rightNeg=0;
00465 for(uint a=0;a<propAns.size();a++)
00466 {
00467 if(propAns[a] == labels[a])
00468 rightPos += weights[a];
00469 else if(propAns[a] == -labels[a])
00470 rightNeg += weights[a];
00471 }
00472
00473
00474
00475 std::deque<float> errs;
00476 errs.push_back(std::min(leftPos,leftNeg));
00477 errs.push_back(std::min(rightPos,rightNeg));
00478
00479
00480 if(leftPos + leftNeg == 0)
00481 return;
00482
00483 if(rightPos + rightNeg == 0)
00484 return;
00485
00486
00487 std::deque<size_t> eIndex;
00488 util::sortrank(errs,eIndex);
00489
00490
00491 std::deque<float> eTmp = errs;
00492 std::deque<rutz::shared_ptr<DecisionNode> > tmpNodes = itsNodes;
00493
00494 errs.clear();
00495 itsNodes.clear();
00496
00497 for(int i=eIndex.size()-1;i>=0;i--)
00498 {
00499 errs.push_back(eTmp[eIndex[i]]);
00500 itsNodes.push_back(tmpNodes[eIndex[i]]);
00501 }
00502
00503 std::deque<rutz::shared_ptr<DecisionNode> > splits;
00504 std::deque<float> splitErrs;
00505 std::deque<float> deltas;
00506
00507
00508 for(uint i=1;i<itsMaxSplits;i++)
00509 {
00510 ASSERT(itsNodes.size()>deltas.size());
00511 std::deque<rutz::shared_ptr<DecisionNode> >::iterator nodeItr=itsNodes.begin()+deltas.size();
00512
00513
00514 for(uint j=deltas.size();j<errs.size();j++,nodeItr++)
00515 {
00516 ASSERT(nodeItr!=itsNodes.end());
00517
00518 rutz::shared_ptr<DecisionNode> curNode = *(nodeItr);
00519
00520 std::vector<int> curNodeOut = curNode->decide(data);
00521
00522
00523 std::vector<uint> mask;
00524 for(uint idx=0;idx<curNodeOut.size();idx++)
00525 {
00526 if(curNodeOut[idx] == curNode->getClass())
00527 mask.push_back(idx);
00528 }
00529
00530 leftPos=0,leftNeg=0;
00531 rightPos=0,rightNeg=0;
00532 float spliterr;
00533
00534 if(mask.size()>0)
00535 {
00536
00537 std::vector<std::vector<float> > maskedData(data.size());
00538 std::vector<int> maskedLabels;
00539 std::vector<float> maskedWeights;
00540 bool allTrueLabels=true;
00541 bool allFalseLabels=true;
00542 for(uint idx=0;idx<mask.size();idx++)
00543 {
00544 for(uint idx2=0;idx2<data.size();idx2++)
00545 {
00546 maskedData[idx2].push_back(data[idx2][mask[idx]]);
00547 }
00548 if(labels[mask[idx]]==-1)
00549 allTrueLabels=false;
00550 if(labels[mask[idx]]==1)
00551 allFalseLabels=false;
00552 maskedLabels.push_back(labels[mask[idx]]);
00553 maskedWeights.push_back(weights[mask[idx]]);
00554 }
00555
00556 if((allTrueLabels && curNode->getClass()==1)||(allFalseLabels && curNode->getClass()==-1))
00557 {
00558 LINFO("Ignoring split of node with no misclassifications");
00559 leftPos=leftNeg=rightPos=rightNeg=0;
00560 spliterr = std::numeric_limits<float>::max();
00561 }
00562 else
00563 {
00564
00565 spliterr = curNode->split(maskedData,maskedLabels,maskedWeights,left,right,curNode);
00566
00567
00568
00569 std::vector<int> propAns = left->decide(data);
00570 for(uint a=0;a<propAns.size();a++)
00571 {
00572 if(propAns[a] == labels[a])
00573 leftPos += weights[a];
00574 else if(propAns[a] == -labels[a])
00575 leftNeg += weights[a];
00576 }
00577 propAns = right->decide(data);
00578 for(uint a=0;a<propAns.size();a++)
00579 {
00580 if(propAns[a] == labels[a])
00581 rightPos += weights[a];
00582 else if(propAns[a] == -labels[a])
00583 rightNeg += weights[a];
00584 }
00585 }
00586 }
00587 else
00588 {
00589
00590 LFATAL("No masked data, node does not contain any training data");
00591 leftPos=leftNeg=rightPos=rightNeg=0;
00592 spliterr = std::numeric_limits<float>::max();
00593 }
00594
00595 splits.push_back(left);
00596 splits.push_back(right);
00597
00598
00599 if(leftPos + leftNeg == 0 || rightPos + rightNeg == 0)
00600 deltas.push_back(0);
00601 else
00602 {
00603 LINFO("Delta for splitting node %d is %f, errs[] %f, spliterr %f",j,errs[j]-spliterr,errs[j],spliterr);
00604 deltas.push_back(errs[j]-spliterr);
00605 }
00606
00607 splitErrs.push_back(std::min(leftPos,leftNeg));
00608 splitErrs.push_back(std::min(rightPos,rightNeg));
00609 }
00610
00611 std::deque<float>::iterator maxElemItr = std::max_element(deltas.begin(),deltas.end());
00612 LINFO("Best Delta %f For Iter %u",*maxElemItr,i);
00613
00614 if(*maxElemItr < 0.000001)
00615 {
00616 LINFO("Delta is zero or too small, done");
00617 return;
00618 }
00619
00620 uint bestSplit = std::distance(deltas.begin(),maxElemItr);
00621
00622
00623 itsNodes[bestSplit]->setLeaf(false);
00624
00625
00626 itsNodes.erase(itsNodes.begin()+bestSplit);
00627 errs.erase(errs.begin()+bestSplit);
00628 deltas.erase(deltas.begin()+bestSplit);
00629
00630
00631 ASSERT((splits.begin()+2*bestSplit)->is_valid());
00632 ASSERT((splits.begin()+2*bestSplit+1)->is_valid());
00633 itsNodes.push_back(*(splits.begin()+2*bestSplit));
00634 itsNodes.push_back(*(splits.begin()+2*bestSplit+1));
00635
00636
00637 errs.push_back(*(splitErrs.begin()+2*bestSplit));
00638 errs.push_back(*(splitErrs.begin()+2*bestSplit+1));
00639
00640
00641 splitErrs.erase(splitErrs.begin()+2*bestSplit,splitErrs.begin()+2*bestSplit+2);
00642 splits.erase(splits.begin()+2*bestSplit,splits.begin()+2*bestSplit+2);
00643 }
00644 }
00645
00646
00647
00648
00649
00650
00651
00652