InferoTemporalSalBayes.C

Go to the documentation of this file.
00001 /*!@file Neuro/InferoTemporalSalBayes.C Object recognition module with SalBayes */
00002 
00003 // //////////////////////////////////////////////////////////////////// //
00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the //
00005 // University of Southern California (USC) and the iLab at USC.         //
00006 // See http://iLab.usc.edu for information about this project.          //
00007 // //////////////////////////////////////////////////////////////////// //
00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected //
00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency //
00010 // in Visual Environments, and Applications'' by Christof Koch and      //
00011 // Laurent Itti, California Institute of Technology, 2001 (patent       //
00012 // pending; application number 09/912,225 filed July 23, 2001; see      //
00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status).     //
00014 // //////////////////////////////////////////////////////////////////// //
00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit.       //
00016 //                                                                      //
00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can   //
00018 // redistribute it and/or modify it under the terms of the GNU General  //
00019 // Public License as published by the Free Software Foundation; either  //
00020 // version 2 of the License, or (at your option) any later version.     //
00021 //                                                                      //
00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope  //
00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the   //
00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR      //
00025 // PURPOSE.  See the GNU General Public License for more details.       //
00026 //                                                                      //
00027 // You should have received a copy of the GNU General Public License    //
00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write   //
00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330,   //
00030 // Boston, MA 02111-1307 USA.                                           //
00031 // //////////////////////////////////////////////////////////////////// //
00032 //
00033 // Primary maintainer for this file: Lior Elazary
00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/Neuro/InferoTemporalSalBayes.C $
00035 // $Id: InferoTemporalSalBayes.C 12337 2009-12-19 02:45:23Z itti $
00036 //
00037 
00038 #include "Neuro/InferoTemporalSalBayes.H"
00039 
00040 #include "Component/OptionManager.H"
00041 #include "Component/ModelOptionDef.H"
00042 #include "Channels/ChannelMaps.H"
00043 #include "Channels/ChannelOpts.H"
00044 #include "Image/MathOps.H"
00045 #include "Image/ShapeOps.H"
00046 #include "Image/CutPaste.H"
00047 #include "Neuro/NeuroOpts.H"
00048 #include "Neuro/NeuroSimEvents.H"
00049 #include "Neuro/Brain.H"
00050 #include "Neuro/VisualCortex.H"
00051 #include "Simulation/SimEventQueue.H"
00052 #include "Media/MediaSimEvents.H"
00053 
00054 #include <cstdlib>
00055 #include <iostream>
00056 
00057 static const ModelOptionDef OPT_ITSB_FOVSize =
00058   { MODOPT_ARG(Dims), "FoveaSize", &MOC_ITC, OPTEXP_CORE,
00059     "Use the given fovea size for constructing a feature vector.",
00060     "it-fov-size", '\0', "<w>x<h>", "75x75" };
00061 
00062 static const ModelOptionDef OPT_ITCMode =
00063   { MODOPT_ARG_STRING, "ITC Mode", &MOC_ITC, OPTEXP_CORE,
00064     "The mode of ITC . Train: is for training from some data, Rec is for recognition.",
00065     "it-mode", '\0', "<Train|Rec>", "Rec" };
00066 
00067 static const ModelOptionDef OPT_ITCSalBayes_NetFile =
00068 { MODOPT_ARG_STRING, "ITC BaysNet File", &MOC_ITC, OPTEXP_CORE,
00069   "Name of the file to save/read the computed Bayesian Network.",
00070   "it-bayesnet-file", '\0', "<filename>", "SalBayes.net" };
00071 
00072 const ModelOptionDef OPT_ITSiftObjectDBFileName =
00073   { MODOPT_ARG_STRING, "ITC SiftObjectDBFileName", &MOC_ITC, OPTEXP_CORE,
00074     "Filename for the sift object database. Specifying no file will disable the sift alg.",
00075     "it-object-db", '\0', "<filename>", "" };
00076 
00077 const ModelOptionDef OPT_ITUseSift =
00078   { MODOPT_ARG(int), "ITC use sift", &MOC_ITC, OPTEXP_CORE,
00079     "Use sift recognition on the n most probable objects obtained from SalBayes. "
00080     "That is, the sift algorithm will only run on the top n objects returned from SalBayes."
00081     "0 disables sift and just return the most probable object",
00082     "it-use-sift", '\0', "<int>", "10" };
00083 
00084 const ModelOptionDef OPT_ITUseMaxNMatches =
00085   { MODOPT_ARG(bool), "ITC use max num of matches", &MOC_ITC, OPTEXP_CORE,
00086     "When determining a which object in the database matches, use the maximum "
00087     "number of keypoints matches instead of sorting by distance of keypoints.",
00088     "it-use-max-num-matches", '\0', "<true|false>", "true" };
00089 
00090 // ######################################################################
00091 // functor to assist with VisualObjectMatch sorting:
00092 class moreVOM
00093 {
00094 public:
00095   moreVOM(const float kcoeff, const float acoeff) :
00096     itsKcoeff(kcoeff), itsAcoeff(acoeff)
00097   { }
00098 
00099   bool operator()(const rutz::shared_ptr<VisualObjectMatch>& x,
00100                   const rutz::shared_ptr<VisualObjectMatch>& y)
00101   { return ( x->getScore(itsKcoeff, itsAcoeff) >
00102              y->getScore(itsKcoeff, itsAcoeff) ); }
00103 
00104 private:
00105   float itsKcoeff, itsAcoeff;
00106 };
00107 
00108 
00109 // ######################################################################
00110 InferoTemporalSalBayes::InferoTemporalSalBayes(OptionManager& mgr,
00111                                      const std::string& descrName,
00112                                      const std::string& tagName) :
00113   InferoTemporal(mgr, descrName, tagName),
00114   itsLevelSpec(&OPT_LevelSpec, this),
00115   itsFoveaSize(&OPT_ITSB_FOVSize, this),
00116   itsITCMode(&OPT_ITCMode, this),
00117   itsBayesNetFilename(&OPT_ITCSalBayes_NetFile, this),
00118   itsUseSift(&OPT_ITUseSift, this),
00119   itsSiftObjectDBFile(&OPT_ITSiftObjectDBFileName, this),
00120   itsUseMaxNMatches(&OPT_ITUseMaxNMatches, this),
00121   itsSiftObjectDB(new VisualObjectDB())
00122 {}
00123 
00124 // ######################################################################
00125 void InferoTemporalSalBayes::start1()
00126 {
00127   // if no filename given for our object DB, start empty, otherwise load it:
00128   if (!itsSiftObjectDBFile.getVal().empty())
00129     itsSiftObjectDB->loadFrom(itsSiftObjectDBFile.getVal());
00130 
00131   InferoTemporal::start1();
00132 }
00133 
00134 // ######################################################################
00135 void InferoTemporalSalBayes::stop1()
00136 {}
00137 
00138 // ######################################################################
00139 InferoTemporalSalBayes::~InferoTemporalSalBayes()
00140 {}
00141 
00142 // ######################################################################
00143 void InferoTemporalSalBayes::attentionShift(SimEventQueue& q,
00144                                             const Point2D<int>& location)
00145 {
00146   if (itsBayesNet.is_valid() == false)
00147     {
00148       rutz::shared_ptr<SimReqVCXmaps> vcxm(new SimReqVCXmaps(this));
00149       q.request(vcxm); // VisualCortex is now filling-in the maps...
00150       rutz::shared_ptr<ChannelMaps> chm = vcxm->channelmaps();
00151 
00152       itsBayesNet.reset(new Bayes(chm->numSubmaps(), 0));
00153       itsBayesNet->load(itsBayesNetFilename.getVal().c_str());
00154     }
00155 
00156 
00157   // do we have 3D information about the scene?
00158   /*
00159   if(SeC<SimEventSaliencyCoordMap> e = q.check<SimEventSaliencyCoordMap>(this))
00160   {
00161     itsHas3Dinfo     = true;
00162     its3DSalLocation = e->getMaxSalCurrent();
00163     its3DSalVal      = e->getMaxSalValue();
00164   }
00165   else
00166   {
00167     itsHas3Dinfo     = false;
00168   }
00169   */
00170 
00171   //extract features
00172   std::vector<double> fv = buildRawDV(q, location);
00173 
00174   rutz::shared_ptr<VisualObject> vobj;
00175   if (!itsSiftObjectDBFile.getVal().empty())
00176   {
00177     //Train the Sift Recognition
00178     Image<PixRGB<float> > objImg;
00179     if (SeC<SimEventRetinaImage> e = q.check<SimEventRetinaImage>(this))
00180       objImg = e->frame().colorByte();
00181     if (!objImg.initialized()) return;
00182 
00183       vobj.reset(new VisualObject("NewObject", "NULL", objImg,
00184             Point2D<int>(-1,-1),
00185             std::vector<float>(),
00186             std::vector< rutz::shared_ptr<Keypoint> >(),
00187             false)); //Use color?
00188   }
00189 
00190   if (itsITCMode.getVal().compare("Train") == 0)          // training
00191   {
00192     rutz::shared_ptr<TestImages::SceneData> sceneData;
00193     //Get the scene data, but dont mark it so we will get it on the next saccade
00194     std::string objName;
00195     if (SeC<SimEventInputFrame> e = q.check<SimEventInputFrame>(this,SEQ_UNMARKED,0))
00196     {
00197 
00198       GenericFrame gf = e->frame();
00199       if (gf.hasMetaData(std::string("SceneData")))
00200       {
00201         rutz::shared_ptr<GenericFrame::MetaData> metaData = gf.getMetaData(std::string("SceneData"));
00202         if (metaData.get() != 0)
00203         {
00204           sceneData.dyn_cast_from(metaData);
00205           objName = getObjNameAtLoc(sceneData->objects, location);
00206         }
00207       } else {
00208         LINFO("Enter name for new object or [RETURN] to skip training:");
00209         std::getline(std::cin, objName, '\n');
00210       }
00211 
00212     }
00213 
00214     if (objName.length() > 0)
00215     {
00216       LINFO("Train on %s", objName.c_str());
00217       itsBayesNet->learn(fv, objName.c_str());
00218       itsBayesNet->save(itsBayesNetFilename.getVal().c_str());
00219 
00220       //Train The Sift alg
00221       if (!itsSiftObjectDBFile.getVal().empty())
00222       {
00223         vobj->setName(objName);
00224         itsSiftObjectDB->addObject(vobj, false); //allow multiple object names
00225         itsSiftObjectDB->saveTo(itsSiftObjectDBFile.getVal());
00226       }
00227     }
00228   }
00229   else if (itsITCMode.getVal().compare("Rec") == 0)      // Recognition
00230   {
00231     if (itsUseSift.getVal() > 0 && !itsSiftObjectDBFile.getVal().empty())
00232       predictWithSift(fv, vobj, q);
00233     else
00234       predict(fv, q);
00235   }
00236   else
00237     LFATAL("Unknown IT Mode type %s", itsITCMode.getVal().c_str());
00238 }
00239 // ######################################################################
00240 
00241 void InferoTemporalSalBayes::predict(std::vector<double> &fv, SimEventQueue& q)
00242 {
00243   const int cls = itsBayesNet->classify(fv);
00244 
00245   if (cls != -1)
00246   {
00247     const double maxProb  = itsBayesNet->getMaxProb();
00248     const double normProb = itsBayesNet->getNormProb();
00249 
00250     std::string clsName(itsBayesNet->getClassName(cls));
00251     rutz::shared_ptr<TestImages::ObjData> objData(new TestImages::ObjData);
00252     objData->name     = clsName;
00253     objData->id       = cls;
00254     objData->maxProb  = maxProb;
00255     objData->normProb = normProb;
00256 
00257     rutz::shared_ptr<SimEventObjectDescription>
00258       objDataEvent(new SimEventObjectDescription(this, objData));
00259     q.post(objDataEvent);
00260   }
00261 }
00262 // ######################################################################
00263 
00264 void InferoTemporalSalBayes::predictWithSift(std::vector<double> &fv,
00265   rutz::shared_ptr<VisualObject> &vobj,
00266   SimEventQueue& q)
00267 {
00268   int cls = -1;
00269   uint mink = 6U; //min # of keypoints
00270   float kcoeff = 0.5; //keypoint distance score default 0.5F
00271   float acoeff = 0.5; //affine distance score default 0.5F
00272   //float minscore=1.0F; //minscore  default 1.0F
00273 
00274   std::vector< rutz::shared_ptr<VisualObjectMatch> > matches;
00275   matches.clear();
00276 
00277   std::string maxObjName;
00278   uint maxKeyMatches = 0;
00279 
00280   std::vector<Bayes::ClassInfo> classInfo = itsBayesNet->classifyRange(fv, cls);
00281   for(uint i=0; i< classInfo.size() && i < (uint)itsUseSift.getVal(); i++)
00282   {
00283     std::string clsName(itsBayesNet->getClassName(classInfo[i].classID));
00284     //check all objects containing this name
00285     //TODO: this should be hashed for greater efficiency
00286     for(uint j=0; j<itsSiftObjectDB->numObjects(); j++)
00287     {
00288       rutz::shared_ptr<VisualObject> knownVObj = itsSiftObjectDB->getObject(j);
00289       if (clsName.compare(knownVObj->getName()) == 0)
00290       {
00291         // attempt a match:
00292         rutz::shared_ptr<VisualObjectMatch>
00293           match(new VisualObjectMatch(vobj, knownVObj,
00294                 VOMA_SIMPLE,
00295                 6U)); //keypoint selection threshold, from lowes code
00296 
00297         if (itsUseMaxNMatches.getVal())
00298         {
00299           //Find the max based on num of matches
00300           if (match->size() > maxKeyMatches)
00301           {
00302             maxObjName = clsName;
00303             maxKeyMatches = match->size();
00304           }
00305         } else {
00306           // apply some standard pruning:
00307           match->prune(std::max(25U, mink * 5U), mink);
00308 
00309           //// if the match is good enough, store it:
00310           if (match->size() >= mink &&
00311               //    match->getScore(kcoeff, acoeff) >= minscore &&
00312               match->checkSIFTaffine())
00313           {
00314             matches.push_back(match);
00315           }
00316         }
00317 
00318       }
00319     }
00320     std::sort(matches.begin(), matches.end(), moreVOM(kcoeff, acoeff));
00321   }
00322 
00323   rutz::shared_ptr<TestImages::ObjData> objData(new TestImages::ObjData);
00324   if (itsUseMaxNMatches.getVal())
00325   {
00326     objData->name = maxObjName;
00327     objData->id = (unsigned int) -1;
00328   } else {
00329     if (matches.size() > 0)
00330     {
00331       //Get the first match
00332       objData->name = matches[0]->getVoTest()->getName();
00333       objData->id = (unsigned int) -1;
00334     }
00335   }
00336   rutz::shared_ptr<SimEventObjectDescription>
00337     objDataEvent(new SimEventObjectDescription(this, objData));
00338   q.post(objDataEvent);
00339 }
00340 
00341 // ######################################################################
00342 std::string InferoTemporalSalBayes::getObjNameAtLoc(const std::vector<TestImages::ObjData> &objects, const Point2D<int>& loc)
00343 {
00344 
00345   for(uint obj=0; obj<objects.size(); obj++)
00346   {
00347     TestImages::ObjData objData = objects[obj];
00348 
00349     //find the object dimention from the polygon
00350     if (objData.polygon.size() > 0)
00351     {
00352       Point2D<int> upperLeft = objData.polygon[0];
00353       Point2D<int> lowerRight = objData.polygon[0];
00354 
00355       for(uint i=0; i<objData.polygon.size(); i++)
00356       {
00357         //find the bounds for the crop
00358         if (objData.polygon[i].i < upperLeft.i) upperLeft.i = objData.polygon[i].i;
00359         if (objData.polygon[i].j < upperLeft.j) upperLeft.j = objData.polygon[i].j;
00360 
00361         if (objData.polygon[i].i > lowerRight.i) lowerRight.i = objData.polygon[i].i;
00362         if (objData.polygon[i].j > lowerRight.j) lowerRight.j = objData.polygon[i].j;
00363       }
00364 
00365       //check if point is within the polygon
00366       for(int y=upperLeft.j; y<lowerRight.j; y++)
00367         for(int x=upperLeft.i; x<lowerRight.i; x++)
00368         {
00369           if (pnpoly(objData.polygon, loc))
00370             return objData.name;
00371         }
00372     }
00373 
00374   }
00375   return std::string("Unknown");
00376 }
00377 
00378 // ######################################################################
00379 std::vector<double> InferoTemporalSalBayes::buildRawDV(SimEventQueue& q, const Point2D<int>& foveaLoc)
00380 {
00381   bool salientLocationWithinSubmaps = true;
00382   Point2D<int> objSalientLoc(-1,-1);  //the feature location
00383 
00384   const int smlevel = itsLevelSpec.getVal().mapLevel();
00385 
00386   int x=int(foveaLoc.i / double(1 << smlevel) + 0.49);
00387   int y=int(foveaLoc.j / double(1 << smlevel) + 0.49);
00388 
00389   //LINFO("Getting from location %d,%d",x,y);
00390 
00391   int foveaW = int(itsFoveaSize.getVal().w() / double(1 << smlevel) + 0.49);
00392   int foveaH = int(itsFoveaSize.getVal().h() / double(1 << smlevel) + 0.49);
00393 
00394   int tl_x = x - (foveaW/2);
00395   int tl_y = y - (foveaH/2);
00396 
00397   rutz::shared_ptr<SimReqVCXmaps> vcxm(new SimReqVCXmaps(this));
00398   q.request(vcxm); // VisualCortex is now filling-in the maps...
00399   rutz::shared_ptr<ChannelMaps> chm = vcxm->channelmaps();
00400 
00401   Dims mapDims = chm->getMap().getDims();
00402 
00403   //Shift the fovea location so we dont go outside the image
00404   //Sift the fovea position if nessesary
00405   if (tl_x < 0) tl_x = 0; if (tl_y < 0) tl_y = 0;
00406   if (tl_x+foveaW > mapDims.w()) tl_x = mapDims.w() - foveaW;
00407   if (tl_y+foveaH > mapDims.h()) tl_y = mapDims.h() - foveaH;
00408 
00409   if (!salientLocationWithinSubmaps)
00410   {
00411     //Find the most salient location within the fovea
00412     Image<float> SMap = chm->getMap();
00413 
00414     Image<float> tmp = SMap; //TODO need to resize to fovea
00415     //Find the max location within the fovea
00416 
00417     float maxVal; Point2D<int> maxLoc;
00418     findMax(tmp, maxLoc, maxVal);
00419     //convert back to original SMap cordinates
00420    // objSalientLoc.i=tl_x+maxLoc.i;
00421    // objSalientLoc.j=tl_y+maxLoc.j;
00422     objSalientLoc.i=x;
00423     objSalientLoc.j=y;
00424   }
00425 
00426   //Go through all the submaps building the DV
00427 
00428   std::vector<double> FV;
00429   uint numSubmaps = chm->numSubmaps();
00430   for (uint i = 0; i < numSubmaps; i++)
00431   {
00432     //Image<float> submap = itsComplexChannel->getSubmap(i);
00433     Image<float> submap = chm->getRawCSmap(i);
00434 
00435     // resize submap to fixed scale if necessary:
00436     if (submap.getWidth() > mapDims.w())
00437       submap = downSize(submap, mapDims);
00438     else if (submap.getWidth() < mapDims.w())
00439       submap = rescale(submap, mapDims); //TODO convert to  quickInterpolate
00440 
00441 
00442     if (salientLocationWithinSubmaps) //get the location from the salient location within each submap
00443     {
00444       Image<float> tmp = submap;
00445       //get only the fovea region
00446 
00447       if (foveaW < tmp.getWidth()) //crop if our fovea is smaller
00448         tmp = crop(tmp, Point2D<int>(tl_x, tl_y), Dims(foveaW, foveaH));
00449      // tmp = maxNormalize(tmp, 0.0F, 10.0F, VCXNORM_MAXNORM);  //find salient locations
00450 
00451       //Find the max location within the fovea
00452       float maxVal; Point2D<int> maxLoc; findMax(tmp, maxLoc, maxVal);
00453       //LINFO("%i: Max val %f, loc(%i,%i)", i, maxVal, maxLoc.i, maxLoc.j);
00454 
00455       objSalientLoc.i=tl_x+maxLoc.i;
00456       objSalientLoc.j=tl_y+maxLoc.j;
00457 
00458     }
00459 
00460     if (objSalientLoc.i < 0) objSalientLoc.i = 0;
00461     if (objSalientLoc.j < 0) objSalientLoc.j = 0;
00462 
00463     if (objSalientLoc.i > submap.getWidth()-1) objSalientLoc.i = submap.getWidth()-1;
00464     if (objSalientLoc.j > submap.getHeight()-1) objSalientLoc.j = submap.getHeight()-1;
00465 
00466 
00467 
00468    // LINFO("Location from %i,%i: (%i,%i)", objSalientLoc.i, objSalientLoc.j,
00469     //    submap.getWidth(), submap.getHeight());
00470     float featureVal = submap.getVal(objSalientLoc.i,objSalientLoc.j);
00471     FV.push_back(featureVal);
00472  //   SHOWIMG(rescale(submap, 255, 255));
00473 
00474   }
00475 
00476   return FV;
00477 
00478 }
00479 
00480 
00481 // ######################################################################
00482 /* So things look consistent in everyone's emacs... */
00483 /* Local Variables: */
00484 /* indent-tabs-mode: nil */
00485 /* End: */
Generated on Sun May 8 08:41:03 2011 for iLab Neuromorphic Vision Toolkit by  doxygen 1.6.3