00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Neuro/InferoTemporal.H"
00039
00040 #include "Component/OptionManager.H"
00041 #include "Image/CutPaste.H"
00042 #include "Neuro/NeuroOpts.H"
00043 #include "Neuro/NeuroSimEvents.H"
00044 #include "Neuro/Brain.H"
00045 #include "Neuro/VisualCortex.H"
00046 #include "SIFT/VisualObjectDB.H"
00047 #include "SIFT/VisualObject.H"
00048
00049 #include <cstdlib>
00050 #include <iostream>
00051
00052
00053 namespace
00054 {
00055 Image<PixRGB<byte> > getCroppedObject(const Image<PixRGB<byte> >& scene,
00056 const Image<float>& smoothMask)
00057 {
00058 if (!scene.initialized())
00059 return Image<PixRGB<byte> >();
00060
00061 if (!smoothMask.initialized())
00062 return Image<PixRGB<byte> >();
00063
00064 const float threshold = 1.0f;
00065
00066 const Rectangle r = findBoundingRect(smoothMask, threshold);
00067 return crop(scene, r);
00068 }
00069 }
00070
00071
00072 InferoTemporal::InferoTemporal(OptionManager& mgr,
00073 const std::string& descrName,
00074 const std::string& tagName) :
00075 SimModule(mgr, descrName, tagName),
00076 SIMCALLBACK_INIT(SimEventWTAwinner)
00077 { }
00078
00079
00080 InferoTemporal::~InferoTemporal()
00081 { }
00082
00083
00084 void InferoTemporal::
00085 onSimEventWTAwinner(SimEventQueue& q, rutz::shared_ptr<SimEventWTAwinner>& e)
00086 {
00087 this->attentionShift(q, e->winner().p);
00088 }
00089
00090
00091 InferoTemporalStub::InferoTemporalStub(OptionManager& mgr,
00092 const std::string& descrName,
00093 const std::string& tagName)
00094 :
00095 InferoTemporal(mgr, descrName, tagName)
00096 {}
00097
00098
00099 InferoTemporalStub::~InferoTemporalStub()
00100 {}
00101
00102
00103 void InferoTemporalStub::attentionShift(SimEventQueue& q,
00104 const Point2D<int>& location)
00105 {}
00106
00107
00108 InferoTemporalStd::InferoTemporalStd(OptionManager& mgr,
00109 const std::string& descrName,
00110 const std::string& tagName) :
00111 InferoTemporal(mgr, descrName, tagName),
00112 itsUseAttention(&OPT_AttentionObjRecog, this),
00113 itsObjectDatabaseFile(&OPT_ObjectDatabaseFileName, this),
00114 itsTrainObjectDB(&OPT_TrainObjectDB, this),
00115 itsPromptUserTrainDB(&OPT_PromptUserTrainDB, this),
00116 itsMatchObjects(&OPT_MatchObjects, this),
00117 itsRecogMinMatch(&OPT_RecognitionMinMatch, this),
00118 itsMatchingAlg(&OPT_MatchingAlgorithm, this),
00119 itsObjectDB(new VisualObjectDB())
00120 { }
00121
00122
00123 void InferoTemporalStd::start1()
00124 {
00125
00126 if (itsObjectDatabaseFile.getVal().empty())
00127 LINFO("Starting with empty VisualObjectDB.");
00128 else
00129 itsObjectDB->loadFrom(itsObjectDatabaseFile.getVal());
00130
00131 InferoTemporal::start1();
00132 }
00133
00134
00135 void InferoTemporalStd::stop1()
00136 {
00137
00138 if (itsObjectDatabaseFile.getVal().empty() == false)
00139 itsObjectDB->saveTo(itsObjectDatabaseFile.getVal());
00140 }
00141
00142
00143 InferoTemporalStd::~InferoTemporalStd()
00144 {}
00145
00146
00147 void InferoTemporalStd::attentionShift(SimEventQueue& q,
00148 const Point2D<int>& location)
00149 {
00150 Image<PixRGB<float> > objImg;
00151
00152
00153 if (SeC<SimEventRetinaImage> e = q.check<SimEventRetinaImage>(this))
00154 objImg = e->frame().colorByte();
00155 else
00156 LFATAL("Oooops, no input frame in the event queue?");
00157
00158
00159 Image<float> smoothMask;
00160 if (SeC<SimEventShapeEstimatorOutput>
00161 e = q.check<SimEventShapeEstimatorOutput>(this))
00162 smoothMask = e->smoothMask();
00163
00164
00165 if (itsUseAttention.getVal())
00166 objImg = getCroppedObject(objImg, smoothMask);
00167
00168 if (!objImg.initialized()) return;
00169
00170 rutz::shared_ptr<SimReqVCXfeatures> ef(new SimReqVCXfeatures(this, location));
00171 q.request(ef);
00172
00173
00174
00175 rutz::shared_ptr<VisualObject>
00176 obj(new VisualObject("NewObject", "NewObject", objImg, location, ef->features()));
00177
00178
00179 if (itsMatchObjects.getVal())
00180 {
00181
00182
00183 if (obj->numKeypoints() < 3)
00184 { LINFO("Not enough Keypoints -- NO RECOGNITION"); return; }
00185
00186 LINFO("Attempting object recognition...");
00187 std::vector< rutz::shared_ptr<VisualObjectMatch> > matches;
00188
00189 const uint nm =
00190 itsObjectDB->getObjectMatches(obj, matches, VOMA_KDTREEBBF,
00191 100U, 0.5F, 0.5F, 1.0F,
00192 uint(itsRecogMinMatch.getVal()),
00193 6U, false);
00194
00195 if (nm > 0)
00196 {
00197 LINFO("***** %u object recognition match(es) *****", nm);
00198 for (uint i = 0 ; i < nm; i ++)
00199 LINFO(" Match with '%s' [score = %f]",
00200 matches[i]->getVoTest()->getName().c_str(),
00201 matches[i]->getScore());
00202 }
00203 else
00204 LINFO("***** Could not identify attended object! *****");
00205 }
00206
00207
00208 if (itsTrainObjectDB.getVal())
00209 {
00210 std::string objname;
00211
00212
00213 if (itsPromptUserTrainDB.getVal())
00214 {
00215 LINFO("Enter name for new object or [RETURN] to skip training:");
00216 std::getline(std::cin, objname, '\n');
00217 }
00218 else
00219 {
00220
00221 char tmpn[14]; strcpy(tmpn, "Object-XXXXXX");
00222 if(mkstemp(tmpn) == -1)
00223 LFATAL("mkstemp failed");
00224 objname = tmpn;
00225 }
00226
00227
00228 if (objname.length() > 0)
00229 {
00230 LINFO("Adding new object '%s' to database.", objname.c_str());
00231 obj->setName(objname);
00232 obj->setImageFname(objname + ".png");
00233 itsObjectDB->addObject(obj);
00234 }
00235
00236 }
00237 }
00238
00239
00240
00241
00242
00243