test-LocalBinaryPatterns2.C
00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Component/ModelManager.H"
00039 #include "Image/DrawOps.H"
00040 #include "Image/Kernels.H"
00041 #include "Image/CutPaste.H"
00042 #include "Image/ColorOps.H"
00043 #include "Image/FilterOps.H"
00044 #include "Raster/Raster.H"
00045 #include "Media/FrameSeries.H"
00046 #include "Util/Timer.H"
00047 #include "Util/CpuTimer.H"
00048 #include "Util/StringUtil.H"
00049 #include "Features/LocalBinaryPatterns.H"
00050 #include "Learn/LogLikelihoodClassifier.H"
00051 #include "Learn/SVMClassifier.H"
00052 #include "rutz/rand.h"
00053 #include "rutz/trace.h"
00054
00055 #include <math.h>
00056 #include <fcntl.h>
00057 #include <limits>
00058 #include <string>
00059
00060 #define TRAIN_WIDTH 160
00061 #define TRAIN_HEIGHT 160
00062 #define SAMPLE_WIDTH 160
00063 #define SAMPLE_HEIGHT 160
00064
00065 #define TEST_SIZE 20
00066
00067 #define USE_SVM 0 // Whether to use SVM (if not uses log likelihood classifier)
00068
00069 int main(const int argc, const char **argv)
00070 {
00071
00072 MYLOGVERB = LOG_INFO;
00073 ModelManager manager("Test LocalBinaryPatterns");
00074
00075
00076 rutz::urand rgen(time((time_t*)0)+getpid());
00077
00078
00079 LogLikelihoodClassifier ll = LogLikelihoodClassifier(7);
00080 SVMClassifier svm = SVMClassifier();
00081 std::vector<LocalBinaryPatterns> lbp;
00082
00083 lbp.push_back(LocalBinaryPatterns(2,16,0,false,true));
00084 lbp.push_back(LocalBinaryPatterns(3,24,0,false,true));
00085
00086 if (manager.parseCommandLine(
00087 (const int)argc, (const char**)argv, "<texture1file> ... <textureNfile>", 2, 200) == false)
00088 return 0;
00089
00090 manager.start();
00091
00092 uint numCategories = manager.numExtraArgs();
00093 std::vector<std::string> texFile;
00094 for(uint i=0;i<numCategories;i++)
00095 {
00096 texFile.push_back(manager.getExtraArg(i));
00097 }
00098
00099 const Dims trainDims = Dims(TRAIN_WIDTH,TRAIN_HEIGHT);
00100 for(uint idx=0;idx<numCategories;idx++)
00101 {
00102 Image<float> tex = Raster::ReadGray(texFile[idx]);
00103
00104 float tw = std::min(TRAIN_WIDTH,int(tex.getWidth()/2.0));
00105 float th = std::min(TRAIN_HEIGHT,int(tex.getHeight()));
00106 for(uint xp=0;xp<=uint(tex.getWidth()/2.0-tw);xp+=tw)
00107 for(uint yp=0;yp<=uint(tex.getHeight()-th);yp+=th)
00108 {
00109 LINFO("Adding crop for id[%u] at pos [%ux%u]",idx,xp,yp);
00110 Image<float> samp = crop(tex,Rectangle(Point2D<int>(xp,yp),trainDims));
00111 for(uint o=0;o<lbp.size();o++)
00112 lbp[o].addModel(toRGB(samp),idx+1);
00113 }
00114 }
00115
00116
00117 std::vector<LocalBinaryPatterns::MapModelVector> allModels;
00118 for(uint o=0;o<lbp.size();o++)
00119 {
00120 LINFO("Building variance model for LBP class [%u]",o);
00121 lbp[o].buildModels();
00122 allModels.push_back(lbp[o].getModels());
00123 }
00124 LocalBinaryPatterns::MapModelVector completeModel;
00125 lbp[0].combineModels(allModels,completeModel);
00126 if(USE_SVM)
00127 {
00128 std::vector<std::vector<float> > data;
00129 std::vector<float> labels;
00130 lbp[0].getLabeledData(completeModel,data,labels);
00131 svm.train(data,labels);
00132 }
00133 else
00134 {
00135 ll.setModels(completeModel);
00136 }
00137
00138
00139 int numCorrect=0;
00140 const Dims sampleDims = Dims(SAMPLE_WIDTH,SAMPLE_HEIGHT);
00141
00142 for(uint s=0;s<TEST_SIZE;s++)
00143 {
00144
00145 int idx = rgen.idraw(numCategories);
00146 Image<float> tex = Raster::ReadGray(texFile[idx]);
00147
00148 float tw = std::min(SAMPLE_WIDTH,int(tex.getWidth()/2.0));
00149 float th = std::min(SAMPLE_HEIGHT,int(tex.getHeight()));
00150 int xp,yp;
00151 xp=rgen.idraw_range(tex.getWidth()/2.0,int(tex.getWidth()-tw));
00152 yp=rgen.idraw_range(0,int(tex.getHeight()-th));
00153 Image<float> samp = crop(tex,Rectangle(Point2D<int>(xp,yp),sampleDims));
00154
00155 std::vector<float> hist;
00156 for(uint o=0;o<lbp.size();o++)
00157 {
00158 std::vector<float> tmpHist = lbp[o].createHistogram(samp);
00159 hist.insert(hist.begin(),tmpHist.begin(),tmpHist.end());
00160 }
00161 int gtIdx = idx+1;
00162 int predIdx;
00163 if(USE_SVM)
00164 {
00165 predIdx = (int) svm.predict(hist);
00166 }
00167 else
00168 predIdx = ll.predict(hist);
00169 LINFO("Index Ground Truth [%d], Predicted [%d]",gtIdx,predIdx);
00170 if(predIdx == gtIdx) numCorrect++;
00171 }
00172 LINFO("Test Accuracy %f, Random chance would be %f",float(numCorrect)/TEST_SIZE,1.0F/numCategories);
00173 manager.stop();
00174
00175 }
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186