00001 /*!@file Learn/RecurRecurBayes.C Recursive RecurBayesian network classifier */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Lior Elazary <elazary@usc.edu> 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/Learn/RecurBayes.C $ 00035 // $Id: RecurBayes.C 10794 2009-02-08 06:21:09Z itti $ 00036 // 00037 00038 //This is a Naive RecurBayes for now 00039 #ifndef LEARN_BAYES_C_DEFINED 00040 #define LEARN_BAYES_C_DEFINED 00041 00042 #include "Learn/RecurBayes.H" 00043 #include "Util/Assert.H" 00044 #include "Util/log.H" 00045 #include <math.h> 00046 #include <fcntl.h> 00047 #include <limits> 00048 #include <string> 00049 00050 // ###################################################################### 00051 RecurBayes::RecurBayes(uint numClasses, uint numFeatures, uint numFix): 00052 itsNumFeatures(numFeatures), itsNumClasses(numClasses), itsNumFix(numFix), 00053 itsMean(numClasses, 00054 std::vector<std::vector<double> > 00055 (numFix, std::vector<double>(numFeatures,0))), 00056 itsVar(numClasses, 00057 std::vector<std::vector<double> > 00058 (numFix, std::vector<double>(numFeatures,0.01))), 00059 itsClassFreq(numClasses,0), 00060 itsFeatureNames(numFeatures), 00061 itsClassNames(numClasses, "No Name") 00062 { 00063 00064 } 00065 00066 // ###################################################################### 00067 RecurBayes::~RecurBayes() 00068 { 00069 } 00070 00071 00072 // ###################################################################### 00073 void RecurBayes::learn(std::vector<double> &fv, const char *name, uint fix) 00074 { 00075 //get the class id 00076 00077 int cls = getClassId(name); 00078 if (cls == -1) //this is a new class, add it to the network 00079 cls = addClass(name); 00080 00081 ASSERT(fv.size() == itsNumFeatures); 00082 00083 //update the class freq 00084 ASSERT((uint)cls < itsNumClasses); 00085 itsClassFreq[cls]++; 00086 00087 //compute the stddev and mean of each feature 00088 //This algorithm is due to Knuth (The Art of Computer Programming, volume 2: 00089 // Seminumerical Algorithms, 3rd edn., p. 232. Boston: Addison-Wesley.) 00090 for (uint i=0; i<fv.size(); i++){ 00091 double val = fv[i]; 00092 double delta = val - itsMean[cls][fix][i]; 00093 itsMean[cls][fix][i] += delta/itsClassFreq[cls]; 00094 if (itsClassFreq[cls] > 3) 00095 { 00096 itsVar[cls][fix][i] = (itsVar[cls][fix][i]*(itsClassFreq[cls]-2)) 00097 + delta*(val - itsMean[cls][fix][i]); 00098 } 00099 if (itsClassFreq[cls] > 1) //watch for divide by 0 00100 itsVar[cls][fix][i] /= double(itsClassFreq[cls]-1); 00101 } 00102 00103 } 00104 00105 // ###################################################################### 00106 double RecurBayes::getMean(uint cls, uint i, uint fix) 00107 { 00108 ASSERT(cls < itsNumClasses && i < itsNumFeatures); 00109 return itsMean[cls][fix][i]; 00110 } 00111 00112 // ###################################################################### 00113 double RecurBayes::getStdevSq(uint cls, uint i, uint fix) 00114 { 00115 ASSERT(cls < itsNumClasses && i < itsNumFeatures); 00116 return itsVar[cls][fix][i]; 00117 } 00118 00119 // ###################################################################### 00120 uint RecurBayes::getNumFeatures() 00121 { 00122 return itsNumFeatures; 00123 } 00124 00125 // ###################################################################### 00126 uint RecurBayes::getNumClasses() 00127 { 00128 return itsNumClasses; 00129 } 00130 00131 // ###################################################################### 00132 uint RecurBayes::getClassFreq(uint cls) 00133 { 00134 ASSERT(cls < itsNumClasses); 00135 return itsClassFreq[cls]; 00136 } 00137 00138 // ###################################################################### 00139 double RecurBayes::getClassProb(uint cls) 00140 { 00141 ASSERT(cls < itsNumClasses); 00142 00143 //double totalFreq = 0; 00144 //for (uint i=0; i<itsNumClasses; i++) 00145 // totalFreq += itsClassFreq[i]; 00146 00147 //return double(itsClassFreq[cls])/totalFreq; 00148 00149 return double(1/itsNumClasses); 00150 } 00151 00152 // ###################################################################### 00153 int RecurBayes::classify(std::vector<double> &fv, double *prob, uint fix) 00154 { 00155 00156 //the maximum posterior (MAP alg): 00157 double maxProb = -std::numeric_limits<double>::max(); 00158 int maxCls = -1; 00159 double sumClassProb = 0; 00160 00161 for(uint cls=0; cls<itsNumClasses; cls++) 00162 { 00163 //Find the probability that the fv belongs to this class 00164 double probVal = log(getClassProb(cls)); //the prior probility 00165 for (uint i=0; i<itsNumFeatures; i++) //get the probilityposterior prob 00166 { 00167 if (itsMean[cls][fix][i] > 0) //only process if mean > 0 00168 { 00169 double g = gauss(fv[i], itsMean[cls][fix][i], itsVar[cls][fix][i]); 00170 probVal += log(g); 00171 00172 00173 // LINFO("Val %f Mean %f sigma %f g(%e) %e", 00174 // fv[i], itsMean[cls][i], itsStdevSq[cls][i], g, probVal); 00175 } 00176 } 00177 00178 LINFO("Class %i prob %f\n", cls, probVal); 00179 sumClassProb += probVal; 00180 if (probVal > maxProb){ //we have a new max 00181 maxProb = probVal; 00182 maxCls = cls; 00183 } 00184 } 00185 00186 if (prob != NULL) 00187 *prob = exp(maxProb)/exp(sumClassProb); 00188 00189 return maxCls; 00190 } 00191 00192 00193 // ###################################################################### 00194 double RecurBayes::getStatSig(std::vector<double> &fv, uint cls, uint fix) 00195 { 00196 ASSERT(fv.size() == itsNumFeatures); 00197 00198 double statSig = 0; 00199 00200 //simple t test 00201 for (uint i=0; i<fv.size(); i++){ 00202 //compute a t test for each feature 00203 double val = fv[i]; 00204 if (itsVar[cls][fix][i] != 0.00) 00205 { 00206 double fsig = log(fabs(val - itsMean[cls][fix][i])) - log(itsVar[cls][fix][i]); 00207 statSig += fsig; 00208 } 00209 } 00210 00211 return statSig; 00212 00213 } 00214 00215 //// ###################################################################### 00216 double RecurBayes::gauss(double x, double mean, double stdevSq) 00217 { 00218 double delta = -(x - mean) * (x - mean); 00219 return exp(delta/(2*stdevSq))/(sqrt(2*M_PI*stdevSq)); 00220 } 00221 00222 //// ###################################################################### 00223 void RecurBayes::save(const char *filename) 00224 { 00225 00226 int fd; 00227 00228 if ((fd = creat(filename, 0644)) == -1) { 00229 printf("Can not open %s for saving\n", filename); 00230 return; 00231 } 00232 00233 //write the # Features and Classes 00234 write(fd, &itsNumFeatures, sizeof(uint)); 00235 write(fd, &itsNumClasses, sizeof(uint)); 00236 00237 //Write the class freq and names 00238 for(uint i=0; i<itsNumClasses; i++) 00239 { 00240 write(fd, &itsClassFreq[i], sizeof(uint64)); 00241 00242 uint clsNameLength = itsClassNames[i].size()+1; //1 for null str terminator 00243 write(fd, &clsNameLength, sizeof(uint)); 00244 write(fd, itsClassNames[i].c_str(), sizeof(char)*clsNameLength); 00245 } 00246 00247 00248 //Write the mean and stdev 00249 for(uint cls=0; cls<itsNumClasses; cls++) 00250 { 00251 for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob 00252 { 00253 int fix = 1; 00254 write(fd, &itsMean[cls][fix][i], sizeof(double)); 00255 write(fd, &itsVar[cls][fix][i], sizeof(double)); 00256 } 00257 } 00258 00259 close(fd); 00260 00261 00262 } 00263 00264 //// ###################################################################### 00265 void RecurBayes::load(const char *filename) 00266 { 00267 00268 int fd; 00269 00270 if ((fd = open(filename, 0644)) == -1) { 00271 printf("Can not open %s for reading\n", filename); 00272 return; 00273 } 00274 00275 itsNumClasses = 0; 00276 itsNumFeatures = 0; 00277 //read the # Features and Classes 00278 read(fd, &itsNumFeatures, sizeof(uint)); 00279 read(fd, &itsNumClasses, sizeof(uint)); 00280 00281 //read the class freq 00282 itsClassFreq.clear(); 00283 itsClassFreq.resize(itsNumClasses); 00284 itsClassNames.resize(itsNumClasses); 00285 00286 for(uint i=0; i<itsNumClasses; i++) 00287 { 00288 read(fd, &itsClassFreq[i], sizeof(uint64)); 00289 00290 uint clsNameLength; 00291 read(fd, &clsNameLength, sizeof(uint)); 00292 char clsName[clsNameLength]; 00293 read(fd, &clsName, sizeof(char)*clsNameLength); 00294 itsClassNames[i] = std::string(clsName); 00295 00296 } 00297 00298 00299 //Write the mean and stdev 00300 itsMean.clear(); 00301 itsMean.resize(itsNumClasses, 00302 std::vector<std::vector<double> > 00303 (itsNumFix, std::vector<double>(itsNumFeatures,0))); 00304 00305 00306 itsVar.clear(); 00307 itsVar.resize(itsNumClasses, 00308 std::vector<std::vector<double> > 00309 (itsNumFix, std::vector<double>(itsNumFeatures,0.01))); 00310 00311 for(uint cls=0; cls<itsNumClasses; cls++) 00312 { 00313 for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob 00314 { 00315 read(fd, &itsMean[cls][i], sizeof(double)); 00316 read(fd, &itsVar[cls][i], sizeof(double)); 00317 } 00318 } 00319 00320 close(fd); 00321 00322 } 00323 00324 // ###################################################################### 00325 void RecurBayes::setFeatureName(uint index, const char *name) 00326 { 00327 ASSERT(index < itsNumFeatures); 00328 itsFeatureNames[index] = std::string(name); 00329 } 00330 00331 // ###################################################################### 00332 const char* RecurBayes::getFeatureName(const uint index) const 00333 { 00334 ASSERT(index < itsNumFeatures); 00335 return itsFeatureNames[index].c_str(); 00336 } 00337 00338 00339 // ###################################################################### 00340 int RecurBayes::addClass(const char *name) 00341 { 00342 //Add a new class 00343 00344 //check if the class exsists 00345 if (getClassId(name) == -1) 00346 { 00347 itsClassNames.push_back(std::string(name)); 00348 00349 itsMean.push_back(std::vector<std::vector<double> > 00350 (itsNumFix, std::vector<double>(itsNumFeatures,0))); 00351 00352 00353 itsVar.push_back(std::vector<std::vector<double> > 00354 (itsNumFix, std::vector<double>(itsNumFeatures,0.01))), 00355 00356 00357 itsClassFreq.push_back(1); 00358 00359 return itsNumClasses++; 00360 } 00361 00362 return -1; 00363 00364 } 00365 00366 // ###################################################################### 00367 const char* RecurBayes::getClassName(const uint id) 00368 { 00369 ASSERT(id < itsNumClasses); 00370 return itsClassNames[id].c_str(); 00371 00372 } 00373 00374 // ###################################################################### 00375 int RecurBayes::getClassId(const char *name) 00376 { 00377 //TODO: should use hash_map (but no hash_map on this machine :( ) 00378 00379 for(uint i=0; i<itsClassNames.size(); i++){ 00380 if (!strcmp(itsClassNames[i].c_str(), name)) 00381 return i; 00382 } 00383 00384 return -1; //no class found but that name 00385 00386 } 00387 00388 // ###################################################################### 00389 /* So things look consistent in everyone's emacs... */ 00390 /* Local Variables: */ 00391 /* indent-tabs-mode: nil */ 00392 /* End: */ 00393 00394 #endif // LEARN_BAYES_C_DEFINED