00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Neuro/InferoTemporalSalBayes.H"
00039
00040 #include "Component/OptionManager.H"
00041 #include "Component/ModelOptionDef.H"
00042 #include "Channels/ChannelMaps.H"
00043 #include "Channels/ChannelOpts.H"
00044 #include "Image/MathOps.H"
00045 #include "Image/ShapeOps.H"
00046 #include "Image/CutPaste.H"
00047 #include "Neuro/NeuroOpts.H"
00048 #include "Neuro/NeuroSimEvents.H"
00049 #include "Neuro/Brain.H"
00050 #include "Neuro/VisualCortex.H"
00051 #include "Simulation/SimEventQueue.H"
00052 #include "Media/MediaSimEvents.H"
00053
00054 #include <cstdlib>
00055 #include <iostream>
00056
00057 static const ModelOptionDef OPT_ITSB_FOVSize =
00058 { MODOPT_ARG(Dims), "FoveaSize", &MOC_ITC, OPTEXP_CORE,
00059 "Use the given fovea size for constructing a feature vector.",
00060 "it-fov-size", '\0', "<w>x<h>", "75x75" };
00061
00062 static const ModelOptionDef OPT_ITCMode =
00063 { MODOPT_ARG_STRING, "ITC Mode", &MOC_ITC, OPTEXP_CORE,
00064 "The mode of ITC . Train: is for training from some data, Rec is for recognition.",
00065 "it-mode", '\0', "<Train|Rec>", "Rec" };
00066
00067 static const ModelOptionDef OPT_ITCSalBayes_NetFile =
00068 { MODOPT_ARG_STRING, "ITC BaysNet File", &MOC_ITC, OPTEXP_CORE,
00069 "Name of the file to save/read the computed Bayesian Network.",
00070 "it-bayesnet-file", '\0', "<filename>", "SalBayes.net" };
00071
00072 const ModelOptionDef OPT_ITSiftObjectDBFileName =
00073 { MODOPT_ARG_STRING, "ITC SiftObjectDBFileName", &MOC_ITC, OPTEXP_CORE,
00074 "Filename for the sift object database. Specifying no file will disable the sift alg.",
00075 "it-object-db", '\0', "<filename>", "" };
00076
00077 const ModelOptionDef OPT_ITUseSift =
00078 { MODOPT_ARG(int), "ITC use sift", &MOC_ITC, OPTEXP_CORE,
00079 "Use sift recognition on the n most probable objects obtained from SalBayes. "
00080 "That is, the sift algorithm will only run on the top n objects returned from SalBayes."
00081 "0 disables sift and just return the most probable object",
00082 "it-use-sift", '\0', "<int>", "10" };
00083
00084 const ModelOptionDef OPT_ITUseMaxNMatches =
00085 { MODOPT_ARG(bool), "ITC use max num of matches", &MOC_ITC, OPTEXP_CORE,
00086 "When determining a which object in the database matches, use the maximum "
00087 "number of keypoints matches instead of sorting by distance of keypoints.",
00088 "it-use-max-num-matches", '\0', "<true|false>", "true" };
00089
00090
00091
00092 class moreVOM
00093 {
00094 public:
00095 moreVOM(const float kcoeff, const float acoeff) :
00096 itsKcoeff(kcoeff), itsAcoeff(acoeff)
00097 { }
00098
00099 bool operator()(const rutz::shared_ptr<VisualObjectMatch>& x,
00100 const rutz::shared_ptr<VisualObjectMatch>& y)
00101 { return ( x->getScore(itsKcoeff, itsAcoeff) >
00102 y->getScore(itsKcoeff, itsAcoeff) ); }
00103
00104 private:
00105 float itsKcoeff, itsAcoeff;
00106 };
00107
00108
00109
00110 InferoTemporalSalBayes::InferoTemporalSalBayes(OptionManager& mgr,
00111 const std::string& descrName,
00112 const std::string& tagName) :
00113 InferoTemporal(mgr, descrName, tagName),
00114 itsLevelSpec(&OPT_LevelSpec, this),
00115 itsFoveaSize(&OPT_ITSB_FOVSize, this),
00116 itsITCMode(&OPT_ITCMode, this),
00117 itsBayesNetFilename(&OPT_ITCSalBayes_NetFile, this),
00118 itsUseSift(&OPT_ITUseSift, this),
00119 itsSiftObjectDBFile(&OPT_ITSiftObjectDBFileName, this),
00120 itsUseMaxNMatches(&OPT_ITUseMaxNMatches, this),
00121 itsSiftObjectDB(new VisualObjectDB())
00122 {}
00123
00124
00125 void InferoTemporalSalBayes::start1()
00126 {
00127
00128 if (!itsSiftObjectDBFile.getVal().empty())
00129 itsSiftObjectDB->loadFrom(itsSiftObjectDBFile.getVal());
00130
00131 InferoTemporal::start1();
00132 }
00133
00134
00135 void InferoTemporalSalBayes::stop1()
00136 {}
00137
00138
00139 InferoTemporalSalBayes::~InferoTemporalSalBayes()
00140 {}
00141
00142
00143 void InferoTemporalSalBayes::attentionShift(SimEventQueue& q,
00144 const Point2D<int>& location)
00145 {
00146 if (itsBayesNet.is_valid() == false)
00147 {
00148 rutz::shared_ptr<SimReqVCXmaps> vcxm(new SimReqVCXmaps(this));
00149 q.request(vcxm);
00150 rutz::shared_ptr<ChannelMaps> chm = vcxm->channelmaps();
00151
00152 itsBayesNet.reset(new Bayes(chm->numSubmaps(), 0));
00153 itsBayesNet->load(itsBayesNetFilename.getVal().c_str());
00154 }
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172 std::vector<double> fv = buildRawDV(q, location);
00173
00174 rutz::shared_ptr<VisualObject> vobj;
00175 if (!itsSiftObjectDBFile.getVal().empty())
00176 {
00177
00178 Image<PixRGB<float> > objImg;
00179 if (SeC<SimEventRetinaImage> e = q.check<SimEventRetinaImage>(this))
00180 objImg = e->frame().colorByte();
00181 if (!objImg.initialized()) return;
00182
00183 vobj.reset(new VisualObject("NewObject", "NULL", objImg,
00184 Point2D<int>(-1,-1),
00185 std::vector<float>(),
00186 std::vector< rutz::shared_ptr<Keypoint> >(),
00187 false));
00188 }
00189
00190 if (itsITCMode.getVal().compare("Train") == 0)
00191 {
00192 rutz::shared_ptr<TestImages::SceneData> sceneData;
00193
00194 std::string objName;
00195 if (SeC<SimEventInputFrame> e = q.check<SimEventInputFrame>(this,SEQ_UNMARKED,0))
00196 {
00197
00198 GenericFrame gf = e->frame();
00199 if (gf.hasMetaData(std::string("SceneData")))
00200 {
00201 rutz::shared_ptr<GenericFrame::MetaData> metaData = gf.getMetaData(std::string("SceneData"));
00202 if (metaData.get() != 0)
00203 {
00204 sceneData.dyn_cast_from(metaData);
00205 objName = getObjNameAtLoc(sceneData->objects, location);
00206 }
00207 } else {
00208 LINFO("Enter name for new object or [RETURN] to skip training:");
00209 std::getline(std::cin, objName, '\n');
00210 }
00211
00212 }
00213
00214 if (objName.length() > 0)
00215 {
00216 LINFO("Train on %s", objName.c_str());
00217 itsBayesNet->learn(fv, objName.c_str());
00218 itsBayesNet->save(itsBayesNetFilename.getVal().c_str());
00219
00220
00221 if (!itsSiftObjectDBFile.getVal().empty())
00222 {
00223 vobj->setName(objName);
00224 itsSiftObjectDB->addObject(vobj, false);
00225 itsSiftObjectDB->saveTo(itsSiftObjectDBFile.getVal());
00226 }
00227 }
00228 }
00229 else if (itsITCMode.getVal().compare("Rec") == 0)
00230 {
00231 if (itsUseSift.getVal() > 0 && !itsSiftObjectDBFile.getVal().empty())
00232 predictWithSift(fv, vobj, q);
00233 else
00234 predict(fv, q);
00235 }
00236 else
00237 LFATAL("Unknown IT Mode type %s", itsITCMode.getVal().c_str());
00238 }
00239
00240
00241 void InferoTemporalSalBayes::predict(std::vector<double> &fv, SimEventQueue& q)
00242 {
00243 const int cls = itsBayesNet->classify(fv);
00244
00245 if (cls != -1)
00246 {
00247 const double maxProb = itsBayesNet->getMaxProb();
00248 const double normProb = itsBayesNet->getNormProb();
00249
00250 std::string clsName(itsBayesNet->getClassName(cls));
00251 rutz::shared_ptr<TestImages::ObjData> objData(new TestImages::ObjData);
00252 objData->name = clsName;
00253 objData->id = cls;
00254 objData->maxProb = maxProb;
00255 objData->normProb = normProb;
00256
00257 rutz::shared_ptr<SimEventObjectDescription>
00258 objDataEvent(new SimEventObjectDescription(this, objData));
00259 q.post(objDataEvent);
00260 }
00261 }
00262
00263
00264 void InferoTemporalSalBayes::predictWithSift(std::vector<double> &fv,
00265 rutz::shared_ptr<VisualObject> &vobj,
00266 SimEventQueue& q)
00267 {
00268 int cls = -1;
00269 uint mink = 6U;
00270 float kcoeff = 0.5;
00271 float acoeff = 0.5;
00272
00273
00274 std::vector< rutz::shared_ptr<VisualObjectMatch> > matches;
00275 matches.clear();
00276
00277 std::string maxObjName;
00278 uint maxKeyMatches = 0;
00279
00280 std::vector<Bayes::ClassInfo> classInfo = itsBayesNet->classifyRange(fv, cls);
00281 for(uint i=0; i< classInfo.size() && i < (uint)itsUseSift.getVal(); i++)
00282 {
00283 std::string clsName(itsBayesNet->getClassName(classInfo[i].classID));
00284
00285
00286 for(uint j=0; j<itsSiftObjectDB->numObjects(); j++)
00287 {
00288 rutz::shared_ptr<VisualObject> knownVObj = itsSiftObjectDB->getObject(j);
00289 if (clsName.compare(knownVObj->getName()) == 0)
00290 {
00291
00292 rutz::shared_ptr<VisualObjectMatch>
00293 match(new VisualObjectMatch(vobj, knownVObj,
00294 VOMA_SIMPLE,
00295 6U));
00296
00297 if (itsUseMaxNMatches.getVal())
00298 {
00299
00300 if (match->size() > maxKeyMatches)
00301 {
00302 maxObjName = clsName;
00303 maxKeyMatches = match->size();
00304 }
00305 } else {
00306
00307 match->prune(std::max(25U, mink * 5U), mink);
00308
00309
00310 if (match->size() >= mink &&
00311
00312 match->checkSIFTaffine())
00313 {
00314 matches.push_back(match);
00315 }
00316 }
00317
00318 }
00319 }
00320 std::sort(matches.begin(), matches.end(), moreVOM(kcoeff, acoeff));
00321 }
00322
00323 rutz::shared_ptr<TestImages::ObjData> objData(new TestImages::ObjData);
00324 if (itsUseMaxNMatches.getVal())
00325 {
00326 objData->name = maxObjName;
00327 objData->id = (unsigned int) -1;
00328 } else {
00329 if (matches.size() > 0)
00330 {
00331
00332 objData->name = matches[0]->getVoTest()->getName();
00333 objData->id = (unsigned int) -1;
00334 }
00335 }
00336 rutz::shared_ptr<SimEventObjectDescription>
00337 objDataEvent(new SimEventObjectDescription(this, objData));
00338 q.post(objDataEvent);
00339 }
00340
00341
00342 std::string InferoTemporalSalBayes::getObjNameAtLoc(const std::vector<TestImages::ObjData> &objects, const Point2D<int>& loc)
00343 {
00344
00345 for(uint obj=0; obj<objects.size(); obj++)
00346 {
00347 TestImages::ObjData objData = objects[obj];
00348
00349
00350 if (objData.polygon.size() > 0)
00351 {
00352 Point2D<int> upperLeft = objData.polygon[0];
00353 Point2D<int> lowerRight = objData.polygon[0];
00354
00355 for(uint i=0; i<objData.polygon.size(); i++)
00356 {
00357
00358 if (objData.polygon[i].i < upperLeft.i) upperLeft.i = objData.polygon[i].i;
00359 if (objData.polygon[i].j < upperLeft.j) upperLeft.j = objData.polygon[i].j;
00360
00361 if (objData.polygon[i].i > lowerRight.i) lowerRight.i = objData.polygon[i].i;
00362 if (objData.polygon[i].j > lowerRight.j) lowerRight.j = objData.polygon[i].j;
00363 }
00364
00365
00366 for(int y=upperLeft.j; y<lowerRight.j; y++)
00367 for(int x=upperLeft.i; x<lowerRight.i; x++)
00368 {
00369 if (pnpoly(objData.polygon, loc))
00370 return objData.name;
00371 }
00372 }
00373
00374 }
00375 return std::string("Unknown");
00376 }
00377
00378
00379 std::vector<double> InferoTemporalSalBayes::buildRawDV(SimEventQueue& q, const Point2D<int>& foveaLoc)
00380 {
00381 bool salientLocationWithinSubmaps = true;
00382 Point2D<int> objSalientLoc(-1,-1);
00383
00384 const int smlevel = itsLevelSpec.getVal().mapLevel();
00385
00386 int x=int(foveaLoc.i / double(1 << smlevel) + 0.49);
00387 int y=int(foveaLoc.j / double(1 << smlevel) + 0.49);
00388
00389
00390
00391 int foveaW = int(itsFoveaSize.getVal().w() / double(1 << smlevel) + 0.49);
00392 int foveaH = int(itsFoveaSize.getVal().h() / double(1 << smlevel) + 0.49);
00393
00394 int tl_x = x - (foveaW/2);
00395 int tl_y = y - (foveaH/2);
00396
00397 rutz::shared_ptr<SimReqVCXmaps> vcxm(new SimReqVCXmaps(this));
00398 q.request(vcxm);
00399 rutz::shared_ptr<ChannelMaps> chm = vcxm->channelmaps();
00400
00401 Dims mapDims = chm->getMap().getDims();
00402
00403
00404
00405 if (tl_x < 0) tl_x = 0; if (tl_y < 0) tl_y = 0;
00406 if (tl_x+foveaW > mapDims.w()) tl_x = mapDims.w() - foveaW;
00407 if (tl_y+foveaH > mapDims.h()) tl_y = mapDims.h() - foveaH;
00408
00409 if (!salientLocationWithinSubmaps)
00410 {
00411
00412 Image<float> SMap = chm->getMap();
00413
00414 Image<float> tmp = SMap;
00415
00416
00417 float maxVal; Point2D<int> maxLoc;
00418 findMax(tmp, maxLoc, maxVal);
00419
00420
00421
00422 objSalientLoc.i=x;
00423 objSalientLoc.j=y;
00424 }
00425
00426
00427
00428 std::vector<double> FV;
00429 uint numSubmaps = chm->numSubmaps();
00430 for (uint i = 0; i < numSubmaps; i++)
00431 {
00432
00433 Image<float> submap = chm->getRawCSmap(i);
00434
00435
00436 if (submap.getWidth() > mapDims.w())
00437 submap = downSize(submap, mapDims);
00438 else if (submap.getWidth() < mapDims.w())
00439 submap = rescale(submap, mapDims);
00440
00441
00442 if (salientLocationWithinSubmaps)
00443 {
00444 Image<float> tmp = submap;
00445
00446
00447 if (foveaW < tmp.getWidth())
00448 tmp = crop(tmp, Point2D<int>(tl_x, tl_y), Dims(foveaW, foveaH));
00449
00450
00451
00452 float maxVal; Point2D<int> maxLoc; findMax(tmp, maxLoc, maxVal);
00453
00454
00455 objSalientLoc.i=tl_x+maxLoc.i;
00456 objSalientLoc.j=tl_y+maxLoc.j;
00457
00458 }
00459
00460 if (objSalientLoc.i < 0) objSalientLoc.i = 0;
00461 if (objSalientLoc.j < 0) objSalientLoc.j = 0;
00462
00463 if (objSalientLoc.i > submap.getWidth()-1) objSalientLoc.i = submap.getWidth()-1;
00464 if (objSalientLoc.j > submap.getHeight()-1) objSalientLoc.j = submap.getHeight()-1;
00465
00466
00467
00468
00469
00470 float featureVal = submap.getVal(objSalientLoc.i,objSalientLoc.j);
00471 FV.push_back(featureVal);
00472
00473
00474 }
00475
00476 return FV;
00477
00478 }
00479
00480
00481
00482
00483
00484
00485