00001 /*!@file Neuro/InferoTemporalSalBayes.C Object recognition module with SalBayes */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the // 00005 // University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Lior Elazary 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/Neuro/InferoTemporalSalBayes.C $ 00035 // $Id: InferoTemporalSalBayes.C 12337 2009-12-19 02:45:23Z itti $ 00036 // 00037 00038 #include "Neuro/InferoTemporalSalBayes.H" 00039 00040 #include "Component/OptionManager.H" 00041 #include "Component/ModelOptionDef.H" 00042 #include "Channels/ChannelMaps.H" 00043 #include "Channels/ChannelOpts.H" 00044 #include "Image/MathOps.H" 00045 #include "Image/ShapeOps.H" 00046 #include "Image/CutPaste.H" 00047 #include "Neuro/NeuroOpts.H" 00048 #include "Neuro/NeuroSimEvents.H" 00049 #include "Neuro/Brain.H" 00050 #include "Neuro/VisualCortex.H" 00051 #include "Simulation/SimEventQueue.H" 00052 #include "Media/MediaSimEvents.H" 00053 00054 #include <cstdlib> 00055 #include <iostream> 00056 00057 static const ModelOptionDef OPT_ITSB_FOVSize = 00058 { MODOPT_ARG(Dims), "FoveaSize", &MOC_ITC, OPTEXP_CORE, 00059 "Use the given fovea size for constructing a feature vector.", 00060 "it-fov-size", '\0', "<w>x<h>", "75x75" }; 00061 00062 static const ModelOptionDef OPT_ITCMode = 00063 { MODOPT_ARG_STRING, "ITC Mode", &MOC_ITC, OPTEXP_CORE, 00064 "The mode of ITC . Train: is for training from some data, Rec is for recognition.", 00065 "it-mode", '\0', "<Train|Rec>", "Rec" }; 00066 00067 static const ModelOptionDef OPT_ITCSalBayes_NetFile = 00068 { MODOPT_ARG_STRING, "ITC BaysNet File", &MOC_ITC, OPTEXP_CORE, 00069 "Name of the file to save/read the computed Bayesian Network.", 00070 "it-bayesnet-file", '\0', "<filename>", "SalBayes.net" }; 00071 00072 const ModelOptionDef OPT_ITSiftObjectDBFileName = 00073 { MODOPT_ARG_STRING, "ITC SiftObjectDBFileName", &MOC_ITC, OPTEXP_CORE, 00074 "Filename for the sift object database. Specifying no file will disable the sift alg.", 00075 "it-object-db", '\0', "<filename>", "" }; 00076 00077 const ModelOptionDef OPT_ITUseSift = 00078 { MODOPT_ARG(int), "ITC use sift", &MOC_ITC, OPTEXP_CORE, 00079 "Use sift recognition on the n most probable objects obtained from SalBayes. " 00080 "That is, the sift algorithm will only run on the top n objects returned from SalBayes." 00081 "0 disables sift and just return the most probable object", 00082 "it-use-sift", '\0', "<int>", "10" }; 00083 00084 const ModelOptionDef OPT_ITUseMaxNMatches = 00085 { MODOPT_ARG(bool), "ITC use max num of matches", &MOC_ITC, OPTEXP_CORE, 00086 "When determining a which object in the database matches, use the maximum " 00087 "number of keypoints matches instead of sorting by distance of keypoints.", 00088 "it-use-max-num-matches", '\0', "<true|false>", "true" }; 00089 00090 // ###################################################################### 00091 // functor to assist with VisualObjectMatch sorting: 00092 class moreVOM 00093 { 00094 public: 00095 moreVOM(const float kcoeff, const float acoeff) : 00096 itsKcoeff(kcoeff), itsAcoeff(acoeff) 00097 { } 00098 00099 bool operator()(const rutz::shared_ptr<VisualObjectMatch>& x, 00100 const rutz::shared_ptr<VisualObjectMatch>& y) 00101 { return ( x->getScore(itsKcoeff, itsAcoeff) > 00102 y->getScore(itsKcoeff, itsAcoeff) ); } 00103 00104 private: 00105 float itsKcoeff, itsAcoeff; 00106 }; 00107 00108 00109 // ###################################################################### 00110 InferoTemporalSalBayes::InferoTemporalSalBayes(OptionManager& mgr, 00111 const std::string& descrName, 00112 const std::string& tagName) : 00113 InferoTemporal(mgr, descrName, tagName), 00114 itsLevelSpec(&OPT_LevelSpec, this), 00115 itsFoveaSize(&OPT_ITSB_FOVSize, this), 00116 itsITCMode(&OPT_ITCMode, this), 00117 itsBayesNetFilename(&OPT_ITCSalBayes_NetFile, this), 00118 itsUseSift(&OPT_ITUseSift, this), 00119 itsSiftObjectDBFile(&OPT_ITSiftObjectDBFileName, this), 00120 itsUseMaxNMatches(&OPT_ITUseMaxNMatches, this), 00121 itsSiftObjectDB(new VisualObjectDB()) 00122 {} 00123 00124 // ###################################################################### 00125 void InferoTemporalSalBayes::start1() 00126 { 00127 // if no filename given for our object DB, start empty, otherwise load it: 00128 if (!itsSiftObjectDBFile.getVal().empty()) 00129 itsSiftObjectDB->loadFrom(itsSiftObjectDBFile.getVal()); 00130 00131 InferoTemporal::start1(); 00132 } 00133 00134 // ###################################################################### 00135 void InferoTemporalSalBayes::stop1() 00136 {} 00137 00138 // ###################################################################### 00139 InferoTemporalSalBayes::~InferoTemporalSalBayes() 00140 {} 00141 00142 // ###################################################################### 00143 void InferoTemporalSalBayes::attentionShift(SimEventQueue& q, 00144 const Point2D<int>& location) 00145 { 00146 if (itsBayesNet.is_valid() == false) 00147 { 00148 rutz::shared_ptr<SimReqVCXmaps> vcxm(new SimReqVCXmaps(this)); 00149 q.request(vcxm); // VisualCortex is now filling-in the maps... 00150 rutz::shared_ptr<ChannelMaps> chm = vcxm->channelmaps(); 00151 00152 itsBayesNet.reset(new Bayes(chm->numSubmaps(), 0)); 00153 itsBayesNet->load(itsBayesNetFilename.getVal().c_str()); 00154 } 00155 00156 00157 // do we have 3D information about the scene? 00158 /* 00159 if(SeC<SimEventSaliencyCoordMap> e = q.check<SimEventSaliencyCoordMap>(this)) 00160 { 00161 itsHas3Dinfo = true; 00162 its3DSalLocation = e->getMaxSalCurrent(); 00163 its3DSalVal = e->getMaxSalValue(); 00164 } 00165 else 00166 { 00167 itsHas3Dinfo = false; 00168 } 00169 */ 00170 00171 //extract features 00172 std::vector<double> fv = buildRawDV(q, location); 00173 00174 rutz::shared_ptr<VisualObject> vobj; 00175 if (!itsSiftObjectDBFile.getVal().empty()) 00176 { 00177 //Train the Sift Recognition 00178 Image<PixRGB<float> > objImg; 00179 if (SeC<SimEventRetinaImage> e = q.check<SimEventRetinaImage>(this)) 00180 objImg = e->frame().colorByte(); 00181 if (!objImg.initialized()) return; 00182 00183 vobj.reset(new VisualObject("NewObject", "NULL", objImg, 00184 Point2D<int>(-1,-1), 00185 std::vector<float>(), 00186 std::vector< rutz::shared_ptr<Keypoint> >(), 00187 false)); //Use color? 00188 } 00189 00190 if (itsITCMode.getVal().compare("Train") == 0) // training 00191 { 00192 rutz::shared_ptr<TestImages::SceneData> sceneData; 00193 //Get the scene data, but dont mark it so we will get it on the next saccade 00194 std::string objName; 00195 if (SeC<SimEventInputFrame> e = q.check<SimEventInputFrame>(this,SEQ_UNMARKED,0)) 00196 { 00197 00198 GenericFrame gf = e->frame(); 00199 if (gf.hasMetaData(std::string("SceneData"))) 00200 { 00201 rutz::shared_ptr<GenericFrame::MetaData> metaData = gf.getMetaData(std::string("SceneData")); 00202 if (metaData.get() != 0) 00203 { 00204 sceneData.dyn_cast_from(metaData); 00205 objName = getObjNameAtLoc(sceneData->objects, location); 00206 } 00207 } else { 00208 LINFO("Enter name for new object or [RETURN] to skip training:"); 00209 std::getline(std::cin, objName, '\n'); 00210 } 00211 00212 } 00213 00214 if (objName.length() > 0) 00215 { 00216 LINFO("Train on %s", objName.c_str()); 00217 itsBayesNet->learn(fv, objName.c_str()); 00218 itsBayesNet->save(itsBayesNetFilename.getVal().c_str()); 00219 00220 //Train The Sift alg 00221 if (!itsSiftObjectDBFile.getVal().empty()) 00222 { 00223 vobj->setName(objName); 00224 itsSiftObjectDB->addObject(vobj, false); //allow multiple object names 00225 itsSiftObjectDB->saveTo(itsSiftObjectDBFile.getVal()); 00226 } 00227 } 00228 } 00229 else if (itsITCMode.getVal().compare("Rec") == 0) // Recognition 00230 { 00231 if (itsUseSift.getVal() > 0 && !itsSiftObjectDBFile.getVal().empty()) 00232 predictWithSift(fv, vobj, q); 00233 else 00234 predict(fv, q); 00235 } 00236 else 00237 LFATAL("Unknown IT Mode type %s", itsITCMode.getVal().c_str()); 00238 } 00239 // ###################################################################### 00240 00241 void InferoTemporalSalBayes::predict(std::vector<double> &fv, SimEventQueue& q) 00242 { 00243 const int cls = itsBayesNet->classify(fv); 00244 00245 if (cls != -1) 00246 { 00247 const double maxProb = itsBayesNet->getMaxProb(); 00248 const double normProb = itsBayesNet->getNormProb(); 00249 00250 std::string clsName(itsBayesNet->getClassName(cls)); 00251 rutz::shared_ptr<TestImages::ObjData> objData(new TestImages::ObjData); 00252 objData->name = clsName; 00253 objData->id = cls; 00254 objData->maxProb = maxProb; 00255 objData->normProb = normProb; 00256 00257 rutz::shared_ptr<SimEventObjectDescription> 00258 objDataEvent(new SimEventObjectDescription(this, objData)); 00259 q.post(objDataEvent); 00260 } 00261 } 00262 // ###################################################################### 00263 00264 void InferoTemporalSalBayes::predictWithSift(std::vector<double> &fv, 00265 rutz::shared_ptr<VisualObject> &vobj, 00266 SimEventQueue& q) 00267 { 00268 int cls = -1; 00269 uint mink = 6U; //min # of keypoints 00270 float kcoeff = 0.5; //keypoint distance score default 0.5F 00271 float acoeff = 0.5; //affine distance score default 0.5F 00272 //float minscore=1.0F; //minscore default 1.0F 00273 00274 std::vector< rutz::shared_ptr<VisualObjectMatch> > matches; 00275 matches.clear(); 00276 00277 std::string maxObjName; 00278 uint maxKeyMatches = 0; 00279 00280 std::vector<Bayes::ClassInfo> classInfo = itsBayesNet->classifyRange(fv, cls); 00281 for(uint i=0; i< classInfo.size() && i < (uint)itsUseSift.getVal(); i++) 00282 { 00283 std::string clsName(itsBayesNet->getClassName(classInfo[i].classID)); 00284 //check all objects containing this name 00285 //TODO: this should be hashed for greater efficiency 00286 for(uint j=0; j<itsSiftObjectDB->numObjects(); j++) 00287 { 00288 rutz::shared_ptr<VisualObject> knownVObj = itsSiftObjectDB->getObject(j); 00289 if (clsName.compare(knownVObj->getName()) == 0) 00290 { 00291 // attempt a match: 00292 rutz::shared_ptr<VisualObjectMatch> 00293 match(new VisualObjectMatch(vobj, knownVObj, 00294 VOMA_SIMPLE, 00295 6U)); //keypoint selection threshold, from lowes code 00296 00297 if (itsUseMaxNMatches.getVal()) 00298 { 00299 //Find the max based on num of matches 00300 if (match->size() > maxKeyMatches) 00301 { 00302 maxObjName = clsName; 00303 maxKeyMatches = match->size(); 00304 } 00305 } else { 00306 // apply some standard pruning: 00307 match->prune(std::max(25U, mink * 5U), mink); 00308 00309 //// if the match is good enough, store it: 00310 if (match->size() >= mink && 00311 // match->getScore(kcoeff, acoeff) >= minscore && 00312 match->checkSIFTaffine()) 00313 { 00314 matches.push_back(match); 00315 } 00316 } 00317 00318 } 00319 } 00320 std::sort(matches.begin(), matches.end(), moreVOM(kcoeff, acoeff)); 00321 } 00322 00323 rutz::shared_ptr<TestImages::ObjData> objData(new TestImages::ObjData); 00324 if (itsUseMaxNMatches.getVal()) 00325 { 00326 objData->name = maxObjName; 00327 objData->id = (unsigned int) -1; 00328 } else { 00329 if (matches.size() > 0) 00330 { 00331 //Get the first match 00332 objData->name = matches[0]->getVoTest()->getName(); 00333 objData->id = (unsigned int) -1; 00334 } 00335 } 00336 rutz::shared_ptr<SimEventObjectDescription> 00337 objDataEvent(new SimEventObjectDescription(this, objData)); 00338 q.post(objDataEvent); 00339 } 00340 00341 // ###################################################################### 00342 std::string InferoTemporalSalBayes::getObjNameAtLoc(const std::vector<TestImages::ObjData> &objects, const Point2D<int>& loc) 00343 { 00344 00345 for(uint obj=0; obj<objects.size(); obj++) 00346 { 00347 TestImages::ObjData objData = objects[obj]; 00348 00349 //find the object dimention from the polygon 00350 if (objData.polygon.size() > 0) 00351 { 00352 Point2D<int> upperLeft = objData.polygon[0]; 00353 Point2D<int> lowerRight = objData.polygon[0]; 00354 00355 for(uint i=0; i<objData.polygon.size(); i++) 00356 { 00357 //find the bounds for the crop 00358 if (objData.polygon[i].i < upperLeft.i) upperLeft.i = objData.polygon[i].i; 00359 if (objData.polygon[i].j < upperLeft.j) upperLeft.j = objData.polygon[i].j; 00360 00361 if (objData.polygon[i].i > lowerRight.i) lowerRight.i = objData.polygon[i].i; 00362 if (objData.polygon[i].j > lowerRight.j) lowerRight.j = objData.polygon[i].j; 00363 } 00364 00365 //check if point is within the polygon 00366 for(int y=upperLeft.j; y<lowerRight.j; y++) 00367 for(int x=upperLeft.i; x<lowerRight.i; x++) 00368 { 00369 if (pnpoly(objData.polygon, loc)) 00370 return objData.name; 00371 } 00372 } 00373 00374 } 00375 return std::string("Unknown"); 00376 } 00377 00378 // ###################################################################### 00379 std::vector<double> InferoTemporalSalBayes::buildRawDV(SimEventQueue& q, const Point2D<int>& foveaLoc) 00380 { 00381 bool salientLocationWithinSubmaps = true; 00382 Point2D<int> objSalientLoc(-1,-1); //the feature location 00383 00384 const int smlevel = itsLevelSpec.getVal().mapLevel(); 00385 00386 int x=int(foveaLoc.i / double(1 << smlevel) + 0.49); 00387 int y=int(foveaLoc.j / double(1 << smlevel) + 0.49); 00388 00389 //LINFO("Getting from location %d,%d",x,y); 00390 00391 int foveaW = int(itsFoveaSize.getVal().w() / double(1 << smlevel) + 0.49); 00392 int foveaH = int(itsFoveaSize.getVal().h() / double(1 << smlevel) + 0.49); 00393 00394 int tl_x = x - (foveaW/2); 00395 int tl_y = y - (foveaH/2); 00396 00397 rutz::shared_ptr<SimReqVCXmaps> vcxm(new SimReqVCXmaps(this)); 00398 q.request(vcxm); // VisualCortex is now filling-in the maps... 00399 rutz::shared_ptr<ChannelMaps> chm = vcxm->channelmaps(); 00400 00401 Dims mapDims = chm->getMap().getDims(); 00402 00403 //Shift the fovea location so we dont go outside the image 00404 //Sift the fovea position if nessesary 00405 if (tl_x < 0) tl_x = 0; if (tl_y < 0) tl_y = 0; 00406 if (tl_x+foveaW > mapDims.w()) tl_x = mapDims.w() - foveaW; 00407 if (tl_y+foveaH > mapDims.h()) tl_y = mapDims.h() - foveaH; 00408 00409 if (!salientLocationWithinSubmaps) 00410 { 00411 //Find the most salient location within the fovea 00412 Image<float> SMap = chm->getMap(); 00413 00414 Image<float> tmp = SMap; //TODO need to resize to fovea 00415 //Find the max location within the fovea 00416 00417 float maxVal; Point2D<int> maxLoc; 00418 findMax(tmp, maxLoc, maxVal); 00419 //convert back to original SMap cordinates 00420 // objSalientLoc.i=tl_x+maxLoc.i; 00421 // objSalientLoc.j=tl_y+maxLoc.j; 00422 objSalientLoc.i=x; 00423 objSalientLoc.j=y; 00424 } 00425 00426 //Go through all the submaps building the DV 00427 00428 std::vector<double> FV; 00429 uint numSubmaps = chm->numSubmaps(); 00430 for (uint i = 0; i < numSubmaps; i++) 00431 { 00432 //Image<float> submap = itsComplexChannel->getSubmap(i); 00433 Image<float> submap = chm->getRawCSmap(i); 00434 00435 // resize submap to fixed scale if necessary: 00436 if (submap.getWidth() > mapDims.w()) 00437 submap = downSize(submap, mapDims); 00438 else if (submap.getWidth() < mapDims.w()) 00439 submap = rescale(submap, mapDims); //TODO convert to quickInterpolate 00440 00441 00442 if (salientLocationWithinSubmaps) //get the location from the salient location within each submap 00443 { 00444 Image<float> tmp = submap; 00445 //get only the fovea region 00446 00447 if (foveaW < tmp.getWidth()) //crop if our fovea is smaller 00448 tmp = crop(tmp, Point2D<int>(tl_x, tl_y), Dims(foveaW, foveaH)); 00449 // tmp = maxNormalize(tmp, 0.0F, 10.0F, VCXNORM_MAXNORM); //find salient locations 00450 00451 //Find the max location within the fovea 00452 float maxVal; Point2D<int> maxLoc; findMax(tmp, maxLoc, maxVal); 00453 //LINFO("%i: Max val %f, loc(%i,%i)", i, maxVal, maxLoc.i, maxLoc.j); 00454 00455 objSalientLoc.i=tl_x+maxLoc.i; 00456 objSalientLoc.j=tl_y+maxLoc.j; 00457 00458 } 00459 00460 if (objSalientLoc.i < 0) objSalientLoc.i = 0; 00461 if (objSalientLoc.j < 0) objSalientLoc.j = 0; 00462 00463 if (objSalientLoc.i > submap.getWidth()-1) objSalientLoc.i = submap.getWidth()-1; 00464 if (objSalientLoc.j > submap.getHeight()-1) objSalientLoc.j = submap.getHeight()-1; 00465 00466 00467 00468 // LINFO("Location from %i,%i: (%i,%i)", objSalientLoc.i, objSalientLoc.j, 00469 // submap.getWidth(), submap.getHeight()); 00470 float featureVal = submap.getVal(objSalientLoc.i,objSalientLoc.j); 00471 FV.push_back(featureVal); 00472 // SHOWIMG(rescale(submap, 255, 255)); 00473 00474 } 00475 00476 return FV; 00477 00478 } 00479 00480 00481 // ###################################################################### 00482 /* So things look consistent in everyone's emacs... */ 00483 /* Local Variables: */ 00484 /* indent-tabs-mode: nil */ 00485 /* End: */