00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #ifndef APPNEURO_TEST_TOPDOWNCONTEXT_C_UTC20050726230120_DEFINED
00039 #define APPNEURO_TEST_TOPDOWNCONTEXT_C_UTC20050726230120_DEFINED
00040
00041 #include "Component/GlobalOpts.H"
00042 #include "Component/ModelManager.H"
00043 #include "Component/ModelOptionDef.H"
00044 #include "Image/Image.H"
00045 #include "Image/MathOps.H"
00046 #include "Image/Pixels.H"
00047 #include "Image/ShapeOps.H"
00048 #include "Media/FrameSeries.H"
00049 #include "Media/MediaOpts.H"
00050 #include "Psycho/EyeSFile.H"
00051 #include "TIGS/FeatureExtractorFactory.H"
00052 #include "TIGS/Figures.H"
00053 #include "TIGS/SaliencyMapFeatureExtractor.H"
00054 #include "TIGS/Scorer.H"
00055 #include "TIGS/TigsOpts.H"
00056 #include "TIGS/TopdownLearnerFactory.H"
00057 #include "TIGS/TrainingSet.H"
00058 #include "Util/Assert.H"
00059 #include "Util/FileUtil.H"
00060 #include "Util/Pause.H"
00061 #include "Util/SimTime.H"
00062 #include "Util/StringConversions.H"
00063 #include "Util/csignals.H"
00064 #include "Util/fpe.H"
00065 #include "rutz/error_context.h"
00066 #include "rutz/sfmt.h"
00067 #include "rutz/shared_ptr.h"
00068 #include "rutz/trace.h"
00069
00070 #include <deque>
00071 #include <fstream>
00072 #include <iostream>
00073 #include <sstream>
00074 #include <unistd.h>
00075 #include <vector>
00076
00077
00078 static const ModelOptionDef OPT_DoBottomUpContext =
00079 { MODOPT_FLAG, "DoBottomUpContext", &MOC_TIGS, OPTEXP_CORE,
00080 "Whether to scale the top-down prediction by a bottom-up map",
00081 "bottom-up-context", '\0', "", "false" };
00082
00083
00084 static const ModelOptionDef OPT_TdcSaveSumo =
00085 { MODOPT_FLAG, "TdcSaveSumo", &MOC_TIGS, OPTEXP_CORE,
00086 "Whether to save the sumo display",
00087 "tdc-save-sumo", '\0', "", "false" };
00088
00089
00090 static const ModelOptionDef OPT_TdcSaveSumo2 =
00091 { MODOPT_FLAG, "TdcSaveSumo2", &MOC_TIGS, OPTEXP_CORE,
00092 "Whether to save the sumo2 display",
00093 "tdc-save-sumo2", '\0', "", "false" };
00094
00095
00096 static const ModelOptionDef OPT_TdcSaveMaps =
00097 { MODOPT_FLAG, "TdcSaveMaps", &MOC_TIGS, OPTEXP_CORE,
00098 "Whether to save the individual topdown context maps",
00099 "tdc-save-maps", '\0', "", "false" };
00100
00101
00102 static const ModelOptionDef OPT_TdcSaveMapsNormalized =
00103 { MODOPT_FLAG, "TdcSaveMapsNormalized", &MOC_TIGS, OPTEXP_CORE,
00104 "Whether to rescale maps to [0,255] when saving with --tdc-save-maps",
00105 "tdc-save-maps-normalized", '\0', "", "true" };
00106
00107
00108 static const ModelOptionDef OPT_TdcLocalMax =
00109 { MODOPT_ARG(float), "TdcLocalMax", &MOC_TIGS, OPTEXP_CORE,
00110 "Diameter of local max region to be applied to bias maps before scoring",
00111 "tdc-local-max", '\0', "<float>", "1" };
00112
00113
00114 static const ModelOptionDef OPT_TdcTemporalMax =
00115 { MODOPT_ARG(unsigned int), "TdcTemporalMax", &MOC_TIGS, OPTEXP_CORE,
00116 "Number of frames across which to apply a temporal max to "
00117 "bias maps before scoring",
00118 "tdc-temporal-max", '\0', "<integer>", "1" };
00119
00120
00121 static const ModelOptionDef OPT_TdcSaveRawData =
00122 { MODOPT_FLAG, "TdcSaveRawData", &MOC_TIGS, OPTEXP_CORE,
00123 "Whether to save a raw binary file containing the "
00124 "bottom-up and top-down maps",
00125 "tdc-save-raw-data", '\0', "", "false" };
00126
00127
00128 static const ModelOptionDef OPT_TdcRectifyTd =
00129 { MODOPT_FLAG, "TdcRectifyTd", &MOC_TIGS, OPTEXP_CORE,
00130 "Whether to rectify the top-down maps",
00131 "tdc-rectify-td", '\0', "", "false" };
00132
00133
00134 static const ModelOptionDef OPT_TopdownContextSpec =
00135 { MODOPT_ARG_STRING, "TopdownContextSpec", &MOC_TIGS, OPTEXP_CORE,
00136 "Specification string for a topdown context",
00137 "context", '\0', "<string>", "" };
00138
00139
00140 static const ModelOptionDef OPT_MoviePeriod =
00141 { MODOPT_ARG(SimTime), "MoviePeriod", &MOC_TIGS, OPTEXP_CORE,
00142 "Inter-frame period (or rate) of input movie",
00143 "movie-period", '\0', "<float>{s|ms|us|ns|Hz}", "0.0s" };
00144
00145
00146 static const ModelOptionDef OPT_MovieHertzObsolete =
00147 { MODOPT_OBSOLETE, "MovieHertzObsolete", &MOC_TIGS, OPTEXP_CORE,
00148 "Obsolete; use --movie-period instead with a SimTime value",
00149 "movie-hertz", '\0', "<float>", "0.0" };
00150
00151
00152 static const ModelOptionDef OPT_NumSkipFrames =
00153 { MODOPT_ARG(int), "NumSkipFrames", &MOC_TIGS, OPTEXP_CORE,
00154 "Number of frames to skip over at beginning of input movie",
00155 "num-skip-frames", '\0', "<int>", "0" };
00156
00157
00158 static const ModelOptionDef OPT_NumTrainingFrames =
00159 { MODOPT_ARG(int), "NumTrainingFrames", &MOC_TIGS, OPTEXP_CORE,
00160 "Number of input movie frames to use as training data",
00161 "num-training-frames", '\0', "<int>", "0" };
00162
00163
00164 static const ModelOptionDef OPT_NumTestingFrames =
00165 { MODOPT_ARG(int), "NumTestingFrames", &MOC_TIGS, OPTEXP_CORE,
00166 "Number of input movie frames to use as testing data",
00167 "num-testing-frames", '\0', "<int>", "0" };
00168
00169
00170 static const ModelOptionDef OPT_SaveGhostFrames =
00171 { MODOPT_ARG_STRING, "SaveGhostFrames", &MOC_TIGS, OPTEXP_CORE,
00172 "Name of a file in which to save ghost frame info that can "
00173 "be used to accelerate processing in a subsequent run",
00174 "save-ghost-frames", '\0', "<filanem>", "" };
00175
00176
00177 static const ModelOptionDef OPT_GhostInput =
00178 { MODOPT_ARG_STRING, "GhostInput", &MOC_TIGS, OPTEXP_CORE,
00179 "Read ghost frame info from this file",
00180 "ghost-input", '\0', "<filename>", "" };
00181
00182 namespace
00183 {
00184 template <class T>
00185 Image<T> asRow(const Image<T>& in)
00186 {
00187 GVX_TRACE(__PRETTY_FUNCTION__);
00188 return Image<T>(in.getArrayPtr(), in.getSize(), 1);
00189 }
00190 }
00191
00192 class Context : public ModelComponent
00193 {
00194 public:
00195 Context(OptionManager& mgr,
00196 const std::string& fx_type_,
00197 const std::string& learner_type_)
00198 :
00199 ModelComponent(mgr, "Context", "Context"),
00200 itsXptSavePrefix(&OPT_XptSavePrefix, this),
00201 itsDoBottomUp(&OPT_DoBottomUpContext, this),
00202 itsSaveSumo(&OPT_TdcSaveSumo, this),
00203 itsSaveSumo2(&OPT_TdcSaveSumo2, this),
00204 itsSaveMaps(&OPT_TdcSaveMaps, this),
00205 itsSaveMapsNormalized(&OPT_TdcSaveMapsNormalized, this),
00206 itsLocalMaxSize(&OPT_TdcLocalMax, this),
00207 itsTemporalMax(&OPT_TdcTemporalMax, this),
00208 itsSaveRawData(&OPT_TdcSaveRawData, this),
00209 itsRectifyTd(&OPT_TdcRectifyTd, this),
00210 itsTdata(new TrainingSet(this->getManager(), fx_type_)),
00211 itsFxType(fx_type_),
00212 itsFx(makeFeatureExtractor(mgr, fx_type_)),
00213 itsLearnerType(learner_type_),
00214 itsLearner(makeTopdownLearner(mgr, learner_type_)),
00215 itsCtxName("-" + itsFxType + "-" + itsLearnerType),
00216 itsScorer(),
00217 itsBottomUpScorer(),
00218 itsComboScorer(),
00219 itsRawDataFile(0)
00220 {
00221 if (theirBottomUp.is_invalid())
00222 {
00223 GVX_ERR_CONTEXT(rutz::sfmt
00224 ("constructing SaliencyMapFeatureExtractor "
00225 "on behalf of %s", itsCtxName.c_str()));
00226
00227 theirBottomUp.reset
00228 (new SaliencyMapFeatureExtractor(this->getManager()));
00229 this->addSubComponent(theirBottomUp);
00230 }
00231
00232 this->addSubComponent(itsTdata);
00233 this->addSubComponent(itsFx);
00234 this->addSubComponent(itsLearner);
00235 }
00236
00237 virtual void start2()
00238 {
00239 ASSERT(itsRawDataFile == 0);
00240
00241 if (itsSaveRawData.getVal())
00242 {
00243 std::string rawdatfname = this->contextName()+".rawdat";
00244 itsRawDataFile = fopen(rawdatfname.c_str(), "w");
00245 if (itsRawDataFile == 0)
00246 LFATAL("couldn't open %s for writing", rawdatfname.c_str());
00247 }
00248 }
00249
00250 virtual void stop1()
00251 {
00252 itsScorer.showScore("finalscore:" + this->contextName());
00253 if (itsDoBottomUp.getVal())
00254 {
00255 itsBottomUpScorer.showScore("finalscore:" + this->contextName()
00256 + "...bu-only");
00257 itsComboScorer.showScore("finalscore:" + this->contextName()
00258 + "+bu");
00259 }
00260
00261 std::ofstream ofs((this->contextName() + ".score").c_str());
00262 if (ofs.is_open())
00263 {
00264 itsScorer.writeScore(this->contextName(), ofs);
00265 if (itsDoBottomUp.getVal())
00266 {
00267 itsBottomUpScorer.writeScore(this->contextName() + "...bu-only", ofs);
00268 itsComboScorer.writeScore(this->contextName() + "+bu", ofs);
00269 }
00270 }
00271 ofs.close();
00272
00273 if (itsRawDataFile != 0)
00274 {
00275 fclose(itsRawDataFile);
00276 itsRawDataFile = 0;
00277 }
00278 }
00279
00280 std::string contextName() const
00281 {
00282 if (itsXptSavePrefix.getVal() == "")
00283 LFATAL("no xpt name specified!");
00284
00285 return itsXptSavePrefix.getVal() + itsCtxName;
00286 }
00287
00288 void loadTrainingSet(const std::string& xpt)
00289 {
00290 itsTdata->load(xpt + itsCtxName);
00291 }
00292
00293 void trainingFrame(const TigsInputFrame& fin,
00294 const Point2D<int>& eyepos, bool lastone,
00295 OutputFrameSeries& ofs)
00296 {
00297 GVX_ERR_CONTEXT(rutz::sfmt("handling training frame in Context %s",
00298 this->contextName().c_str()));
00299
00300 const Image<float> features = itsFx->extract(fin);
00301
00302 if (fin.origbounds().contains(eyepos))
00303 {
00304 const Image<float> biasmap =
00305 itsTdata->recordSample(eyepos, features);
00306
00307 if (!ofs.isVoid() && itsSaveSumo.getVal())
00308 ofs.writeRGB(makeSumoDisplay(fin, biasmap, *itsTdata,
00309 eyepos, features),
00310 this->contextName());
00311 }
00312
00313 if (lastone)
00314 itsTdata->save(this->contextName());
00315
00316 if (itsDoBottomUp.getVal())
00317 {
00318
00319
00320 (void) theirBottomUp->extract(fin);
00321 }
00322 }
00323
00324 Image<float> localMax(const Image<float>& in) const
00325 {
00326 if (itsLocalMaxSize.getVal() < 2.0f)
00327 return in;
00328
00329 const double rad = itsLocalMaxSize.getVal() / 2.0;
00330 const double rad2 = rad*rad;
00331 const int bound = int(rad+1.0);
00332
00333 Image<float> result(in.getDims(), ZEROS);
00334
00335 const int w = in.getWidth();
00336 const int h = in.getHeight();
00337
00338 for (int x = 0; x < w; ++x)
00339 for (int y = 0; y < h; ++y)
00340 {
00341 float maxv = in.getVal(x, y);
00342
00343 for (int i = -bound; i <= bound; ++i)
00344 for (int j = -bound; j <= bound; ++j)
00345 {
00346 if (i*i + j*j <= rad2 && result.coordsOk(x+i,y+j))
00347 maxv = std::max(maxv, 0.999f*in.getVal(x+i,y+j));
00348 }
00349
00350 result.setVal(x, y, maxv);
00351 }
00352
00353 return result;
00354 }
00355
00356 Image<float> combineBuTd(const Image<float>& bu,
00357 const Image<float>& td) const
00358 {
00359 Image<float> rtd = td;
00360 inplaceRectify(rtd);
00361
00362 return bu * rtd;
00363 }
00364
00365 Image<float> temporalMax(const Image<float>& img,
00366 std::deque<Image<float> >& q) const
00367 {
00368 q.push_front(img);
00369
00370 ASSERT(itsTemporalMax.getVal() > 0);
00371
00372 while (q.size() > itsTemporalMax.getVal())
00373 q.pop_back();
00374 ASSERT(q.size() <= itsTemporalMax.getVal());
00375
00376 Image<float> result = q[0];
00377
00378 for (uint i = 1; i < q.size(); ++i)
00379 result = takeMax(result, q[i]);
00380
00381 return result;
00382 }
00383
00384 void testFrame(const TigsInputFrame& fin,
00385 const Point2D<int>& eyepos,
00386 OutputFrameSeries& ofs)
00387 {
00388 GVX_ERR_CONTEXT(rutz::sfmt("handling training frame in Context %s",
00389 this->contextName().c_str()));
00390
00391 LINFO("context %s", this->contextName().c_str());
00392
00393 if (!fin.origbounds().contains(eyepos))
00394 return;
00395
00396 const Image<float> features = itsFx->extract(fin);
00397
00398 const Image<float> rawtdmap =
00399 reshape(itsLearner->getBiasMap(*itsTdata, asRow(features)),
00400 itsTdata->scaledInputDims());
00401
00402 Image<float> tdmap =
00403 this->temporalMax(this->localMax(rawtdmap), itsTdQ);
00404
00405 if (itsRectifyTd.getVal())
00406 inplaceRectify(tdmap);
00407
00408 const int pos = itsTdata->p2p(eyepos);
00409 itsScorer.score(this->contextName(), tdmap, pos);
00410
00411 if (itsDoBottomUp.getVal())
00412 {
00413 const Image<float> rawbumap =
00414 rescale(reshape(theirBottomUp->extract(fin),
00415 Dims(512 >> 4, 512 >> 4)),
00416 tdmap.getDims());
00417
00418 const Image<float> bumap =
00419 this->temporalMax(this->localMax(rawbumap), itsBuQ);
00420
00421 itsBottomUpScorer.score(this->contextName() + "...bu-only",
00422 bumap, pos);
00423
00424 const Image<float> rawcombomap =
00425 this->combineBuTd(rawbumap, rawtdmap);
00426
00427 const Image<float> combomap =
00428 this->temporalMax(this->localMax(rawcombomap), itsComboQ);
00429
00430 itsComboScorer.score(this->contextName() + "+bu",
00431 combomap, pos);
00432
00433 if (!ofs.isVoid() && itsSaveMaps.getVal())
00434 {
00435 const int flags =
00436 itsSaveMapsNormalized.getVal()
00437 ? (FLOAT_NORM_0_255 | FLOAT_NORM_WITH_SCALE)
00438 : FLOAT_NORM_PRESERVE;
00439
00440 makeDirectory(this->contextName() + "-maps");
00441
00442 ofs.writeFloat(tdmap, flags,
00443 this->contextName() + "-maps/td");
00444
00445 ofs.writeFloat(bumap, flags,
00446 this->contextName() + "-maps/bu");
00447
00448 ofs.writeFloat(combomap, flags,
00449 this->contextName() + "-maps/combo");
00450 }
00451
00452 if (!ofs.isVoid() && itsSaveSumo2.getVal())
00453 {
00454 ofs.writeRGB(makeSumoDisplay2(fin,
00455 tdmap,
00456 bumap,
00457 combomap,
00458 *itsTdata,
00459 eyepos),
00460 this->contextName() + "-sumo2");
00461 }
00462
00463 if (itsSaveRawData.getVal())
00464 {
00465 ASSERT(itsRawDataFile != 0);
00466
00467 ASSERT(bumap.getSize() == tdmap.getSize());
00468
00469 const float sz = float(bumap.getSize());
00470 fwrite(&sz, sizeof(float), 1, itsRawDataFile);
00471
00472 const float fpos = float(pos);
00473 fwrite(&fpos, sizeof(float), 1, itsRawDataFile);
00474
00475 const float buval = bumap[pos];
00476 fwrite(&buval, sizeof(float), 1, itsRawDataFile);
00477
00478 const float tdval = tdmap[pos];
00479 fwrite(&tdval, sizeof(float), 1, itsRawDataFile);
00480
00481 const float bumean = mean(bumap);
00482 fwrite(&bumean, sizeof(float), 1, itsRawDataFile);
00483
00484 const float tdmean = mean(tdmap);
00485 fwrite(&tdmean, sizeof(float), 1, itsRawDataFile);
00486
00487 const float bustd = stdev(bumap);
00488 fwrite(&bustd, sizeof(float), 1, itsRawDataFile);
00489
00490 const float tdstd = stdev(tdmap);
00491 fwrite(&tdstd, sizeof(float), 1, itsRawDataFile);
00492
00493 const float buz = (buval - bumean) / bustd;
00494 fwrite(&buz, sizeof(float), 1, itsRawDataFile);
00495
00496 const float tdz = (tdval - tdmean) / tdstd;
00497 fwrite(&tdz, sizeof(float), 1, itsRawDataFile);
00498
00499 fwrite(bumap.getArrayPtr(), sizeof(float),
00500 bumap.getSize(), itsRawDataFile);
00501
00502 fwrite(tdmap.getArrayPtr(), sizeof(float),
00503 tdmap.getSize(), itsRawDataFile);
00504
00505 fflush(itsRawDataFile);
00506 }
00507 }
00508
00509 if (!ofs.isVoid() && itsSaveSumo.getVal())
00510 ofs.writeRGB(makeSumoDisplay(fin, tdmap, *itsTdata,
00511 eyepos, features),
00512 this->contextName());
00513 }
00514
00515 private:
00516 OModelParam<std::string> itsXptSavePrefix;
00517 OModelParam<bool> itsDoBottomUp;
00518 OModelParam<bool> itsSaveSumo;
00519 OModelParam<bool> itsSaveSumo2;
00520 OModelParam<bool> itsSaveMaps;
00521 OModelParam<bool> itsSaveMapsNormalized;
00522 OModelParam<float> itsLocalMaxSize;
00523 OModelParam<uint> itsTemporalMax;
00524 OModelParam<bool> itsSaveRawData;
00525 OModelParam<bool> itsRectifyTd;
00526
00527 const nub::ref<TrainingSet> itsTdata;
00528 const std::string itsFxType;
00529 const nub::ref<FeatureExtractor> itsFx;
00530 const std::string itsLearnerType;
00531 const nub::ref<TopdownLearner> itsLearner;
00532 static nub::soft_ref<FeatureExtractor> theirBottomUp;
00533 std::string itsCtxName;
00534 MulticastScorer itsScorer;
00535 MulticastScorer itsBottomUpScorer;
00536 MulticastScorer itsComboScorer;
00537 std::deque<Image<float> > itsBuQ;
00538 std::deque<Image<float> > itsTdQ;
00539 std::deque<Image<float> > itsComboQ;
00540 FILE* itsRawDataFile;
00541 };
00542
00543 nub::soft_ref<FeatureExtractor> Context::theirBottomUp;
00544
00545 class TigsJob : public ModelComponent
00546 {
00547 public:
00548 TigsJob(OptionManager& mgr)
00549 :
00550 ModelComponent(mgr, "TigsJob", "TigsJob"),
00551 itsXptSavePrefix(&OPT_XptSavePrefix, this),
00552 itsContextSpec(&OPT_TopdownContextSpec, this),
00553 itsMoviePeriod(&OPT_MoviePeriod, this),
00554 itsNumSkipFrames(&OPT_NumSkipFrames, this),
00555 itsNumTrainingFrames(&OPT_NumTrainingFrames, this),
00556 itsNumTestingFrames(&OPT_NumTestingFrames, this),
00557 itsSaveGhostFrames(&OPT_SaveGhostFrames, this),
00558 itsObsolete1(&OPT_MovieHertzObsolete, this)
00559 {}
00560
00561 virtual void start2()
00562 {
00563 rutz::prof::prof_summary_file_name
00564 ((itsXptSavePrefix.getVal() + "-prof.out").c_str());
00565
00566 if (itsSaveGhostFrames.getVal().length() > 0)
00567 {
00568 itsGhostOutput =
00569 rutz::shared_ptr<std::ofstream>
00570 (new std::ofstream(itsSaveGhostFrames.getVal().c_str()));
00571
00572 if (!itsGhostOutput->is_open())
00573 LFATAL("couldn't open '%s' for writing",
00574 itsSaveGhostFrames.getVal().c_str());
00575 }
00576 }
00577
00578 virtual void paramChanged(ModelParamBase* const param,
00579 const bool valueChanged,
00580 ParamClient::ChangeStatus* status)
00581 {
00582 ModelComponent::paramChanged(param, valueChanged, status);
00583
00584 if (param == &itsContextSpec)
00585 {
00586 if (itsContextSpec.getVal() != "")
00587 this->addContext(itsContextSpec.getVal());
00588 }
00589 }
00590
00591 void addContext(const std::string& spec)
00592 {
00593 GVX_ERR_CONTEXT(rutz::sfmt
00594 ("adding context for spec %s", spec.c_str()));
00595
00596 std::string::size_type comma = spec.find_first_of(',');
00597 if (comma == spec.npos)
00598 LFATAL("missing comma in context spec '%s'", spec.c_str());
00599 if (comma+1 >= spec.length())
00600 LFATAL("bogus context spec '%s'", spec.c_str());
00601
00602 std::string fx_type = spec.substr(0, comma);
00603 std::string learner_type = spec.substr(comma+1);
00604
00605 LINFO("fxtype=%s, learnertype=%s",
00606 fx_type.c_str(), learner_type.c_str());
00607
00608 nub::ref<Context> ctx(new Context(this->getManager(),
00609 fx_type, learner_type));
00610
00611 this->addSubComponent(ctx);
00612 itsContexts.push_back(ctx);
00613
00614 ctx->exportOptions(MC_RECURSE);
00615 }
00616
00617 void loadTrainingSet(const std::string& xpt)
00618 {
00619 ASSERT(itsNumTrainingFrames.getVal() == 0);
00620
00621 for (size_t c = 0; c < itsContexts.size(); ++c)
00622 itsContexts[c]->loadTrainingSet(xpt);
00623 }
00624
00625
00626 bool handleFrame(int nframe,
00627 const TigsInputFrame& fin,
00628 const Point2D<int>& eyepos,
00629 const bool islast,
00630 OutputFrameSeries& ofs)
00631 {
00632 GVX_ERR_CONTEXT(rutz::sfmt("handling input frame %d in TigsJob",
00633 nframe));
00634
00635 if (nframe < itsNumSkipFrames.getVal())
00636 return true;
00637
00638 if (nframe >= (itsNumSkipFrames.getVal()
00639 +itsNumTrainingFrames.getVal()
00640 +itsNumTestingFrames.getVal()))
00641 {
00642 LINFO("exceeded skip+train+test frames");
00643 return false;
00644 }
00645
00646 if (itsGhostOutput.is_valid())
00647 {
00648 *itsGhostOutput << fin.toGhostString() << std::endl;
00649 }
00650
00651 for (size_t c = 0; c < itsContexts.size(); ++c)
00652 {
00653 if (nframe < itsNumSkipFrames.getVal()+itsNumTrainingFrames.getVal())
00654 {
00655 const bool lasttraining =
00656 islast
00657 ||
00658 (nframe+1 == (itsNumSkipFrames.getVal()
00659 +itsNumTrainingFrames.getVal()));
00660
00661 itsContexts[c]->trainingFrame(fin, eyepos,
00662 lasttraining, ofs);
00663 }
00664 else if (nframe < itsNumSkipFrames.getVal()+itsNumTrainingFrames.getVal()+itsNumTestingFrames.getVal())
00665 {
00666 itsContexts[c]->testFrame(fin, eyepos, ofs);
00667 }
00668 }
00669
00670 return true;
00671 }
00672
00673 SimTime movieFrameLength() const
00674 {
00675 ASSERT(itsMoviePeriod.getVal() > SimTime::ZERO());
00676 return itsMoviePeriod.getVal();
00677 }
00678
00679 std::string getSavePrefix() const { return itsXptSavePrefix.getVal(); }
00680
00681 private:
00682 OModelParam<std::string> itsXptSavePrefix;
00683 OModelParam<std::string> itsContextSpec;
00684 OModelParam<SimTime> itsMoviePeriod;
00685 OModelParam<int> itsNumSkipFrames;
00686 OModelParam<int> itsNumTrainingFrames;
00687 OModelParam<int> itsNumTestingFrames;
00688 OModelParam<std::string> itsSaveGhostFrames;
00689 OModelParam<bool> itsObsolete1;
00690 std::vector<nub::ref<Context> > itsContexts;
00691 rutz::shared_ptr<std::ofstream> itsGhostOutput;
00692 };
00693
00694 class TigsInputFrameSeries : public ModelComponent
00695 {
00696 public:
00697 TigsInputFrameSeries(OptionManager& mgr)
00698 :
00699 ModelComponent(mgr, "TigsInputFrameSeries", "TigsInputFrameSeries"),
00700 itsGhostInput(&OPT_GhostInput, this),
00701 itsIfs(new InputFrameSeries(mgr)),
00702 itsFirst(true)
00703 {
00704 this->addSubComponent(itsIfs);
00705 }
00706
00707 virtual void paramChanged(ModelParamBase* param,
00708 const bool valueChanged,
00709 ChangeStatus* status)
00710 {
00711 if (param == &itsGhostInput && valueChanged)
00712 {
00713 if (itsGhostInput.getVal().length() == 0)
00714 {
00715 ASSERT(itsIfs.is_valid() == false);
00716
00717
00718 itsGhostFile.close();
00719
00720
00721 itsIfs.reset(new InputFrameSeries(getManager()));
00722 this->addSubComponent(itsIfs);
00723 itsIfs->exportOptions(MC_RECURSE);
00724 itsFirst = true;
00725 }
00726 else
00727 {
00728 ASSERT(itsIfs.is_valid() == true);
00729 this->removeSubComponent(*itsIfs);
00730 itsIfs.reset(0);
00731
00732 itsGhostFile.open(itsGhostInput.getVal().c_str());
00733 if (!itsGhostFile.is_open())
00734 LFATAL("couldn't open file '%s' for reading",
00735 itsGhostInput.getVal().c_str());
00736
00737 LINFO("reading ghost frames from '%s'",
00738 itsGhostInput.getVal().c_str());
00739
00740
00741
00742 if (!std::getline(itsGhostFile, itsNextLine))
00743 itsNextLine = "";
00744 itsFirst = false;
00745
00746 rutz::shared_ptr<TigsInputFrame> f =
00747 TigsInputFrame::fromGhostString(itsNextLine);
00748
00749 getManager().setOptionValString
00750 (&OPT_InputFrameDims,
00751 convertToString(f->origbounds().dims()));
00752 }
00753
00754
00755
00756 ASSERT(itsIfs.is_valid() != itsGhostFile.is_open());
00757 }
00758 }
00759
00760 rutz::shared_ptr<TigsInputFrame> getFrame(const SimTime stime,
00761 bool* islast)
00762 {
00763 if (itsIfs.is_valid())
00764 {
00765 if (itsFirst)
00766 {
00767 itsIfs->updateNext();
00768 itsNextFrame = itsIfs->readRGB();
00769 itsFirst = false;
00770 }
00771
00772 if (!itsNextFrame.initialized())
00773 return rutz::shared_ptr<TigsInputFrame>();
00774
00775
00776 itsIfs->updateNext();
00777 Image<PixRGB<byte> > frame = itsIfs->readRGB();
00778 frame.swap(itsNextFrame);
00779
00780 ASSERT(frame.initialized());
00781
00782 *islast = (itsNextFrame.initialized() == false);
00783
00784 return rutz::shared_ptr<TigsInputFrame>
00785 (new TigsInputFrame(frame, stime));
00786 }
00787 else
00788 {
00789 ASSERT(itsGhostFile.is_open());
00790
00791 if (itsNextLine.length() == 0)
00792 return rutz::shared_ptr<TigsInputFrame>();
00793
00794
00795 std::string line;
00796 if (!std::getline(itsGhostFile, line))
00797 line = "";
00798 line.swap(itsNextLine);
00799
00800 ASSERT(line.length() > 0);
00801
00802 *islast = (itsNextLine.length() == 0);
00803
00804 rutz::shared_ptr<TigsInputFrame> result =
00805 TigsInputFrame::fromGhostString(line);
00806
00807 if (result->t() != stime)
00808 LFATAL("wrong time in ghost frame: expected %.2fms "
00809 "but got %.2fms", stime.msecs(), result->t().msecs());
00810
00811 LINFO("got ghost frame: %s", line.c_str());
00812
00813 return result;
00814 }
00815 }
00816
00817 private:
00818 OModelParam<std::string> itsGhostInput;
00819
00820 nub::soft_ref<InputFrameSeries> itsIfs;
00821 std::ifstream itsGhostFile;
00822
00823 bool itsFirst;
00824 Image<PixRGB<byte> > itsNextFrame;
00825 std::string itsNextLine;
00826 };
00827
00828 int submain(int argc, const char** argv)
00829 {
00830 GVX_TRACE("test-TopdownContext-main");
00831
00832 volatile int signum = 0;
00833 catchsignals(&signum);
00834
00835 rutz::prof::print_at_exit(true);
00836
00837 fpExceptionsUnlock();
00838 fpExceptionsOff();
00839 fpExceptionsLock();
00840
00841 ModelManager mgr("topdown context tester");
00842
00843 nub::ref<TigsInputFrameSeries> ifs(new TigsInputFrameSeries(mgr));
00844 mgr.addSubComponent(ifs);
00845
00846 nub::ref<OutputFrameSeries> ofs(new OutputFrameSeries(mgr));
00847 mgr.addSubComponent(ofs);
00848
00849 nub::ref<TigsJob> job(new TigsJob(mgr));
00850 mgr.addSubComponent(job);
00851
00852 nub::ref<EyeSFile> eyeS(new EyeSFile(mgr));
00853 mgr.addSubComponent(eyeS);
00854
00855 mgr.exportOptions(MC_RECURSE);
00856
00857 mgr.setOptionValString(&OPT_UseRandom, "false");
00858
00859 if (mgr.parseCommandLine(argc, argv,
00860 "[load-pfx1 [load-pfx2 [...]]]",
00861 0, -1) == false)
00862 return 1;
00863
00864 ofs->setModelParamVal("OutputMPEGStreamFrameRate", int(24),
00865 MC_RECURSE | MC_IGNORE_MISSING);
00866 ofs->setModelParamVal("OutputMPEGStreamBitRate", int(2500000),
00867 MC_RECURSE | MC_IGNORE_MISSING);
00868
00869 mgr.start();
00870
00871 {
00872 std::ofstream ofs((job->getSavePrefix() + ".model").c_str());
00873 if (ofs.is_open())
00874 {
00875 mgr.printout(ofs);
00876 }
00877 ofs.close();
00878 }
00879
00880 mgr.printout(std::cout);
00881
00882 for (uint e = 0; e < mgr.numExtraArgs(); ++e)
00883 job->loadTrainingSet(mgr.getExtraArg(e));
00884
00885 PauseWaiter p;
00886
00887 int nframes = 0;
00888
00889 while (1)
00890 {
00891 GVX_TRACE("frames loop");
00892
00893 if (signum != 0)
00894 {
00895 LINFO("caught signal %s; quitting", signame(signum));
00896 break;
00897 }
00898
00899 if (p.checkPause())
00900 continue;
00901
00902 LINFO("trying frame %d", nframes);
00903
00904 const SimTime stime = job->movieFrameLength() * (nframes+1);
00905
00906 bool islast;
00907 rutz::shared_ptr<TigsInputFrame> fin =
00908 ifs->getFrame(stime, &islast);
00909
00910 if (fin.get() == 0)
00911 {
00912 LINFO("input exhausted; quitting");
00913 break;
00914 }
00915
00916 const Point2D<int> eyepos = eyeS->readUpTo(stime);
00917
00918 LINFO("simtime %.6fs, movie frame %d, eye sample %d, ratio %f, "
00919 "eyepos (x=%d, y=%d)",
00920 stime.secs(), nframes+1, eyeS->lineNumber(),
00921 double(eyeS->lineNumber())/double(nframes+1),
00922 eyepos.i, eyepos.j);
00923
00924 if (!job->handleFrame(nframes, *fin, eyepos, islast, *ofs))
00925 break;
00926
00927 ofs->updateNext();
00928
00929 ++nframes;
00930 }
00931
00932 mgr.stop();
00933
00934 return 0;
00935 }
00936
00937 int main(int argc, const char** argv)
00938 {
00939 try {
00940 submain(argc, argv);
00941 } catch(...) {
00942 REPORT_CURRENT_EXCEPTION;
00943 std::terminate();
00944 }
00945 }
00946
00947
00948
00949
00950
00951
00952
00953 #endif // !APPNEURO_TEST_TOPDOWNCONTEXT_C_UTC20050726230120DEFINED