00001 /*!@file Learn/Bayes.H Bayesian network classifier */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Lior Elazary <elazary@usc.edu> 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/Learn/Bayes.H $ 00035 // $Id: Bayes.H 10794 2009-02-08 06:21:09Z itti $ 00036 // 00037 00038 //This is a Naive Bayes for now 00039 #ifndef LEARN_BAYES_H_DEFINED 00040 #define LEARN_BAYES_H_DEFINED 00041 00042 #include "Util/Types.H" // for uint 00043 #include <vector> 00044 #include <string> 00045 00046 class Bayes 00047 { 00048 public: 00049 00050 struct ClassInfo 00051 { 00052 ClassInfo(int id, double p, double sig) : classID(id), prob(p), statSig(sig) {} //constructor to set values 00053 int classID; //the class ID; 00054 double prob; //the probability of this class 00055 double statSig; //the statistical significance between the features value and the params 00056 }; 00057 00058 //! Construct a bayes classifer with a given number of features and 00059 //! number of classes 00060 Bayes(uint numFeatures, uint numClasses); 00061 00062 //! Destructor 00063 ~Bayes(); 00064 00065 //! Learn to associate a feature vector with a particuler class 00066 void learn(const std::vector<double> &fv, const uint cls); //TODO make as a Template 00067 00068 //! Learn to associate a feature vector with a particuler class name 00069 void learn(const std::vector<double> &fv, const char *name); //TODO make as a Template 00070 00071 //! classify a given feature vector 00072 int classify(const std::vector<double> &fv, double *prob = NULL); //TODO make as a template 00073 00074 //! classify a given feature vector (Return all classes and thier prob, cls contains the max) 00075 std::vector<ClassInfo> classifyRange(std::vector<double> &fv, int &retCls, const bool sort=true); 00076 00077 //! Return the probability of all the classes given the feature vector 00078 std::vector<double> getClassProb(const std::vector<double> &fv); 00079 00080 //! Get the mean for a particuler feature 00081 double getMean(const uint cls, const uint i) const; 00082 00083 //! Get the stdev Squared for a particuler feature 00084 double getStdevSq(const uint cls, const uint i) const; 00085 00086 //! set the mean 00087 void setMean(const uint cls, const uint i, const double val); 00088 00089 //! set the stdev Squared for a particuler feature 00090 void setStdevSq(const uint cls, const uint i, const double val); 00091 00092 //! Get the number of features 00093 uint getNumFeatures() const; 00094 00095 //! Get the number of classes 00096 uint getNumClasses() const; 00097 00098 //! Get the Freq of a given class 00099 uint getClassFreq(const uint cls) const; 00100 00101 //! Get the probability of a given class 00102 double getClassProb(const uint cls) const; 00103 00104 //! return the statistical significent of the FV for a given class 00105 double getStatSig(const std::vector<double> &fv, const uint cls) const; 00106 00107 //! Calculate a Normal Dist (use the srdev squared 00108 double gauss(const double x, const double mean, const double stdevSq) const; 00109 00110 //! Save the network to a file 00111 void save(const char *filename); 00112 00113 //! Load the network from a binary file 00114 bool load(const char *filename); 00115 00116 //! Load the network from a text file 00117 void import(const char *filename); 00118 00119 //! set feature name (for debuging) 00120 void setFeatureName(uint index, const char *name); 00121 00122 //! get feature name (for debuging) 00123 const char* getFeatureName(const uint index) const; 00124 00125 //! Add class by name and return its Id 00126 int addClass(const char *name); 00127 00128 //! Get the class name from a given Id 00129 const char* getClassName(const uint id); 00130 00131 //! Get the class id from a given name 00132 int getClassId(const char *name); 00133 00134 //! get the probability value associated with a classification 00135 double getMaxProb() const; 00136 00137 //! get the normalized probability value associated with a classification 00138 double getNormProb() const; 00139 private: 00140 uint itsNumFeatures; //the number of features we have 00141 uint itsNumClasses; //the Number of classes we have 00142 double itsMaxProb; // Stores the maximum probability with each object rec 00143 double itsSumProb; // Used to derive a normalized P value 00144 double itsNormProb; // normalized P of object 00145 std::vector<std::vector<double> > itsMean; //the mean for each feature per class 00146 std::vector<std::vector<double> > itsStdevSq; //the stdev squared for each feature 00147 //TODO: its long int sufficent? is there a better way of calc the mean and stdev? 00148 std::vector<uint64> itsClassFreq; //the Freq of a given class 00149 00150 std::vector<std::string> itsFeatureNames; //THe name of the features 00151 std::vector<std::string> itsClassNames; //The names of the clases 00152 00153 }; 00154 00155 // ###################################################################### 00156 /* So things look consistent in everyone's emacs... */ 00157 /* Local Variables: */ 00158 /* indent-tabs-mode: nil */ 00159 /* End: */ 00160 00161 #endif // LEARN_BAYES_H_DEFINED