00001 /*!@file Learn/SVMClassifierModule.C Support Vector Machine Classifier module */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the // 00005 // University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Laurent Itti <itti@usc.edu> 00034 // $HeadURL: svn://dparks@isvn.usc.edu/software/invt/trunk/saliency/src/Learn/SVMClassifierModule.C $ 00035 // $Id: SVMClassifierModule.C 13332 2010-04-28 18:50:09Z dparks $ 00036 // 00037 00038 #include "Learn/svm.h" 00039 #include "SVMClassifierModule.H" 00040 #include "Learn/SVMClassifier.H" 00041 00042 #include <fstream> 00043 #include <iostream> 00044 #include <iomanip> 00045 #include <string> 00046 #include <cstdlib> 00047 00048 const ModelOptionCateg MOC_SVMClassifier = { 00049 MOC_SORTPRI_3, "SVMClassifier Related Options" }; 00050 00051 const ModelOptionDef OPT_SVMModelFileNames = 00052 { MODOPT_ARG_STRING, "SVM Model File Names", &MOC_SVMClassifier, OPTEXP_CORE, 00053 "Colon separated list of filenames for the SVM model", 00054 "svm-model-filenames", '\0', "<filename1:filename2>", "" }; 00055 00056 const ModelOptionDef OPT_SVMModelNames = 00057 { MODOPT_ARG_STRING, "SVM Model Names", &MOC_SVMClassifier, OPTEXP_CORE, 00058 "Colon separated list of names for the SVM", 00059 "svm-model-names", '\0', "<name1:name2>", "" }; 00060 00061 const ModelOptionDef OPT_SVMRangeFileNames = 00062 { MODOPT_ARG_STRING, "SVM Range File Names", &MOC_SVMClassifier, OPTEXP_CORE, 00063 "Colon separated list of filenames for the SVM range", 00064 "svm-range-filenames", '\0', "<filename1:filename2>", "" }; 00065 00066 const ModelOptionDef OPT_SVMOutputFileNames = 00067 { MODOPT_ARG_STRING, "SVM Training Output File Names", &MOC_SVMClassifier, OPTEXP_CORE, 00068 "Filename(s) for the SVM training to output to", 00069 "svm-output-filenames", '\0', "<filename1:filename2>", "" }; 00070 00071 const ModelOptionDef OPT_SVMObjDBFileName = 00072 { MODOPT_ARG_STRING, "Object DB File Name", &MOC_SVMClassifier, OPTEXP_CORE, 00073 "Filename for the object database file", 00074 "svm-objdb-filename", '\0', "<filename>", "" }; 00075 00076 const ModelOptionDef OPT_SVMTrainObjName = 00077 { MODOPT_ARG_STRING, "SVM Training Object Name", &MOC_SVMClassifier, OPTEXP_CORE, 00078 "Name of the object used in training", 00079 "svm-train-objname", '\0', "<name>", "" }; 00080 00081 const ModelOptionDef OPT_SVMTrainObjId = 00082 { MODOPT_ARG(int), "SVM Training Object Id", &MOC_SVMClassifier, OPTEXP_CORE, 00083 "Id of the object used in training", 00084 "svm-train-objid", '\0', "<id>", "-1" }; 00085 00086 const ModelOptionDef OPT_SVMMode = 00087 { MODOPT_ARG_STRING, "SVM Mode", &MOC_SVMClassifier, OPTEXP_CORE, 00088 "The mode of SVM Classifier. Train|Rec", 00089 "svm-mode", '\0', "<Train|Rec>", "Rec" }; 00090 00091 00092 00093 SVMClassifierModule::SVMClassifierModule(OptionManager& mgr, 00094 const std::string& descrName, 00095 const std::string& tagName) : 00096 SimModule(mgr, descrName, tagName), 00097 itsSVMModelFileNamesStr(&OPT_SVMModelFileNames, this), 00098 itsSVMModelNamesStr(&OPT_SVMModelNames, this), 00099 itsSVMRangeFileNamesStr(&OPT_SVMRangeFileNames, this), 00100 itsSVMOutputFileNamesStr(&OPT_SVMOutputFileNames, this), 00101 itsSVMObjDBFileName(&OPT_SVMObjDBFileName, this), 00102 itsSVMTrainObjName(&OPT_SVMTrainObjName, this), 00103 itsSVMTrainObjId(&OPT_SVMTrainObjId, this), 00104 itsSVMMode(&OPT_SVMMode, this) 00105 { 00106 } 00107 00108 SVMClassifierModule::~SVMClassifierModule() 00109 { 00110 } 00111 00112 00113 00114 void SVMClassifierModule::start2() 00115 { 00116 SimModule::start2(); 00117 00118 // Parse which model files should be loaded 00119 split(itsSVMModelFileNamesStr.getVal(), ":", std::back_inserter(itsSVMModelFiles)); 00120 00121 // Parse which model names should be loaded 00122 split(itsSVMModelNamesStr.getVal(), ":", std::back_inserter(itsSVMModelNames)); 00123 00124 // Parse which ranges should be loaded 00125 split(itsSVMRangeFileNamesStr.getVal(), ":", std::back_inserter(itsSVMRangeFiles)); 00126 00127 // Parse which output files 00128 split(itsSVMOutputFileNamesStr.getVal(), ":", std::back_inserter(itsSVMOutputFiles)); 00129 00130 // If no range files are specified, resize the list to be the same as the model file list 00131 if(itsSVMRangeFiles.size()==0) 00132 { 00133 itsSVMRangeFiles = std::vector<std::string>(itsSVMModelFiles.size()); 00134 } 00135 if(itsSVMRangeFiles.size()!= itsSVMModelFiles.size()) 00136 LFATAL("If range files are specified, must be same number as model files"); 00137 00138 if(itsSVMObjDBFileName.getVal().compare("") == 0) { 00139 LFATAL("Must specify object db file using --svm-objdb-filename"); 00140 } 00141 itsObjDB.loadObjDB(itsSVMObjDBFileName.getVal()); 00142 00143 if (itsSVMMode.getVal().compare("Train") == 0) // training 00144 { 00145 if(itsSVMOutputFiles.size() == 0 || itsSVMOutputFiles.size() != itsSVMModelNames.size()) { 00146 LFATAL("Must specify training output file(s) (and equal number of model names) if in training mode"); 00147 } 00148 // Load basic classifier for training 00149 SVMClassifier classifier; 00150 itsClassifiers.push_back(classifier); 00151 } 00152 else if (itsSVMMode.getVal().compare("Rec") == 0) // Recognition 00153 { 00154 if(itsSVMModelFiles.size() == 0 || itsSVMModelFiles.size() != itsSVMModelNames.size()){ 00155 LFATAL("Must specify svm model file(s) (and equal number of model names) if in recognition mode"); 00156 } 00157 for(size_t c=0;c<itsSVMModelFiles.size();c++) 00158 { 00159 SVMClassifier classifier; 00160 // Load model file 00161 classifier.readModel(itsSVMModelFiles[c]); 00162 // Load the range file 00163 if(itsSVMRangeFiles[c].compare("") != 0) { 00164 classifier.readRange(itsSVMRangeFiles[c]); 00165 } 00166 itsClassifiers.push_back(classifier); 00167 } 00168 } 00169 else 00170 LFATAL("Unknown SVM Mode type %s", itsSVMMode.getVal().c_str()); 00171 } 00172 00173 // ###################################################################### 00174 void SVMClassifierModule::stop1() 00175 { 00176 LINFO("Writing out object db in module %p",this); 00177 itsObjDB.writeObjDB(itsSVMObjDBFileName.getVal()); 00178 SimModule::stop1(); 00179 } 00180 00181 std::string SVMClassifierModule::getMode() 00182 { 00183 return itsSVMMode.getVal(); 00184 } 00185 00186 void SVMClassifierModule::attentionShift(SimEventQueue& q, 00187 const Point2D<int>& location) 00188 { 00189 } 00190 00191 00192 std::vector<std::string> SVMClassifierModule::getModelNames() 00193 { 00194 return itsSVMModelNames; 00195 } 00196 00197 SVMObject SVMClassifierModule::determineLabel(std::vector<float> featureVector, int id, std::string name, int classifierId) 00198 { 00199 if (itsSVMMode.getVal().compare("Rec") == 0) // Recognition 00200 { 00201 return recognizeLabel(featureVector,id,name,classifierId); 00202 } 00203 else if (itsSVMMode.getVal().compare("Train") == 0) // Train 00204 { 00205 return trainLabel(featureVector,id,name,classifierId); 00206 } 00207 else 00208 { 00209 LFATAL("Invalid SVM Classification Mode"); 00210 } 00211 return SVMObject(); 00212 } 00213 00214 SVMObject SVMClassifierModule::trainLabel(std::vector<float> featureVector, int id, std::string name, int classifierId) 00215 { 00216 SVMObject so; 00217 printf("In SVMClassifierModule::determineLabel %s\n",itsSVMMode.getVal().c_str()); 00218 // Preprocess the id and name if we are in training mode 00219 // If the id or name for these images is given, assign it 00220 if(itsSVMTrainObjId.getVal() != -1) { 00221 id = itsSVMTrainObjId.getVal(); 00222 } 00223 if(itsSVMTrainObjName.getVal().compare("") != 0) { 00224 name = itsSVMTrainObjName.getVal(); 00225 } 00226 if(name.compare("") == 0) { 00227 // If the name is still not defined, prompt the user 00228 // LINFO("Enter name for new object or [RETURN] to skip training:"); 00229 // std::getline(std::cin, name, '\n'); 00230 LFATAL("Name is not specified while in Train mode"); 00231 } 00232 // If the id is still not defined, try to pull the id out 00233 if(id == -1 && itsObjDB.getObject(name).initialized()) { 00234 so = itsObjDB.getObject(name); 00235 id = so.id; 00236 } 00237 // Make sure the id is in the database. Check for mismatch, and add it, if not present 00238 LINFO("Training on object %s[%d]\n",name.c_str(),id); 00239 so = itsObjDB.updateObject(id,name); 00240 so.confidence = 1.0; 00241 so.id = id; 00242 so.name = name; 00243 itsClassifiers[classifierId].train(itsSVMOutputFiles[classifierId],id,featureVector); 00244 00245 return so; 00246 00247 } 00248 00249 SVMObject SVMClassifierModule::recognizeLabel(std::vector<float> featureVector, int id, std::string name, int classifierId) 00250 { 00251 std::vector<SVMObject> objects = getLabelPDF(featureVector,id,name,classifierId); 00252 return getBestLabel(objects); 00253 } 00254 00255 SVMObject SVMClassifierModule::getBestLabel(const std::vector<SVMObject> &objects) 00256 { 00257 double maxProb=-1; 00258 int bestLabelIdx=0; 00259 if(objects.size() == 0) 00260 return SVMObject(); 00261 for(size_t i=0;i<objects.size();i++) 00262 { 00263 if(maxProb<objects[i].confidence) 00264 { 00265 bestLabelIdx=i; 00266 maxProb=objects[i].confidence; 00267 } 00268 } 00269 return objects[bestLabelIdx]; 00270 } 00271 00272 std::vector<SVMObject> SVMClassifierModule::getLabelPDF(std::vector<float> featureVector, int id, std::string name, int classifierId) 00273 { 00274 std::map<int,double> pdf = itsClassifiers[classifierId].predictPDF(featureVector); 00275 00276 if (itsSVMMode.getVal().compare("Rec") == 0) // Recognition 00277 { 00278 std::vector<SVMObject> svmObjects; 00279 for(std::map<int,double>::iterator pdfIt=pdf.begin(); pdfIt!=pdf.end(); ++pdfIt) 00280 { 00281 SVMObject obj = itsObjDB.getObject(int(pdfIt->first)); 00282 if (obj.id == -1) //Assign the object ID that we have, if we did not find it the DB 00283 obj.id = pdfIt->first; 00284 obj.confidence = pdfIt->second; 00285 svmObjects.push_back(obj); 00286 } 00287 return svmObjects; 00288 00289 } 00290 else if (itsSVMMode.getVal().compare("Train") == 0) // Train 00291 { 00292 std::vector<SVMObject> svmObjects; 00293 for(std::map<int,double>::iterator pdfIt=pdf.begin(); pdfIt!=pdf.end(); ++pdfIt) 00294 { 00295 SVMObject obj = itsObjDB.getObject(int(pdfIt->first)); 00296 if(obj.id==id) 00297 obj.confidence = 1; 00298 else 00299 obj.confidence = 0; 00300 svmObjects.push_back(obj); 00301 } 00302 return svmObjects; 00303 } 00304 else 00305 { 00306 LFATAL("Invalid SVM Classification Mode"); 00307 } 00308 00309 return std::vector<SVMObject>(); 00310 } 00311