00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Neuro/AttentionGate.H"
00039
00040 #include "Channels/ChannelBase.H"
00041 #include "Channels/ChannelMaps.H"
00042 #include "Component/OptionManager.H"
00043 #include "Image/Image.H"
00044 #include "Image/MathOps.H"
00045 #include "Image/ShapeOps.H"
00046 #include "Media/MediaSimEvents.H"
00047 #include "Neuro/NeuroOpts.H"
00048 #include "Neuro/VisualCortex.H"
00049 #include "Simulation/SimEventQueue.H"
00050 #include "Transport/FrameInfo.H"
00051 #include "Transport/FrameOstream.H"
00052 #include "Util/log.H"
00053
00054
00055
00056
00057
00058
00059
00060 AttentionGate::
00061 AttentionGate(OptionManager& mgr,
00062 const std::string& descrName,
00063 const std::string& tagName,
00064 const nub::soft_ref<VisualCortex> vcx) :
00065 SimModule(mgr, descrName, tagName),
00066 itsSaveResults(&OPT_AGsaveResults, this),
00067 itsVCX(vcx), itsLogSigO(0.10f), itsLogSigS(20.0f)
00068 { }
00069
00070
00071 AttentionGate::~AttentionGate()
00072 { }
00073
00074
00075
00076
00077
00078
00079 AttentionGateConfigurator::
00080 AttentionGateConfigurator(OptionManager& mgr,
00081 const std::string& descrName,
00082 const std::string& tagName) :
00083 ModelComponent(mgr, descrName, tagName),
00084 itsAGtype(&OPT_AttentionGateType, this),
00085 itsAG(new AttentionGateStd(mgr))
00086 {
00087 addSubComponent(itsAG);
00088 }
00089
00090
00091 AttentionGateConfigurator::~AttentionGateConfigurator()
00092 { }
00093
00094
00095 nub::ref<AttentionGate>
00096 AttentionGateConfigurator::getAG() const
00097 { return itsAG; }
00098
00099
00100 void AttentionGateConfigurator::
00101 paramChanged(ModelParamBase* const param,
00102 const bool valueChanged,
00103 ParamClient::ChangeStatus* status)
00104 {
00105 ModelComponent::paramChanged(param, valueChanged, status);
00106
00107
00108 if (param == &itsAGtype) {
00109
00110
00111
00112 removeSubComponent(*itsAG);
00113
00114
00115 if (itsAGtype.getVal().compare("Std") == 0)
00116 itsAG.reset(new AttentionGateStd(getManager()));
00117 else if (itsAGtype.getVal().compare("None") == 0)
00118 itsAG.reset(new AttentionGateStub(getManager()));
00119 else
00120 LFATAL("Unknown AG type %s", itsAGtype.getVal().c_str());
00121
00122
00123
00124
00125
00126
00127 addSubComponent(itsAG);
00128
00129
00130 itsAG->exportOptions(MC_RECURSE);
00131
00132
00133 LINFO("Selected AG of type %s", itsAGtype.getVal().c_str());
00134 }
00135 }
00136
00137
00138
00139
00140
00141
00142
00143
00144 AttentionGateStd::
00145 AttentionGateStd(OptionManager& mgr,
00146 const std::string& descrName,
00147 const std::string& tagName) :
00148 AttentionGate(mgr, descrName, tagName),
00149 SIMCALLBACK_INIT(SimEventInputFrame),
00150 SIMCALLBACK_INIT(SimEventAttentionGuidanceMapOutput),
00151 SIMCALLBACK_INIT(SimEventSaveOutput),
00152 itsSegmentDone(false),
00153 itsAGStageOneType(&OPT_AttentionGateStageOneType, this),
00154 itsAGStageTwoType(&OPT_AttentionGateStageTwoType, this),
00155 itsAGStageTwoEpochs(&OPT_AttentionGateStageTwoEpochs, this),
00156 itsMaxStageTwoFrames(10),
00157 itsStageTwoGetFeatures(AG_MAX),
00158 itsT(SimTime::ZERO()), itsTimeStep(SimTime::SECS(0.0001)),
00159 itsC(1.0), itsLeak(5000.0)
00160 { }
00161
00162
00163 AttentionGateStd::~AttentionGateStd()
00164 { }
00165
00166
00167 void AttentionGateStd::
00168 onSimEventSaveOutput(SimEventQueue& q, rutz::shared_ptr<SimEventSaveOutput>& e)
00169 {
00170
00171
00172 LINFO("Save Attention Gate");
00173 Image<float> ag = this->getValue(q);
00174 Image<float> lam = this->getLastAttentionMap();
00175 Image<float> cam = this->getCurrentAttentionMap();
00176
00177 if (ag.initialized())
00178 {
00179
00180 ag = normalizeFloat(ag,FLOAT_NORM_0_255)/255.0F;
00181 ag = logSig(ag,itsLogSigO,itsLogSigS) * 255.0F;
00182
00183 lam = normalizeFloat(lam,FLOAT_NORM_0_255)/255.0F;
00184 lam = logSig(lam,itsLogSigO,itsLogSigS) * 255.0F;
00185
00186 cam = normalizeFloat(cam,FLOAT_NORM_0_255)/255.0F;
00187 cam = logSig(cam,itsLogSigO,itsLogSigS) * 255.0F;
00188
00189 rutz::shared_ptr<SimEventAttentionGateOutput>
00190 ago(new SimEventAttentionGateOutput(this, ag, lam, cam, itsFrameNumber-1));
00191 q.post(ago);
00192
00193
00194 if (itsSaveResults.getVal())
00195 {
00196
00197
00198 nub::ref<FrameOstream> ofs = dynamic_cast<const SimModuleSaveInfo&>(e->sinfo()).ofs;
00199
00200 Image<float> lout = rescale(lam,itsSizeX,itsSizeY);
00201
00202
00203 if(itsLastFrame.initialized())
00204 {
00205 Image<float> maskImage = normalizeFloat(lout,FLOAT_NORM_0_255)/255.0F;
00206
00207 Image<PixRGB<float> > outImageFloat = itsLastFrame * maskImage;
00208 Image<PixRGB<byte> > outImage = outImageFloat;
00209
00210 ofs->writeRGB(outImage, "AG-LMASK", FrameInfo("Masked Image", SRC_POS));
00211 }
00212
00213
00214 if (ag.initialized())
00215 {
00216 const Image<float> out = rescale(ag,itsSizeX,itsSizeY);
00217 ofs->writeFloat(out, FLOAT_NORM_PRESERVE,
00218 "AG",
00219 FrameInfo("overall attention gate map", SRC_POS));
00220 }
00221 if(lout.initialized())
00222 {
00223 ofs->writeFloat(lout, FLOAT_NORM_PRESERVE,
00224 "AG-LAM",
00225 FrameInfo("last attention gate map", SRC_POS));
00226 }
00227 if(cam.initialized())
00228 {
00229 const Image<float> out = rescale(cam,itsSizeX,itsSizeY);
00230 ofs->writeFloat(out, FLOAT_NORM_PRESERVE,
00231 "AG-CAM",
00232 FrameInfo("current attention gate map", SRC_POS));
00233 }
00234 itsLastFrame = itsCurrFrame;
00235 }
00236 }
00237 }
00238
00239
00240 void AttentionGateStd::
00241 onSimEventInputFrame(SimEventQueue& q, rutz::shared_ptr<SimEventInputFrame>& e)
00242 {
00243
00244
00245 if (itsSaveResults.getVal())
00246 {
00247 itsSizeX = e->frame().getWidth();
00248 itsSizeY = e->frame().getHeight();
00249 itsCurrFrame = e->frame().asRgb();
00250 itsFrameNumber = (unsigned int)e->frameNum();
00251 }
00252 }
00253
00254
00255 void AttentionGateStd::
00256 onSimEventAttentionGuidanceMapOutput(SimEventQueue& q, rutz::shared_ptr<SimEventAttentionGuidanceMapOutput>& e)
00257 {
00258 Image<float> input = e->agm();
00259
00260 if(!itsStageOneGate.initialized())
00261 {
00262 itsStageOneGate.resize(input.getWidth(),input.getHeight(),0.0F);
00263 itsStageTwoGate.resize(input.getWidth(),input.getHeight(),0.0F);
00264 itsCurrentAttentionMap.resize(input.getWidth(),input.getHeight(),0.0F);
00265 itsLastAttentionMap.resize(input.getWidth(),input.getHeight(),0.0F);
00266 }
00267 itsInput = input;
00268
00269 if(itsAGStageTwoType.getVal().compare("None") != 0)
00270 {
00271 if(itsSegmentDone == false)
00272 {
00273 itsSegmenter.SIsetValThresh(std::vector<float>(1,256.0F),
00274 std::vector<float>(1,32.0F));
00275 itsSegmenter.SItoggleCandidateBandPass(false);
00276 int x = itsInput.getWidth();
00277 int y = itsInput.getHeight();
00278 itsSegmenter.SIsetFrame(&x,&y);
00279 itsSegmenter.SIsetAvg(10);
00280 itsSegmenter.SItoggleRemoveSingles(true);
00281 itsSegmenter.SIsetKillValue(1);
00282 itsSegmentDone = true;
00283 }
00284 }
00285 }
00286
00287
00288
00289 void AttentionGateStd::reset1()
00290 {
00291 itsStageOneGate.freeMem();
00292 itsStageTwoGate.freeMem();
00293 itsCurrentAttentionMap.freeMem();
00294 itsLastAttentionMap.freeMem();
00295
00296 itsTotalEpochs = (uint)((itsAGStageTwoEpochs.getVal() * 2) + 1);
00297 }
00298
00299
00300 Image<float> AttentionGateStd::getValue(SimEventQueue& q)
00301 {
00302
00303 if (itsAGStageOneType.getVal().compare("Simple") == 0)
00304 stageOneGateSimple(q);
00305 else if (itsAGStageOneType.getVal().compare("Complex") == 0)
00306 stageOneGateComplex(q);
00307 else
00308 {
00309 LINFO("Type of stage one attention gate given by");
00310 LFATAL("ag-so-type as '%s' is not valid", itsAGStageOneType.getVal().c_str());
00311 }
00312
00313 if (itsAGStageTwoType.getVal().compare("Std") == 0)
00314 {
00315 stageTwoGate(q);
00316 return itsStageTwoGate;
00317 }
00318 else if (itsAGStageTwoType.getVal().compare("None") == 0)
00319 return itsStageOneGate;
00320 else
00321 {
00322 LINFO("Type of stage two attention gate given by");
00323 LFATAL("ag-st-type as '%s' is not valid", itsAGStageTwoType.getVal().c_str());
00324 }
00325
00326 return itsStageOneGate;
00327 }
00328
00329
00330 void AttentionGateStd::stageOneGateSimple(SimEventQueue& q)
00331 {
00332 const SimTime dt = SimTime::computeDeltaT((q.now() - itsT), itsTimeStep);
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342 itsLastAttentionMap = itsCurrentAttentionMap;
00343
00344
00345 const Image<float> poke = itsInput - itsStageOneGate;
00346
00347 Image<float>::const_iterator p = poke.begin();
00348 Image<float>::iterator a = itsCurrentAttentionMap.beginw();
00349
00350 while(p != poke.end())
00351 {
00352 if(*p < 0) *a++ = 0;
00353 else *a++ = *p;
00354
00355 p++;
00356 }
00357
00358
00359
00360 const float delta = (dt.secs() / itsC);
00361 const Image<float> leak = (itsStageOneGate * itsLeak);
00362 itsStageOneGate -= leak * delta;
00363
00364
00365 a = itsStageOneGate.beginw();
00366 Image<float>::iterator i = itsInput.beginw();
00367
00368
00369 while(a != itsStageOneGate.endw())
00370 {
00371 if(*a < *i) *a = *i;
00372 a++; i++;
00373 }
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387 const Image<float> diff = itsLastAttentionMap - itsCurrentAttentionMap;
00388
00389
00390 Image<float>::const_iterator d = diff.begin();
00391 a = itsLastAttentionMap.beginw();
00392 while(d != diff.end())
00393 {
00394 if(*d < 0) *a++ = 0;
00395 else *a++ = *d;
00396 d++;
00397 }
00398
00399 itsT = q.now();
00400 }
00401
00402
00403 void AttentionGateStd::stageOneGateComplex(SimEventQueue& q)
00404 {
00405 LFATAL("Under Construction");
00406 }
00407
00408
00409 void AttentionGateStd::stageTwoGate(SimEventQueue& q)
00410 {
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426
00427 std::vector<Image<float> > input;
00428 float min,max,avg;
00429 getMinMaxAvg(itsLastAttentionMap,min,max,avg);
00430 Image<float> newImage = ((itsLastAttentionMap - min) / (max - min)) * 255.0f;
00431 input.push_back(newImage);
00432 itsSegmenter.SIsegment(&input,false);
00433 Image<bool> candidates = itsSegmenter.SIreturnCandidates();
00434 itsStageTwoSegments = itsSegmenter.SIreturnBlobs();
00435 int segnum = itsSegmenter.SInumberBlobs();
00436
00437 LINFO("Compute Min Max XY");
00438 computeMinMaxXY(itsLastAttentionMap,itsStageTwoSegments,candidates);
00439
00440 LINFO("Extract Features");
00441 extractFeatureValues(q);
00442
00443 LINFO("Compute Feature Distances");
00444 computeFeatureDistance();
00445
00446 rutz::shared_ptr<SimEventAttentionGateStageTwoSegments>
00447 segs(new SimEventAttentionGateStageTwoSegments(this, candidates,
00448 itsStageTwoObjects.back(),
00449 segnum));
00450
00451 q.post(segs);
00452 }
00453
00454
00455 void AttentionGateStd::computeMinMaxXY(const Image<float>& attentionMap,
00456 const Image<int>& segments,
00457 const Image<bool>& candidates)
00458 {
00459 Image<float>::const_iterator attentionMapItr = attentionMap.begin();
00460 Image<int>::const_iterator segmentsItr = segments.begin();
00461 Image<bool>::const_iterator candidatesItr = candidates.begin();
00462
00463 int maxObjects = 0;
00464
00465
00466 while(candidatesItr != candidates.end())
00467 {
00468 if(*candidatesItr++)
00469 {
00470 if(*segmentsItr > maxObjects)
00471 {
00472 maxObjects = *segmentsItr;
00473 }
00474 }
00475 ++segmentsItr;
00476 }
00477
00478 segmentsItr = segments.begin();
00479 candidatesItr = candidates.begin();
00480
00481 itsStageTwoObjectX.resize(maxObjects+1,-1);
00482 itsStageTwoObjectY.resize(maxObjects+1,-1);
00483 itsStageTwoObjectVal.resize(maxObjects+1,0);
00484 itsStageTwoObjectID.resize(maxObjects+1,0);
00485
00486 int pos = 0;
00487 const int w = candidates.getWidth();
00488
00489
00490
00491
00492 while(candidatesItr != candidates.end())
00493 {
00494 if(*candidatesItr++)
00495 {
00496 if(*attentionMapItr > itsStageTwoObjectVal[*segmentsItr])
00497 {
00498 itsStageTwoObjectVal[*segmentsItr] = *attentionMapItr;
00499 itsStageTwoObjectX[*segmentsItr] = pos%w;
00500 itsStageTwoObjectY[*segmentsItr] = pos/w;
00501 itsStageTwoObjectID[*segmentsItr] = *segmentsItr;
00502 }
00503 }
00504 pos++; ++attentionMapItr; ++segmentsItr;
00505 }
00506 }
00507
00508
00509 void AttentionGateStd::extractFeatureValues(SimEventQueue& q)
00510 {
00511 std::vector<int>::iterator stageTwoObjectXItr = itsStageTwoObjectX.begin();
00512 std::vector<int>::iterator stageTwoObjectYItr = itsStageTwoObjectY.begin();
00513 std::vector<int>::iterator stageTwoObjectIDItr = itsStageTwoObjectID.begin();
00514 std::vector<float>::iterator stageTwoObjectValItr = itsStageTwoObjectVal.begin();
00515
00516 const float h = (float)itsLastAttentionMap.getHeight();
00517 const float w = (float)itsLastAttentionMap.getWidth();
00518
00519 std::vector<std::vector<float> > featureObjects;
00520 std::vector<int> featureX;
00521 std::vector<int> featureY;
00522 std::vector<int> featureID;
00523 std::vector<float> featureVal;
00524
00525 rutz::shared_ptr<SimReqVCXmaps> vcxm(new SimReqVCXmaps(this));
00526 q.request(vcxm);
00527 rutz::shared_ptr<ChannelMaps> chm = vcxm->channelmaps();
00528
00529 const uint numSubmaps = chm->numSubmaps();
00530
00531 int n = 0;
00532
00533 while(stageTwoObjectYItr != itsStageTwoObjectY.end())
00534 {
00535 const float i = (float)*stageTwoObjectXItr++;
00536 const float j = (float)*stageTwoObjectYItr++;
00537 const float val = (float)*stageTwoObjectValItr++;
00538 const int ID = *stageTwoObjectIDItr++;
00539
00540 std::vector<float> features;
00541
00542
00543 if(i >= 0.0f)
00544 {
00545
00546
00547 if(itsStageTwoGetFeatures == AG_CENTER)
00548 {
00549
00550 for (uint k = 0; k < numSubmaps; k++)
00551 {
00552 const Image<float> submap = chm->getRawCSmap(k);
00553 const float sm_w = (float)submap.getWidth();
00554 const float sm_h = (float)submap.getHeight();
00555
00556 const int sm_i = (int)floor(i * (sm_w/w) +
00557 0.5 * (sm_w/w));
00558 const int sm_j = (int)floor(j * (sm_h/h) +
00559 0.5 * (sm_h/h));
00560
00561 const float featureVal = submap.getVal(sm_i,sm_j);
00562 features.push_back(featureVal);
00563 }
00564 }
00565
00566
00567 else if(itsStageTwoGetFeatures == AG_MAX)
00568 {
00569
00570 for (uint k = 0; k < numSubmaps; k++)
00571 {
00572 const Image<float> submap = chm->getRawCSmap(k);
00573 const float sm_w = (float)submap.getWidth();
00574 const float sm_h = (float)submap.getHeight();
00575
00576
00577 if(sm_w == w)
00578 {
00579 const float featureVal = submap.getVal((int)i,(int)j);
00580 features.push_back(featureVal);
00581 }
00582
00583 else if(sm_w > w)
00584 {
00585
00586 const int l_i = (int)floor(i * (sm_w/w));
00587 const int h_i = l_i + (int)round(sm_w/w);
00588
00589 const int l_j = (int)floor(j * (sm_h/h));
00590 const int h_j = l_j + (int)round(sm_h/h);
00591
00592 float maxVal = 0;
00593
00594 for(int u = l_i; u < h_i; u++)
00595 {
00596 for(int v = l_j; v < h_j; v++)
00597 {
00598 const float featureVal = submap.getVal(u,v);
00599 if(featureVal > maxVal)
00600 maxVal = featureVal;
00601
00602
00603 }
00604 }
00605 features.push_back(maxVal);
00606 }
00607
00608 else
00609 {
00610 const int sm_i = (int)floor(i * (sm_w/w) +
00611 0.5 * (sm_w/w));
00612 const int sm_j = (int)floor(j * (sm_h/h) +
00613 0.5 * (sm_h/h));
00614 const float featureVal = submap.getVal(sm_i,sm_j);
00615 features.push_back(featureVal);
00616 }
00617 }
00618 }
00619 else
00620 {
00621 LFATAL("Unknown method given for stage two feature extraction");
00622 }
00623
00624 featureObjects.push_back(features);
00625 featureX.push_back((int)i);
00626 featureY.push_back((int)j);
00627 featureID.push_back(ID);
00628 featureVal.push_back(val);
00629 n++;
00630 }
00631 }
00632
00633
00634 SimEventAttentionGateStageTwoObjects stObjects;
00635 stObjects.features = featureObjects;
00636 stObjects.segments = itsStageTwoSegments;
00637 stObjects.x = featureX;
00638 stObjects.y = featureY;
00639 stObjects.id = featureID;
00640 stObjects.val = featureVal;
00641 stObjects.n = n;
00642
00643
00644 itsStageTwoObjects.push_back(stObjects);
00645
00646
00647 if(itsStageTwoObjects.size() > itsMaxStageTwoFrames)
00648 itsStageTwoObjects.pop_front();
00649
00650 }
00651
00652
00653 void AttentionGateStd::computeFeatureDistance()
00654 {
00655
00656
00657
00658
00659 int otherObjects = 0;
00660
00661
00662 if(itsStageTwoObjects.size() > 1)
00663 {
00664
00665 for(std::deque<SimEventAttentionGateStageTwoObjects>::iterator
00666 stageTwoObjectsItr = itsStageTwoObjects.begin();
00667 stageTwoObjectsItr != (itsStageTwoObjects.end() - 2);
00668 ++stageTwoObjectsItr)
00669 {
00670 otherObjects += stageTwoObjectsItr->n;
00671 }
00672
00673
00674 Image<float> fdMatrix;
00675 const int thisObjects = (itsStageTwoObjects.back()).n;
00676 fdMatrix.resize(otherObjects,thisObjects);
00677 Image<float>::iterator fdMatrixItr = fdMatrix.beginw();
00678
00679 std::vector<std::vector<float> >::iterator featuresA =
00680 (itsStageTwoObjects.back()).features.begin();
00681
00682 while(featuresA != (itsStageTwoObjects.back()).features.end())
00683 {
00684 std::deque<SimEventAttentionGateStageTwoObjects>::iterator
00685 stageTwoObjectsBItr = itsStageTwoObjects.begin();
00686
00687
00688
00689 while(stageTwoObjectsBItr != (itsStageTwoObjects.end() - 2))
00690 {
00691 std::vector<std::vector<float> >::iterator featuresB =
00692 stageTwoObjectsBItr->features.begin();
00693
00694
00695 while(featuresB != stageTwoObjectsBItr->features.end())
00696 {
00697
00698 std::vector<float>::iterator fa = featuresA->begin();
00699 std::vector<float>::iterator fb = featuresB->begin();
00700 float ss = 0.0f;
00701
00702 while(fa != featuresA->end())
00703 {
00704 ss += pow(*fa++ - *fb++ ,2);
00705 }
00706
00707
00708 *fdMatrixItr++ = sqrt(ss/featuresA->size());
00709
00710 ++featuresB;
00711 }
00712 ++stageTwoObjectsBItr;
00713 }
00714 ++featuresA;
00715 }
00716
00717
00718 (itsStageTwoObjects.end() - 1)->fdistance = fdMatrix;
00719 }
00720 }
00721
00722
00723
00724 Image<float> AttentionGateStd::getLastAttentionMap() const
00725 {
00726 return itsLastAttentionMap;
00727 }
00728
00729
00730 Image<float> AttentionGateStd::getCurrentAttentionMap() const
00731 {
00732 return itsCurrentAttentionMap;
00733 }
00734
00735
00736
00737
00738
00739
00740
00741 AttentionGateStub::
00742 AttentionGateStub(OptionManager& mgr,
00743 const std::string& descrName,
00744 const std::string& tagName) :
00745 AttentionGate(mgr, descrName, tagName)
00746 { }
00747
00748
00749 AttentionGateStub::~AttentionGateStub()
00750 { }
00751
00752
00753
00754
00755
00756