GentleBoost.C
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #include "GentleBoost.H"
00038 #include "Util/Assert.H"
00039 #include "Util/log.H"
00040 #include "Util/SortUtil.H"
00041 #include "Util/sformat.H"
00042 #include <limits>
00043 #include <math.h>
00044 #include <stdio.h>
00045
00046
00047 GentleBoost::GentleBoost(int maxTreeSize) :
00048 itsMaxTreeSize(maxTreeSize)
00049 {
00050 }
00051
00052 std::vector<std::vector<float> > GentleBoost::transpose(const std::vector<std::vector<float> >& data)
00053 {
00054 std::vector<std::vector<float> > out;
00055 if(data.size()==0)
00056 return out;
00057 out.resize(data[0].size());
00058 for(uint i=0;i<data[0].size();i++)
00059 {
00060 for(uint j=0;j<data.size();j++)
00061 {
00062 out[i].push_back(data[j][i]);
00063 }
00064 }
00065 return out;
00066 }
00067
00068 std::map<int,std::vector<float> > GentleBoost::predictPDF(const std::vector<std::vector<float> >& data)
00069 {
00070 ASSERT(itsLearners.size()>0);
00071 ASSERT(data.size()>0);
00072 std::map<int,std::vector<float> > perClassPreds;
00073 std::map<int,GentleBoostBinary>::iterator litr;
00074 for(litr=itsLearners.begin();litr!=itsLearners.end();litr++)
00075 {
00076 perClassPreds[litr->first] = litr->second.predict(data);
00077 }
00078 return perClassPreds;
00079 }
00080
00081 int GentleBoost::getMostLikelyClass(const std::map<int,std::vector<float> >& pdf, int index)
00082 {
00083 float maxClassVal=-1000;
00084 int maxClassIdx=-1;
00085 std::map<int,GentleBoostBinary>::iterator litr;
00086 for(litr=itsLearners.begin();litr!=itsLearners.end();litr++)
00087 {
00088 std::map<int,std::vector<float> >::const_iterator pitr=pdf.find(litr->first);
00089 ASSERT(pitr != pdf.end());
00090 if(maxClassVal < pitr->second[index])
00091 {
00092 maxClassVal=pitr->second[index];
00093 maxClassIdx=litr->first;
00094 }
00095 }
00096 ASSERT(maxClassIdx>=0);
00097 return maxClassIdx;
00098 }
00099
00100 std::vector<int> GentleBoost::getMostLikelyClass(const std::map<int,std::vector<float> >& pdf)
00101 {
00102 ASSERT(itsLearners.size()>0);
00103 ASSERT(pdf.size()>0);
00104 std::vector<int> preds;
00105 int nSamples=(pdf.begin())->second.size();
00106 for(int s=0;s<nSamples;s++)
00107 {
00108 int maxClassIdx = getMostLikelyClass(pdf,s);
00109 preds.push_back(maxClassIdx);
00110 }
00111 return preds;
00112 }
00113
00114 std::vector<int> GentleBoost::predict(const std::vector<std::vector<float> >& data)
00115 {
00116 std::map<int,std::vector<float> > perClassPreds=predictPDF(data);
00117 return getMostLikelyClass(perClassPreds);
00118 }
00119
00120
00121 void GentleBoost::train(const std::vector<std::vector<float> >& data, const std::vector<int>& labels, int maxIters)
00122 {
00123 std::map<int,std::vector<int> > perClassLabels = convertLabels(labels);
00124 std::map<int,std::vector<int> >::iterator litr;
00125 for(litr=perClassLabels.begin();litr!=perClassLabels.end();litr++)
00126 {
00127 itsLearners[litr->first] = GentleBoostBinary(itsMaxTreeSize);
00128 itsLearners[litr->first].train(data,litr->second,maxIters);
00129 }
00130 }
00131
00132 void GentleBoost::printAllTrees()
00133 {
00134 std::map<int,GentleBoostBinary>::iterator litr;
00135 for(litr=itsLearners.begin();litr!=itsLearners.end();litr++)
00136 {
00137 LINFO("Printing Gentle Boost Binary Classification Tree for Class %d",litr->first);
00138 litr->second.printTree();
00139 }
00140 }
00141
00142 void GentleBoost::writeAllTrees(std::ostream& outstream)
00143 {
00144 rutz::shared_ptr<std::string> output = rutz::shared_ptr<std::string>(new std::string());
00145 std::map<int,GentleBoostBinary>::iterator litr;
00146 outstream << sformat("MAXTREES:%d; \n",itsMaxTreeSize);
00147 for(litr=itsLearners.begin();litr!=itsLearners.end();litr++)
00148 {
00149 outstream << sformat("TREECLASS:%d; \n",litr->first);
00150 litr->second.writeTree(outstream);
00151 }
00152 outstream << std::string("END\n");
00153 }
00154
00155 void GentleBoost::readAllTrees(std::istream& instream)
00156 {
00157 int treeIdx=0;
00158 bool treeIsValid = true;
00159 const int BUFFER_SIZE = 256;
00160 char buf[BUFFER_SIZE];
00161 instream.getline(buf,BUFFER_SIZE);
00162 int numItemsFound = sscanf(buf,"MAXTREES:%d; ",&itsMaxTreeSize);
00163 if(numItemsFound != 1)
00164 LFATAL("Invalid GentleBoost format, MAXTREES undefined");
00165 while(treeIsValid)
00166 {
00167 instream.getline(buf,BUFFER_SIZE);
00168 int treeClass;
00169 int numItemsFound = sscanf(buf,"TREECLASS:%d; ",&treeClass);
00170 if(numItemsFound == 1)
00171 {
00172 GentleBoostBinary gbb;
00173 gbb.readTree(instream);
00174
00175 itsLearners[treeClass] = gbb;
00176
00177 treeIdx++;
00178 }
00179 else if(std::string(buf).compare("END")==0)
00180 {
00181 treeIsValid = false;
00182 }
00183 else
00184 {
00185 LFATAL("Incomplete tree representation at index %d",treeIdx);
00186 treeIsValid = false;
00187 }
00188
00189 }
00190
00191 }
00192
00193 void GentleBoost::save(std::string file)
00194 {
00195 std::ofstream outf(file.c_str(),std::ofstream::out);
00196 writeAllTrees(outf);
00197 }
00198
00199 void GentleBoost::load(std::string file)
00200 {
00201 std::ifstream inf(file.c_str(),std::ofstream::in);
00202 readAllTrees(inf);
00203 }
00204
00205
00206
00207 std::map<int,std::vector<int> > GentleBoost::convertLabels(const std::vector<int>& labels)
00208 {
00209 std::map<int,std::vector<int> > perClassLabels;
00210
00211 int nSamples=labels.size();
00212
00213 for(int i=0;i<nSamples;i++)
00214 {
00215 if(perClassLabels.find(labels[i])==perClassLabels.end())
00216 {
00217
00218 perClassLabels[labels[i]] = std::vector<int>(nSamples,-1);
00219 }
00220 perClassLabels[labels[i]][i]=1;
00221 }
00222
00223 return perClassLabels;
00224 }