RecurBayes.C

00001 /*!@file Learn/RecurRecurBayes.C Recursive RecurBayesian network classifier */
00002 
00003 // //////////////////////////////////////////////////////////////////// //
00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005   //
00005 // by the University of Southern California (USC) and the iLab at USC.  //
00006 // See http://iLab.usc.edu for information about this project.          //
00007 // //////////////////////////////////////////////////////////////////// //
00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected //
00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency //
00010 // in Visual Environments, and Applications'' by Christof Koch and      //
00011 // Laurent Itti, California Institute of Technology, 2001 (patent       //
00012 // pending; application number 09/912,225 filed July 23, 2001; see      //
00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status).     //
00014 // //////////////////////////////////////////////////////////////////// //
00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit.       //
00016 //                                                                      //
00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can   //
00018 // redistribute it and/or modify it under the terms of the GNU General  //
00019 // Public License as published by the Free Software Foundation; either  //
00020 // version 2 of the License, or (at your option) any later version.     //
00021 //                                                                      //
00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope  //
00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the   //
00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR      //
00025 // PURPOSE.  See the GNU General Public License for more details.       //
00026 //                                                                      //
00027 // You should have received a copy of the GNU General Public License    //
00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write   //
00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330,   //
00030 // Boston, MA 02111-1307 USA.                                           //
00031 // //////////////////////////////////////////////////////////////////// //
00032 //
00033 // Primary maintainer for this file: Lior Elazary <elazary@usc.edu>
00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/Learn/RecurBayes.C $
00035 // $Id: RecurBayes.C 10794 2009-02-08 06:21:09Z itti $
00036 //
00037 
00038 //This is a Naive RecurBayes for now
00039 #ifndef LEARN_BAYES_C_DEFINED
00040 #define LEARN_BAYES_C_DEFINED
00041 
00042 #include "Learn/RecurBayes.H"
00043 #include "Util/Assert.H"
00044 #include "Util/log.H"
00045 #include <math.h>
00046 #include <fcntl.h>
00047 #include <limits>
00048 #include <string>
00049 
00050 // ######################################################################
00051 RecurBayes::RecurBayes(uint numClasses, uint numFeatures, uint numFix):
00052   itsNumFeatures(numFeatures), itsNumClasses(numClasses), itsNumFix(numFix),
00053   itsMean(numClasses,
00054       std::vector<std::vector<double> >
00055           (numFix, std::vector<double>(numFeatures,0))),
00056   itsVar(numClasses,
00057       std::vector<std::vector<double> >
00058           (numFix, std::vector<double>(numFeatures,0.01))),
00059   itsClassFreq(numClasses,0),
00060   itsFeatureNames(numFeatures),
00061   itsClassNames(numClasses, "No Name")
00062 {
00063 
00064 }
00065 
00066 // ######################################################################
00067 RecurBayes::~RecurBayes()
00068 {
00069 }
00070 
00071 
00072 // ######################################################################
00073 void RecurBayes::learn(std::vector<double> &fv, const char *name, uint fix)
00074 {
00075   //get the class id
00076 
00077   int cls = getClassId(name);
00078   if (cls == -1) //this is a new class, add it to the network
00079     cls = addClass(name);
00080 
00081   ASSERT(fv.size() == itsNumFeatures);
00082 
00083   //update the class freq
00084   ASSERT((uint)cls < itsNumClasses);
00085   itsClassFreq[cls]++;
00086 
00087   //compute the stddev and mean of each feature
00088   //This algorithm is due to Knuth (The Art of Computer Programming, volume 2:
00089   //  Seminumerical Algorithms, 3rd edn., p. 232. Boston: Addison-Wesley.)
00090   for (uint i=0; i<fv.size(); i++){
00091     double val = fv[i];
00092     double delta = val - itsMean[cls][fix][i];
00093     itsMean[cls][fix][i] += delta/itsClassFreq[cls];
00094     if (itsClassFreq[cls] > 3)
00095     {
00096       itsVar[cls][fix][i] = (itsVar[cls][fix][i]*(itsClassFreq[cls]-2))
00097         + delta*(val - itsMean[cls][fix][i]);
00098     }
00099     if (itsClassFreq[cls] > 1) //watch for divide by 0
00100       itsVar[cls][fix][i] /= double(itsClassFreq[cls]-1);
00101   }
00102 
00103 }
00104 
00105 // ######################################################################
00106 double RecurBayes::getMean(uint cls, uint i, uint fix)
00107 {
00108   ASSERT(cls < itsNumClasses && i < itsNumFeatures);
00109   return itsMean[cls][fix][i];
00110 }
00111 
00112 // ######################################################################
00113 double RecurBayes::getStdevSq(uint cls, uint i, uint fix)
00114 {
00115   ASSERT(cls < itsNumClasses && i < itsNumFeatures);
00116   return itsVar[cls][fix][i];
00117 }
00118 
00119 // ######################################################################
00120 uint RecurBayes::getNumFeatures()
00121 {
00122   return itsNumFeatures;
00123 }
00124 
00125 // ######################################################################
00126 uint RecurBayes::getNumClasses()
00127 {
00128   return itsNumClasses;
00129 }
00130 
00131 // ######################################################################
00132 uint RecurBayes::getClassFreq(uint cls)
00133 {
00134   ASSERT(cls < itsNumClasses);
00135   return itsClassFreq[cls];
00136 }
00137 
00138 // ######################################################################
00139 double RecurBayes::getClassProb(uint cls)
00140 {
00141   ASSERT(cls < itsNumClasses);
00142 
00143   //double totalFreq = 0;
00144   //for (uint i=0; i<itsNumClasses; i++)
00145   //  totalFreq += itsClassFreq[i];
00146 
00147   //return double(itsClassFreq[cls])/totalFreq;
00148 
00149   return double(1/itsNumClasses);
00150 }
00151 
00152 // ######################################################################
00153 int RecurBayes::classify(std::vector<double> &fv, double *prob, uint fix)
00154 {
00155 
00156   //the maximum posterior  (MAP alg):
00157   double maxProb = -std::numeric_limits<double>::max();
00158   int maxCls = -1;
00159   double sumClassProb = 0;
00160 
00161   for(uint cls=0; cls<itsNumClasses; cls++)
00162   {
00163     //Find the probability that the fv belongs to this class
00164     double probVal = log(getClassProb(cls)); //the prior probility
00165     for (uint i=0; i<itsNumFeatures; i++) //get the probilityposterior prob
00166     {
00167       if (itsMean[cls][fix][i] > 0)  //only process if mean > 0
00168       {
00169         double g = gauss(fv[i], itsMean[cls][fix][i], itsVar[cls][fix][i]);
00170         probVal += log(g);
00171 
00172 
00173        //  LINFO("Val %f Mean %f sigma %f g(%e) %e",
00174        //      fv[i], itsMean[cls][i], itsStdevSq[cls][i], g, probVal);
00175       }
00176     }
00177 
00178     LINFO("Class %i prob %f\n", cls, probVal);
00179     sumClassProb += probVal;
00180     if (probVal > maxProb){ //we have a new max
00181       maxProb = probVal;
00182       maxCls = cls;
00183     }
00184   }
00185 
00186   if (prob != NULL)
00187     *prob = exp(maxProb)/exp(sumClassProb);
00188 
00189   return maxCls;
00190 }
00191 
00192 
00193 // ######################################################################
00194 double RecurBayes::getStatSig(std::vector<double> &fv, uint cls, uint fix)
00195 {
00196   ASSERT(fv.size() == itsNumFeatures);
00197 
00198   double statSig = 0;
00199 
00200   //simple t test
00201   for (uint i=0; i<fv.size(); i++){
00202     //compute a t test for each feature
00203     double val = fv[i];
00204     if (itsVar[cls][fix][i] != 0.00)
00205     {
00206       double fsig = log(fabs(val - itsMean[cls][fix][i])) - log(itsVar[cls][fix][i]);
00207       statSig += fsig;
00208     }
00209   }
00210 
00211   return statSig;
00212 
00213 }
00214 
00215 //// ######################################################################
00216 double RecurBayes::gauss(double x, double mean, double stdevSq)
00217 {
00218  double delta = -(x - mean) * (x - mean);
00219  return exp(delta/(2*stdevSq))/(sqrt(2*M_PI*stdevSq));
00220 }
00221 
00222 //// ######################################################################
00223 void RecurBayes::save(const char *filename)
00224 {
00225 
00226   int fd;
00227 
00228   if ((fd = creat(filename, 0644)) == -1) {
00229     printf("Can not open %s for saving\n", filename);
00230     return;
00231   }
00232 
00233   //write the #  Features and Classes
00234   write(fd, &itsNumFeatures, sizeof(uint));
00235   write(fd, &itsNumClasses, sizeof(uint));
00236 
00237   //Write the class freq and names
00238   for(uint i=0; i<itsNumClasses; i++)
00239   {
00240     write(fd, &itsClassFreq[i], sizeof(uint64));
00241 
00242     uint clsNameLength = itsClassNames[i].size()+1; //1 for null str terminator
00243     write(fd, &clsNameLength, sizeof(uint));
00244     write(fd, itsClassNames[i].c_str(), sizeof(char)*clsNameLength);
00245   }
00246 
00247 
00248   //Write the mean and stdev
00249   for(uint cls=0; cls<itsNumClasses; cls++)
00250   {
00251     for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob
00252     {
00253       int fix = 1;
00254       write(fd, &itsMean[cls][fix][i], sizeof(double));
00255       write(fd, &itsVar[cls][fix][i], sizeof(double));
00256     }
00257   }
00258 
00259   close(fd);
00260 
00261 
00262 }
00263 
00264 //// ######################################################################
00265 void RecurBayes::load(const char *filename)
00266 {
00267 
00268   int fd;
00269 
00270   if ((fd = open(filename, 0644)) == -1) {
00271     printf("Can not open %s for reading\n", filename);
00272     return;
00273   }
00274 
00275   itsNumClasses = 0;
00276   itsNumFeatures = 0;
00277   //read the #  Features and Classes
00278   read(fd, &itsNumFeatures, sizeof(uint));
00279   read(fd, &itsNumClasses, sizeof(uint));
00280 
00281   //read the class freq
00282   itsClassFreq.clear();
00283   itsClassFreq.resize(itsNumClasses);
00284   itsClassNames.resize(itsNumClasses);
00285 
00286   for(uint i=0; i<itsNumClasses; i++)
00287   {
00288     read(fd, &itsClassFreq[i], sizeof(uint64));
00289 
00290     uint clsNameLength;
00291     read(fd, &clsNameLength, sizeof(uint));
00292     char clsName[clsNameLength];
00293     read(fd, &clsName, sizeof(char)*clsNameLength);
00294     itsClassNames[i] = std::string(clsName);
00295 
00296   }
00297 
00298 
00299   //Write the mean and stdev
00300   itsMean.clear();
00301   itsMean.resize(itsNumClasses,
00302       std::vector<std::vector<double> >
00303           (itsNumFix, std::vector<double>(itsNumFeatures,0)));
00304 
00305 
00306   itsVar.clear();
00307   itsVar.resize(itsNumClasses,
00308       std::vector<std::vector<double> >
00309           (itsNumFix, std::vector<double>(itsNumFeatures,0.01)));
00310 
00311   for(uint cls=0; cls<itsNumClasses; cls++)
00312   {
00313     for (uint i=0; i<itsNumFeatures; i++) //get the posterior prob
00314     {
00315       read(fd, &itsMean[cls][i], sizeof(double));
00316       read(fd, &itsVar[cls][i], sizeof(double));
00317     }
00318   }
00319 
00320   close(fd);
00321 
00322 }
00323 
00324 // ######################################################################
00325 void RecurBayes::setFeatureName(uint index, const char *name)
00326 {
00327   ASSERT(index < itsNumFeatures);
00328   itsFeatureNames[index] = std::string(name);
00329 }
00330 
00331 // ######################################################################
00332 const char* RecurBayes::getFeatureName(const uint index) const
00333 {
00334   ASSERT(index < itsNumFeatures);
00335   return itsFeatureNames[index].c_str();
00336 }
00337 
00338 
00339 // ######################################################################
00340 int RecurBayes::addClass(const char *name)
00341 {
00342   //Add a new class
00343 
00344   //check if the class exsists
00345   if (getClassId(name) == -1)
00346   {
00347     itsClassNames.push_back(std::string(name));
00348 
00349     itsMean.push_back(std::vector<std::vector<double> >
00350         (itsNumFix, std::vector<double>(itsNumFeatures,0)));
00351 
00352 
00353     itsVar.push_back(std::vector<std::vector<double> >
00354         (itsNumFix, std::vector<double>(itsNumFeatures,0.01))),
00355 
00356 
00357     itsClassFreq.push_back(1);
00358 
00359     return itsNumClasses++;
00360   }
00361 
00362   return -1;
00363 
00364 }
00365 
00366 // ######################################################################
00367 const char* RecurBayes::getClassName(const uint id)
00368 {
00369   ASSERT(id < itsNumClasses);
00370   return itsClassNames[id].c_str();
00371 
00372 }
00373 
00374 // ######################################################################
00375 int RecurBayes::getClassId(const char *name)
00376 {
00377   //TODO: should use hash_map (but no hash_map on this machine :( )
00378 
00379   for(uint i=0; i<itsClassNames.size(); i++){
00380     if (!strcmp(itsClassNames[i].c_str(), name))
00381       return i;
00382   }
00383 
00384   return -1;  //no class found but that name
00385 
00386 }
00387 
00388 // ######################################################################
00389 /* So things look consistent in everyone's emacs... */
00390 /* Local Variables: */
00391 /* indent-tabs-mode: nil */
00392 /* End: */
00393 
00394 #endif // LEARN_BAYES_C_DEFINED
Generated on Sun May 8 08:05:19 2011 for iLab Neuromorphic Vision Toolkit by  doxygen 1.6.3