00001 /*!@file Learn/Bayes.C Bayesian network classifier */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Lior Elazary <elazary@usc.edu> 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/Learn/Bayes.C $ 00035 // $Id: Bayes.C 14390 2011-01-13 20:17:22Z pez $ 00036 // 00037 00038 //This is a Naive Bayes for now 00039 #include "Learn/Bayes.H" 00040 #include "Util/Assert.H" 00041 #include "Util/log.H" 00042 #include <math.h> 00043 #include <fcntl.h> 00044 #include <limits> 00045 #include <string> 00046 00047 #include <cstdio> 00048 00049 // functor to assist with classInfo sorting: 00050 class lessClassInfo 00051 { 00052 public: 00053 bool operator()(const Bayes::ClassInfo& classA, 00054 const Bayes::ClassInfo& classB) 00055 { return (classA.prob) > (classB.prob); } 00056 }; 00057 00058 // ###################################################################### 00059 Bayes::Bayes(uint numFeatures, uint numClasses): 00060 itsNumFeatures(numFeatures), itsNumClasses(numClasses), 00061 itsMean(numClasses, std::vector<double>(numFeatures,0)), 00062 itsStdevSq(numClasses, std::vector<double>(numFeatures,0.01)), 00063 itsClassFreq(numClasses,0), 00064 itsFeatureNames(numFeatures), 00065 itsClassNames(numClasses, "No Name") 00066 { 00067 00068 } 00069 00070 // ###################################################################### 00071 Bayes::~Bayes() 00072 {} 00073 00074 // ###################################################################### 00075 void Bayes::learn(const std::vector<double> &fv, const uint cls) 00076 { 00077 00078 ASSERT(fv.size() == itsNumFeatures); 00079 00080 //update the class freq 00081 ASSERT(cls < itsNumClasses); 00082 itsClassFreq[cls]++; 00083 00084 //compute the stddev and mean of each feature 00085 //This algorithm is due to Knuth (The Art of Computer Programming, volume 2: 00086 // Seminumerical Algorithms, 3rd edn., p. 232. Boston: Addison-Wesley.) 00087 for (uint i=0; i<fv.size(); i++){ 00088 const double val = fv[i]; 00089 const double delta = val - itsMean[cls][i]; 00090 itsMean[cls][i] += delta/itsClassFreq[cls]; 00091 if (itsClassFreq[cls] > 2) //watch for divide by 0 00092 { 00093 itsStdevSq[cls][i] = (itsStdevSq[cls][i]*(itsClassFreq[cls]-2)) 00094 + delta*(val - itsMean[cls][i]); 00095 itsStdevSq[cls][i] /= double(itsClassFreq[cls]-1); 00096 } 00097 } 00098 } 00099 00100 // ###################################################################### 00101 void Bayes::learn(const std::vector<double> &fv, const char *name) 00102 { 00103 //get the class id 00104 00105 int cls = getClassId(name); 00106 if (cls == -1) //this is a new class, add it to the network 00107 cls = addClass(name); 00108 00109 if(fv.size() != itsNumFeatures) 00110 { 00111 LINFO("NOTE: deleting the .net file may fix this if you are"); 00112 LINFO("training with a different set of features."); 00113 LFATAL("fv.size() != itsNumFeatures: %d != %d",(int)fv.size(), 00114 itsNumFeatures); 00115 } 00116 00117 //ASSERT(fv.size() == itsNumFeatures); 00118 00119 //update the class freq 00120 ASSERT((uint)cls < itsNumClasses); 00121 itsClassFreq[cls]++; 00122 00123 //compute the stddev and mean of each feature 00124 //This algorithm is due to Knuth (The Art of Computer Programming, volume 2: 00125 // Seminumerical Algorithms, 3rd edn., p. 232. Boston: Addison-Wesley.) 00126 /* 00127 for (uint i=0; i<fv.size(); i++){ 00128 const double val = fv[i]; 00129 const double delta = val - itsMean[cls][i]; 00130 itsMean[cls][i] += delta/itsClassFreq[cls]; 00131 if (itsClassFreq[cls] > 3) 00132 { 00133 itsStdevSq[cls][i] = (itsStdevSq[cls][i]*(itsClassFreq[cls]-2)) 00134 + delta*(val - itsMean[cls][i]); 00135 } 00136 if (itsClassFreq[cls] > 1) //watch for divide by 0 00137 itsStdevSq[cls][i] /= double(itsClassFreq[cls]-1); 00138 } 00139 */ 00140 00141 //watch for divide by 0 00142 if(itsClassFreq[cls] > 3) 00143 { 00144 const double freq1 = 1.0F/(double)itsClassFreq[cls]; 00145 const double freq2 = 1.0F/(double)(itsClassFreq[cls]-1); 00146 const uint64 freq = itsClassFreq[cls]; 00147 00148 for (uint i=0; i<fv.size(); i++) 00149 { 00150 const double val = fv[i]; 00151 const double delta = val - itsMean[cls][i]; 00152 00153 itsMean[cls][i] += delta * freq1; 00154 itsStdevSq[cls][i] = (itsStdevSq[cls][i]*(freq-2)) 00155 + delta*(val - itsMean[cls][i]); 00156 itsStdevSq[cls][i] *= freq2; 00157 } 00158 } 00159 else if(itsClassFreq[cls] > 1) 00160 { 00161 const double freq1 = 1.0F/(double)itsClassFreq[cls]; 00162 const double freq2 = 1.0F/(double)(itsClassFreq[cls]-1); 00163 00164 for (uint i=0; i<fv.size(); i++) 00165 { 00166 const double val = fv[i]; 00167 const double delta = val - itsMean[cls][i]; 00168 00169 itsMean[cls][i] += delta * freq1; 00170 itsStdevSq[cls][i] *= freq2; 00171 } 00172 } 00173 else 00174 { 00175 const double freq1 = 1.0F/(double)itsClassFreq[cls]; 00176 00177 for (uint i=0; i<fv.size(); i++) 00178 { 00179 const double val = fv[i]; 00180 const double delta = val - itsMean[cls][i]; 00181 00182 itsMean[cls][i] += delta * freq1; 00183 } 00184 } 00185 00186 } 00187 00188 // ###################################################################### 00189 double Bayes::getMean(const uint cls, const uint i) const 00190 { 00191 ASSERT(cls < itsNumClasses && i < itsNumFeatures); 00192 return itsMean[cls][i]; 00193 } 00194 00195 // ###################################################################### 00196 double Bayes::getStdevSq(const uint cls, const uint i) const 00197 { 00198 ASSERT(cls < itsNumClasses && i < itsNumFeatures); 00199 return itsStdevSq[cls][i]; 00200 } 00201 00202 // ###################################################################### 00203 void Bayes::setMean(const uint cls, const uint i, const double val) 00204 { 00205 ASSERT(cls < itsNumClasses && i < itsNumFeatures); 00206 itsMean[cls][i] = val; 00207 00208 } 00209 00210 // ###################################################################### 00211 void Bayes::setStdevSq(const uint cls, const uint i, const double val) 00212 { 00213 ASSERT(cls < itsNumClasses && i < itsNumFeatures); 00214 itsStdevSq[cls][i] = val; 00215 } 00216 00217 // ###################################################################### 00218 uint Bayes::getNumFeatures() const 00219 { 00220 return itsNumFeatures; 00221 } 00222 00223 // ###################################################################### 00224 uint Bayes::getNumClasses() const 00225 { 00226 return itsNumClasses; 00227 } 00228 00229 // ###################################################################### 00230 uint Bayes::getClassFreq(const uint cls) const 00231 { 00232 ASSERT(cls < itsNumClasses); 00233 return itsClassFreq[cls]; 00234 } 00235 00236 // ###################################################################### 00237 double Bayes::getClassProb(const uint cls) const 00238 { 00239 ASSERT(cls < itsNumClasses); 00240 00241 //double totalFreq = 0; 00242 //for (uint i=0; i<itsNumClasses; i++) 00243 // totalFreq += itsClassFreq[i]; 00244 00245 //return double(itsClassFreq[cls])/totalFreq; 00246 00247 return double(1/itsNumClasses); 00248 } 00249 00250 // ###################################################################### 00251 int Bayes::classify(const std::vector<double> &fv, double *prob) 00252 { 00253 00254 //the maximum posterior (MAP alg): 00255 itsMaxProb = -std::numeric_limits<double>::max(); 00256 itsSumProb = 0.0F; 00257 itsNormProb = 0.0F; 00258 int maxCls = -1; 00259 //double sumClassProb = 0; 00260 00261 for(uint cls=0; cls<itsNumClasses; cls++) 00262 { 00263 LINFO("Class %d of %d - %s",cls,itsNumClasses,itsClassNames[cls].c_str()); 00264 //Find the probability that the fv belongs to this class 00265 double probVal = 0; ////log(getClassProb(cls)); //the prior probility 00266 for (uint i=0; i<itsNumFeatures; i++) //get the probilityposterior prob 00267 { 00268 if (itsMean[cls][i] > 0) //only process if mean > 0 00269 { 00270 const double g = gauss(fv[i], itsMean[cls][i], itsStdevSq[cls][i]); 00271 probVal += log(g); 00272 00273 //LINFO("Val %f Mean %f sigma %f g(%e) %e", 00274 // fv[i], itsMean[cls][i], itsStdevSq[cls][i], g, probVal); 00275 } 00276 } 00277 00278 //if (probVal == NAN || probVal == -INFINITY) probVal = 1; //log of 0 00279 //printf("Class %i %s prob %f\n", cls, getClassName(cls), probVal); 00280 00281 //sumClassProb += probVal; 00282 00283 itsSumProb += exp(probVal); 00284 if (probVal > itsMaxProb){ //we have a new max 00285 itsMaxProb = probVal; 00286 maxCls = cls; 00287 } 00288 } 00289 00290 itsMaxProb = exp(itsMaxProb); 00291 itsNormProb = itsMaxProb / itsSumProb; 00292 00293 if (prob != NULL) 00294 *prob = itsMaxProb; //)/exp(sumClassProb); 00295 00296 return maxCls; 00297 } 00298 00299 // ###################################################################### 00300 std::vector<Bayes::ClassInfo> Bayes::classifyRange(std::vector<double> &fv, 00301 int &retCls, const bool sort) 00302 { 00303 00304 std::vector<ClassInfo> classInfoRet; 00305 00306 //the maximum posterior (MAP alg): 00307 itsMaxProb = -std::numeric_limits<double>::max(); 00308 itsSumProb = 0.0F; 00309 itsNormProb = 0.0F; 00310 int maxCls = -1; 00311 00312 for(uint cls=0; cls<itsNumClasses; cls++) 00313 { 00314 //Find the probability that the fv belongs to this class 00315 double probVal = 0; //log(getClassProb(cls)); //the prior probility 00316 for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob 00317 { 00318 if (itsMean[cls][i] > 0) //only process if mean > 0 00319 { 00320 const double g = gauss(fv[i], itsMean[cls][i], itsStdevSq[cls][i]); 00321 probVal += log(g); 00322 } 00323 } 00324 00325 itsSumProb += exp(probVal); 00326 if (probVal > itsMaxProb){ //we have a new max 00327 itsMaxProb = probVal; 00328 maxCls = cls; 00329 } 00330 classInfoRet.push_back(ClassInfo(cls, probVal, getStatSig(fv, cls))); 00331 } 00332 00333 itsMaxProb = exp(itsMaxProb); 00334 itsNormProb = itsMaxProb / itsSumProb; 00335 00336 retCls = maxCls; 00337 if (sort) 00338 std::sort(classInfoRet.begin(), classInfoRet.end(), lessClassInfo()); 00339 return classInfoRet; 00340 } 00341 00342 // ###################################################################### 00343 std::vector<double> Bayes::getClassProb(const std::vector<double> &fv) 00344 { 00345 00346 std::vector<double> classProb(itsNumClasses); 00347 00348 for(uint cls=0; cls<itsNumClasses; cls++) 00349 { 00350 //Find the probability that the fv belongs to this class 00351 double probVal = log(1/(float)itsNumClasses); //log(getClassProb(cls)); //the prior probility 00352 for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob 00353 { 00354 if (itsMean[cls][i] != 0) //only process if mean > 0 00355 { 00356 00357 const double g = gauss(fv[i], itsMean[cls][i], itsStdevSq[cls][i]); 00358 probVal += log(g); 00359 //LINFO("%i: %f %f %f => %e,%e", 00360 // cls, fv[i], itsMean[cls][i], itsStdevSq[cls][i],g, probVal); 00361 } else { 00362 probVal += log(0); 00363 } 00364 } 00365 00366 classProb[cls] = probVal; 00367 } 00368 00369 return classProb; 00370 } 00371 00372 // ###################################################################### 00373 double Bayes::getStatSig(const std::vector<double> &fv, const uint cls) const 00374 { 00375 ASSERT(fv.size() == itsNumFeatures); 00376 00377 double statSig = 0; 00378 00379 //simple t test 00380 for (uint i=0; i<fv.size(); i++){ 00381 //compute a t test for each feature 00382 const double val = fv[i]; 00383 if (itsStdevSq[cls][i] != 0.00) 00384 { 00385 const double fsig = 00386 log(fabs(val - itsMean[cls][i])) - log(itsStdevSq[cls][i]); 00387 statSig += fsig; 00388 } 00389 } 00390 00391 return statSig; 00392 00393 } 00394 00395 //// ###################################################################### 00396 inline double Bayes::gauss(const double x, const double mean, const double stdevSq) const 00397 { 00398 const double delta = -(x - mean) * (x - mean); 00399 return exp(delta/(2*stdevSq))/(sqrt(2*M_PI*stdevSq)); 00400 } 00401 00402 //// ###################################################################### 00403 void Bayes::save(const char *filename) 00404 { 00405 00406 int fd; 00407 00408 if ((fd = creat(filename, 0644)) == -1) { 00409 printf("Can not open %s for saving\n", filename); 00410 return; 00411 } 00412 00413 //write the # Features and Classes 00414 if(write(fd, &itsNumFeatures, sizeof(uint)) != sizeof(uint)) LFATAL("Failed to write into: %s", filename); 00415 if(write(fd, &itsNumClasses, sizeof(uint)) != sizeof(uint)) LFATAL("Failed to write into: %s", filename); 00416 00417 //Write the class freq and names 00418 for(uint i=0; i<itsNumClasses; i++) 00419 { 00420 if(write(fd, &itsClassFreq[i], sizeof(uint64)) != sizeof(uint64)) LFATAL("Failed to write into: %s", filename); 00421 uint clsNameLength = itsClassNames[i].size()+1; //1 for null str terminator 00422 if(write(fd, &clsNameLength, sizeof(uint)) != sizeof(uint)) LFATAL("Failed to write into: %s", filename); 00423 int sz = sizeof(char)*clsNameLength; 00424 if(write(fd, itsClassNames[i].c_str(), sz) != sz) LFATAL("Failed to write into: %s", filename); 00425 } 00426 00427 00428 //Write the mean and stdev 00429 for(uint cls=0; cls<itsNumClasses; cls++) 00430 { 00431 for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob 00432 { 00433 if(write(fd, &itsMean[cls][i], sizeof(double)) != sizeof(double)) LFATAL("Failed to write into: %s", filename); 00434 if(write(fd, &itsStdevSq[cls][i], sizeof(double)) != sizeof(double)) LFATAL("Failed to write into: %s", filename); 00435 } 00436 } 00437 00438 close(fd); 00439 00440 00441 } 00442 00443 //// ###################################################################### 00444 bool Bayes::load(const char *filename) 00445 { 00446 int fd; 00447 00448 if ((fd = open(filename, 0644)) == -1) { 00449 printf("Can not open %s for reading\n", filename); 00450 return false; 00451 } 00452 00453 itsNumClasses = 0; 00454 itsNumFeatures = 0; 00455 //read the # Features and Classes 00456 if(read(fd, &itsNumFeatures, sizeof(uint)) != sizeof(uint)) LFATAL("Failed to read from: %s", filename); 00457 if(read(fd, &itsNumClasses, sizeof(uint)) != sizeof(uint)) LFATAL("Failed to read from: %s", filename); 00458 00459 //read the class freq 00460 itsClassFreq.clear(); 00461 itsClassFreq.resize(itsNumClasses); 00462 itsClassNames.resize(itsNumClasses); 00463 00464 for(uint i=0; i<itsNumClasses; i++) 00465 { 00466 if(read(fd, &itsClassFreq[i], sizeof(uint64)) != sizeof(uint64)) LFATAL("Failed to read from: %s", filename); 00467 00468 uint clsNameLength; 00469 if(read(fd, &clsNameLength, sizeof(uint)) != sizeof(uint)) LFATAL("Failed to read from: %s", filename); 00470 char clsName[clsNameLength]; 00471 int sz = sizeof(char)*clsNameLength; 00472 if(read(fd, &clsName, sz) != sz) LFATAL("Failed to read from: %s", filename); 00473 itsClassNames[i] = std::string(clsName); 00474 } 00475 00476 00477 //Write the mean and stdev 00478 itsMean.clear(); 00479 itsMean.resize(itsNumClasses, std::vector<double>(itsNumFeatures,0)); 00480 00481 itsStdevSq.clear(); 00482 itsStdevSq.resize(itsNumClasses, std::vector<double>(itsNumFeatures,0)); 00483 00484 for(uint cls=0; cls<itsNumClasses; cls++) 00485 { 00486 for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob 00487 { 00488 if(read(fd, &itsMean[cls][i], sizeof(double)) != sizeof(double)) LFATAL("Failed to read from: %s", filename); 00489 if(read(fd, &itsStdevSq[cls][i], sizeof(double)) != sizeof(double)) LFATAL("Failed to read from: %s", filename); 00490 } 00491 } 00492 00493 close(fd); 00494 00495 return true; 00496 } 00497 00498 //// ###################################################################### 00499 void Bayes::import(const char *filename) 00500 { 00501 00502 FILE *fd; 00503 00504 if ((fd = fopen(filename, "r")) == NULL) { 00505 printf("Can not open %s for reading\n", filename); 00506 return; 00507 } 00508 00509 00510 //Read the mean and stdev 00511 for(uint cls=0; cls<itsNumClasses; cls++) 00512 { 00513 00514 //read the means then the var 00515 for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob 00516 { 00517 float mean = 0; 00518 int ret = fscanf(fd, "%f", &mean); 00519 if (ret == -1) //end of file or error, exit 00520 { 00521 fclose(fd); 00522 return; 00523 } 00524 itsMean[cls][i] = mean; 00525 } 00526 00527 //read the means then the var 00528 for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob 00529 { 00530 float stdevSq = 0; 00531 int ret = fscanf(fd, "%f", &stdevSq); 00532 if (ret == -1) //end of file or error, exit 00533 { 00534 fclose(fd); 00535 return; 00536 } 00537 itsStdevSq[cls][i] = stdevSq; 00538 } 00539 00540 00541 } 00542 00543 fclose(fd); 00544 00545 } 00546 // ###################################################################### 00547 void Bayes::setFeatureName(uint index, const char *name) 00548 { 00549 ASSERT(index < itsNumFeatures); 00550 itsFeatureNames[index] = std::string(name); 00551 } 00552 00553 // ###################################################################### 00554 const char* Bayes::getFeatureName(const uint index) const 00555 { 00556 ASSERT(index < itsNumFeatures); 00557 return itsFeatureNames[index].c_str(); 00558 } 00559 00560 // ###################################################################### 00561 int Bayes::addClass(const char *name) 00562 { 00563 //Add a new class 00564 00565 //check if the class exsists 00566 if (getClassId(name) == -1) 00567 { 00568 itsClassNames.push_back(std::string(name)); 00569 00570 itsMean.push_back(std::vector<double>(itsNumFeatures,0)); 00571 itsStdevSq.push_back(std::vector<double>(itsNumFeatures,0.01)); 00572 itsClassFreq.push_back(1); 00573 return itsNumClasses++; 00574 } 00575 00576 return -1; 00577 00578 } 00579 00580 // ###################################################################### 00581 const char* Bayes::getClassName(const uint id) 00582 { 00583 ASSERT(id < itsNumClasses); 00584 return itsClassNames[id].c_str(); 00585 00586 } 00587 00588 // ###################################################################### 00589 int Bayes::getClassId(const char *name) 00590 { 00591 //TODO: should use hash_map (but no hash_map on this machine :( ) 00592 00593 for(uint i=0; i<itsClassNames.size(); i++){ 00594 if (!strcmp(itsClassNames[i].c_str(), name)) 00595 return i; 00596 } 00597 00598 return -1; //no class found but that name 00599 00600 } 00601 00602 // ###################################################################### 00603 double Bayes::getMaxProb() const 00604 { return itsMaxProb; } 00605 00606 // ###################################################################### 00607 double Bayes::getNormProb() const 00608 { return itsNormProb; } 00609 00610 // ###################################################################### 00611 /* So things look consistent in everyone's emacs... */ 00612 /* Local Variables: */ 00613 /* indent-tabs-mode: nil */ 00614 /* End: */