00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Learn/svm.h"
00039 #include "SVMClassifierModule.H"
00040 #include "Learn/SVMClassifier.H"
00041
00042 #include <fstream>
00043 #include <iostream>
00044 #include <iomanip>
00045 #include <string>
00046 #include <cstdlib>
00047
00048 const ModelOptionCateg MOC_SVMClassifier = {
00049 MOC_SORTPRI_3, "SVMClassifier Related Options" };
00050
00051 const ModelOptionDef OPT_SVMModelFileNames =
00052 { MODOPT_ARG_STRING, "SVM Model File Names", &MOC_SVMClassifier, OPTEXP_CORE,
00053 "Colon separated list of filenames for the SVM model",
00054 "svm-model-filenames", '\0', "<filename1:filename2>", "" };
00055
00056 const ModelOptionDef OPT_SVMModelNames =
00057 { MODOPT_ARG_STRING, "SVM Model Names", &MOC_SVMClassifier, OPTEXP_CORE,
00058 "Colon separated list of names for the SVM",
00059 "svm-model-names", '\0', "<name1:name2>", "" };
00060
00061 const ModelOptionDef OPT_SVMRangeFileNames =
00062 { MODOPT_ARG_STRING, "SVM Range File Names", &MOC_SVMClassifier, OPTEXP_CORE,
00063 "Colon separated list of filenames for the SVM range",
00064 "svm-range-filenames", '\0', "<filename1:filename2>", "" };
00065
00066 const ModelOptionDef OPT_SVMOutputFileNames =
00067 { MODOPT_ARG_STRING, "SVM Training Output File Names", &MOC_SVMClassifier, OPTEXP_CORE,
00068 "Filename(s) for the SVM training to output to",
00069 "svm-output-filenames", '\0', "<filename1:filename2>", "" };
00070
00071 const ModelOptionDef OPT_SVMObjDBFileName =
00072 { MODOPT_ARG_STRING, "Object DB File Name", &MOC_SVMClassifier, OPTEXP_CORE,
00073 "Filename for the object database file",
00074 "svm-objdb-filename", '\0', "<filename>", "" };
00075
00076 const ModelOptionDef OPT_SVMTrainObjName =
00077 { MODOPT_ARG_STRING, "SVM Training Object Name", &MOC_SVMClassifier, OPTEXP_CORE,
00078 "Name of the object used in training",
00079 "svm-train-objname", '\0', "<name>", "" };
00080
00081 const ModelOptionDef OPT_SVMTrainObjId =
00082 { MODOPT_ARG(int), "SVM Training Object Id", &MOC_SVMClassifier, OPTEXP_CORE,
00083 "Id of the object used in training",
00084 "svm-train-objid", '\0', "<id>", "-1" };
00085
00086 const ModelOptionDef OPT_SVMMode =
00087 { MODOPT_ARG_STRING, "SVM Mode", &MOC_SVMClassifier, OPTEXP_CORE,
00088 "The mode of SVM Classifier. Train|Rec",
00089 "svm-mode", '\0', "<Train|Rec>", "Rec" };
00090
00091
00092
00093 SVMClassifierModule::SVMClassifierModule(OptionManager& mgr,
00094 const std::string& descrName,
00095 const std::string& tagName) :
00096 SimModule(mgr, descrName, tagName),
00097 itsSVMModelFileNamesStr(&OPT_SVMModelFileNames, this),
00098 itsSVMModelNamesStr(&OPT_SVMModelNames, this),
00099 itsSVMRangeFileNamesStr(&OPT_SVMRangeFileNames, this),
00100 itsSVMOutputFileNamesStr(&OPT_SVMOutputFileNames, this),
00101 itsSVMObjDBFileName(&OPT_SVMObjDBFileName, this),
00102 itsSVMTrainObjName(&OPT_SVMTrainObjName, this),
00103 itsSVMTrainObjId(&OPT_SVMTrainObjId, this),
00104 itsSVMMode(&OPT_SVMMode, this)
00105 {
00106 }
00107
00108 SVMClassifierModule::~SVMClassifierModule()
00109 {
00110 }
00111
00112
00113
00114 void SVMClassifierModule::start2()
00115 {
00116 SimModule::start2();
00117
00118
00119 split(itsSVMModelFileNamesStr.getVal(), ":", std::back_inserter(itsSVMModelFiles));
00120
00121
00122 split(itsSVMModelNamesStr.getVal(), ":", std::back_inserter(itsSVMModelNames));
00123
00124
00125 split(itsSVMRangeFileNamesStr.getVal(), ":", std::back_inserter(itsSVMRangeFiles));
00126
00127
00128 split(itsSVMOutputFileNamesStr.getVal(), ":", std::back_inserter(itsSVMOutputFiles));
00129
00130
00131 if(itsSVMRangeFiles.size()==0)
00132 {
00133 itsSVMRangeFiles = std::vector<std::string>(itsSVMModelFiles.size());
00134 }
00135 if(itsSVMRangeFiles.size()!= itsSVMModelFiles.size())
00136 LFATAL("If range files are specified, must be same number as model files");
00137
00138 if(itsSVMObjDBFileName.getVal().compare("") == 0) {
00139 LFATAL("Must specify object db file using --svm-objdb-filename");
00140 }
00141 itsObjDB.loadObjDB(itsSVMObjDBFileName.getVal());
00142
00143 if (itsSVMMode.getVal().compare("Train") == 0)
00144 {
00145 if(itsSVMOutputFiles.size() == 0 || itsSVMOutputFiles.size() != itsSVMModelNames.size()) {
00146 LFATAL("Must specify training output file(s) (and equal number of model names) if in training mode");
00147 }
00148
00149 SVMClassifier classifier;
00150 itsClassifiers.push_back(classifier);
00151 }
00152 else if (itsSVMMode.getVal().compare("Rec") == 0)
00153 {
00154 if(itsSVMModelFiles.size() == 0 || itsSVMModelFiles.size() != itsSVMModelNames.size()){
00155 LFATAL("Must specify svm model file(s) (and equal number of model names) if in recognition mode");
00156 }
00157 for(size_t c=0;c<itsSVMModelFiles.size();c++)
00158 {
00159 SVMClassifier classifier;
00160
00161 classifier.readModel(itsSVMModelFiles[c]);
00162
00163 if(itsSVMRangeFiles[c].compare("") != 0) {
00164 classifier.readRange(itsSVMRangeFiles[c]);
00165 }
00166 itsClassifiers.push_back(classifier);
00167 }
00168 }
00169 else
00170 LFATAL("Unknown SVM Mode type %s", itsSVMMode.getVal().c_str());
00171 }
00172
00173
00174 void SVMClassifierModule::stop1()
00175 {
00176 LINFO("Writing out object db in module %p",this);
00177 itsObjDB.writeObjDB(itsSVMObjDBFileName.getVal());
00178 SimModule::stop1();
00179 }
00180
00181 std::string SVMClassifierModule::getMode()
00182 {
00183 return itsSVMMode.getVal();
00184 }
00185
00186 void SVMClassifierModule::attentionShift(SimEventQueue& q,
00187 const Point2D<int>& location)
00188 {
00189 }
00190
00191
00192 std::vector<std::string> SVMClassifierModule::getModelNames()
00193 {
00194 return itsSVMModelNames;
00195 }
00196
00197 SVMObject SVMClassifierModule::determineLabel(std::vector<float> featureVector, int id, std::string name, int classifierId)
00198 {
00199 if (itsSVMMode.getVal().compare("Rec") == 0)
00200 {
00201 return recognizeLabel(featureVector,id,name,classifierId);
00202 }
00203 else if (itsSVMMode.getVal().compare("Train") == 0)
00204 {
00205 return trainLabel(featureVector,id,name,classifierId);
00206 }
00207 else
00208 {
00209 LFATAL("Invalid SVM Classification Mode");
00210 }
00211 return SVMObject();
00212 }
00213
00214 SVMObject SVMClassifierModule::trainLabel(std::vector<float> featureVector, int id, std::string name, int classifierId)
00215 {
00216 SVMObject so;
00217 printf("In SVMClassifierModule::determineLabel %s\n",itsSVMMode.getVal().c_str());
00218
00219
00220 if(itsSVMTrainObjId.getVal() != -1) {
00221 id = itsSVMTrainObjId.getVal();
00222 }
00223 if(itsSVMTrainObjName.getVal().compare("") != 0) {
00224 name = itsSVMTrainObjName.getVal();
00225 }
00226 if(name.compare("") == 0) {
00227
00228
00229
00230 LFATAL("Name is not specified while in Train mode");
00231 }
00232
00233 if(id == -1 && itsObjDB.getObject(name).initialized()) {
00234 so = itsObjDB.getObject(name);
00235 id = so.id;
00236 }
00237
00238 LINFO("Training on object %s[%d]\n",name.c_str(),id);
00239 so = itsObjDB.updateObject(id,name);
00240 so.confidence = 1.0;
00241 so.id = id;
00242 so.name = name;
00243 itsClassifiers[classifierId].train(itsSVMOutputFiles[classifierId],id,featureVector);
00244
00245 return so;
00246
00247 }
00248
00249 SVMObject SVMClassifierModule::recognizeLabel(std::vector<float> featureVector, int id, std::string name, int classifierId)
00250 {
00251 std::vector<SVMObject> objects = getLabelPDF(featureVector,id,name,classifierId);
00252 return getBestLabel(objects);
00253 }
00254
00255 SVMObject SVMClassifierModule::getBestLabel(const std::vector<SVMObject> &objects)
00256 {
00257 double maxProb=-1;
00258 int bestLabelIdx=0;
00259 if(objects.size() == 0)
00260 return SVMObject();
00261 for(size_t i=0;i<objects.size();i++)
00262 {
00263 if(maxProb<objects[i].confidence)
00264 {
00265 bestLabelIdx=i;
00266 maxProb=objects[i].confidence;
00267 }
00268 }
00269 return objects[bestLabelIdx];
00270 }
00271
00272 std::vector<SVMObject> SVMClassifierModule::getLabelPDF(std::vector<float> featureVector, int id, std::string name, int classifierId)
00273 {
00274 std::map<int,double> pdf = itsClassifiers[classifierId].predictPDF(featureVector);
00275
00276 if (itsSVMMode.getVal().compare("Rec") == 0)
00277 {
00278 std::vector<SVMObject> svmObjects;
00279 for(std::map<int,double>::iterator pdfIt=pdf.begin(); pdfIt!=pdf.end(); ++pdfIt)
00280 {
00281 SVMObject obj = itsObjDB.getObject(int(pdfIt->first));
00282 if (obj.id == -1)
00283 obj.id = pdfIt->first;
00284 obj.confidence = pdfIt->second;
00285 svmObjects.push_back(obj);
00286 }
00287 return svmObjects;
00288
00289 }
00290 else if (itsSVMMode.getVal().compare("Train") == 0)
00291 {
00292 std::vector<SVMObject> svmObjects;
00293 for(std::map<int,double>::iterator pdfIt=pdf.begin(); pdfIt!=pdf.end(); ++pdfIt)
00294 {
00295 SVMObject obj = itsObjDB.getObject(int(pdfIt->first));
00296 if(obj.id==id)
00297 obj.confidence = 1;
00298 else
00299 obj.confidence = 0;
00300 svmObjects.push_back(obj);
00301 }
00302 return svmObjects;
00303 }
00304 else
00305 {
00306 LFATAL("Invalid SVM Classification Mode");
00307 }
00308
00309 return std::vector<SVMObject>();
00310 }
00311