GentleBoostBinary.C
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #include "Learn/GentleBoostBinary.H"
00038 #include "Util/Assert.H"
00039 #include "Util/log.H"
00040 #include "Util/sformat.H"
00041 #include "Util/SortUtil.H"
00042 #include <limits>
00043 #include <math.h>
00044 #include <stdio.h>
00045
00046
00047
00048 GentleBoostBinary::GentleBoostBinary(int maxTreeSize) :
00049 itsMaxTreeSize(maxTreeSize)
00050 {
00051 }
00052
00053
00054 std::vector<float> GentleBoostBinary::predict(const std::vector<std::vector<float> >& data)
00055 {
00056 return predict(data,itsWeights);
00057 }
00058
00059
00060 std::vector<float> GentleBoostBinary::predict(const std::vector<std::vector<float> >& data, std::vector<float> weights)
00061 {
00062 ASSERT(weights.size()==itsNodes.size());
00063 ASSERT(data.size()>0);
00064 std::vector<float> pred(data[0].size());
00065 for(size_t i=0;i<itsNodes.size();i++)
00066 {
00067 std::vector<int> tmppred=itsNodes[i]->decide(data);
00068 for(size_t j=0;j<pred.size();j++)
00069 {
00070 pred[j]+=float(tmppred[j])*weights[i];
00071 }
00072 }
00073 return pred;
00074 }
00075
00076
00077 void GentleBoostBinary::train(const std::vector<std::vector<float> >& data, const std::vector<int>& labels, int maxIters)
00078 {
00079 std::vector<float> predictions;
00080 train(data,labels,maxIters,predictions);
00081
00082 }
00083
00084 void GentleBoostBinary::train(const std::vector<std::vector<float> >& data, const std::vector<int>& labels, int maxIters, std::vector<float>& predictions)
00085 {
00086 ASSERT(data.size()>0);
00087 int nSamples = int(data[0].size());
00088 ASSERT(int(labels.size())==nSamples);
00089 std::vector<float> dataWeights;
00090 if(predictions.size()>0)
00091 {
00092 dataWeights = std::vector<float>(nSamples);
00093 for(size_t i=0;i<predictions.size();i++)
00094 {
00095 dataWeights[i]=exp(-(labels[i]*predictions[i]));
00096 }
00097 }
00098 else
00099 {
00100 dataWeights = std::vector<float>(nSamples,1.0F/float(nSamples));
00101 predictions = std::vector<float>(nSamples);
00102 }
00103
00104 for(int iter=0;iter<maxIters;iter++)
00105 {
00106 rutz::shared_ptr<DecisionTree> learner(new DecisionTree(itsMaxTreeSize));
00107 learner->train(data,labels,dataWeights);
00108 std::deque<rutz::shared_ptr<DecisionNode> > curNodes = learner->getNodes();
00109 if(curNodes.size()==0)
00110 {
00111 LINFO("Training complete, only trivial cuts found");
00112 return;
00113 }
00114 for(size_t idx=0;idx<curNodes.size();idx++)
00115 {
00116 rutz::shared_ptr<DecisionNode> curNode = curNodes[idx];
00117 std::vector<int> curNodeOut = curNode->decide(data);
00118 float s1=0.0F,s2=0.0F;
00119 for(size_t i=0;i<curNodeOut.size();i++)
00120 {
00121 if(curNodeOut[i]==1)
00122 {
00123 if(labels[i]==1)
00124 {
00125
00126 s1 += dataWeights[i];
00127 }
00128 else if(labels[i]==-1)
00129 {
00130
00131 s2 += dataWeights[i];
00132 }
00133 }
00134
00135 else if(curNodeOut[i]==-1)
00136 {
00137 if(labels[i]==-1)
00138 {
00139
00140 s1 += dataWeights[i];
00141 }
00142 else if(labels[i]==1)
00143 {
00144
00145 s2 += dataWeights[i];
00146 }
00147 }
00148 }
00149 if(s1==0.0F && s2==0.0F)
00150 continue;
00151 float alpha = (s1-s2)/(s1+s2);
00152 itsWeights.push_back(alpha);
00153 itsNodes.push_back(curNode);
00154 for(size_t i=0;i<predictions.size();i++)
00155 {
00156 predictions[i] += curNodeOut[i]*alpha;
00157 }
00158 }
00159 float sumDW=0;
00160 for(int i=0;i<nSamples;i++)
00161 {
00162 dataWeights[i] = exp(-1.0F * (labels[i]*predictions[i]));
00163 sumDW+=dataWeights[i];
00164 }
00165 if(sumDW>0)
00166 {
00167 for(int i=0;i<nSamples;i++)
00168 {
00169 dataWeights[i]/=sumDW;
00170 }
00171 }
00172 }
00173 }
00174
00175 void GentleBoostBinary::printTree()
00176 {
00177 std::deque<rutz::shared_ptr<DecisionNode> >::iterator itr;
00178 LINFO("Printing Tree of %Zu nodes",itsNodes.size());
00179 int i=0;
00180 for(itr=itsNodes.begin();itr!=itsNodes.end();itr++)
00181 {
00182 rutz::shared_ptr<DecisionNode> n=*itr;
00183 if(!n.is_valid())
00184 {
00185 LINFO("Node[%d] <Invalid Pointer>",i);
00186 continue;
00187 }
00188 std::string output;
00189 n->printNode(output);
00190 LINFO("Weight: %f\n%s",itsWeights[i],output.c_str());
00191 i++;
00192 }
00193 }
00194
00195 void GentleBoostBinary::writeTree(std::ostream& outstream)
00196 {
00197 rutz::shared_ptr<std::string> output = rutz::shared_ptr<std::string>(new std::string);
00198 std::deque<rutz::shared_ptr<DecisionNode> >::iterator itr;
00199 int i=0;
00200 for(itr=itsNodes.begin();itr!=itsNodes.end();itr++)
00201 {
00202 rutz::shared_ptr<DecisionNode> n=*itr;
00203 if(!n.is_valid())
00204 {
00205 continue;
00206 }
00207 outstream << sformat("TREEWEIGHT:%f; \n",itsWeights[i]);
00208 n->writeNode(outstream);
00209 }
00210 outstream << std::string("END\n");
00211 }
00212
00213 void GentleBoostBinary::readTree(std::istream& instream)
00214 {
00215 DecisionNode tmp;
00216 const int BUFFER_SIZE = 256;
00217 char buf[BUFFER_SIZE];
00218 int treeIdx=0;
00219
00220 bool nodeIsValid = true;
00221 while(nodeIsValid)
00222 {
00223 instream.getline(buf,BUFFER_SIZE);
00224 float treeWeight;
00225 int numItemsFound = sscanf(buf,"TREEWEIGHT:%f; ",&treeWeight);
00226 if(numItemsFound == 1)
00227 {
00228 rutz::shared_ptr<DecisionNode> node = tmp.readNode(instream);
00229 if(!node.is_valid())
00230 {
00231 LFATAL("No tree associated with tree weight at index %d",treeIdx);
00232 nodeIsValid = false;
00233 }
00234 itsWeights.push_back(treeWeight);
00235 itsNodes.push_back(node);
00236 treeIdx++;
00237 }
00238 else if(std::string(buf).compare("END")==0)
00239 {
00240 nodeIsValid = false;
00241 }
00242 else
00243 {
00244 LFATAL("Incomplete tree representation at index %d",treeIdx);
00245 nodeIsValid = false;
00246 }
00247 }
00248
00249 }
00250
00251
00252 void GentleBoostBinary::clear()
00253 {
00254 itsNodes.clear();
00255 itsWeights.clear();
00256 }