00001 /*!@file TIGS/test-TopdownContext.C */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Rob Peters <rjpeters at usc dot edu> 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/TIGS/test-TopdownContext.C $ 00035 // $Id: test-TopdownContext.C 9412 2008-03-10 23:10:15Z farhan $ 00036 // 00037 00038 #ifndef APPNEURO_TEST_TOPDOWNCONTEXT_C_UTC20050726230120_DEFINED 00039 #define APPNEURO_TEST_TOPDOWNCONTEXT_C_UTC20050726230120_DEFINED 00040 00041 #include "Component/GlobalOpts.H" 00042 #include "Component/ModelManager.H" 00043 #include "Component/ModelOptionDef.H" 00044 #include "Image/Image.H" 00045 #include "Image/MathOps.H" 00046 #include "Image/Pixels.H" 00047 #include "Image/ShapeOps.H" 00048 #include "Media/FrameSeries.H" 00049 #include "Media/MediaOpts.H" 00050 #include "Psycho/EyeSFile.H" 00051 #include "TIGS/FeatureExtractorFactory.H" 00052 #include "TIGS/Figures.H" 00053 #include "TIGS/SaliencyMapFeatureExtractor.H" 00054 #include "TIGS/Scorer.H" 00055 #include "TIGS/TigsOpts.H" 00056 #include "TIGS/TopdownLearnerFactory.H" 00057 #include "TIGS/TrainingSet.H" 00058 #include "Util/Assert.H" 00059 #include "Util/FileUtil.H" 00060 #include "Util/Pause.H" 00061 #include "Util/SimTime.H" 00062 #include "Util/StringConversions.H" 00063 #include "Util/csignals.H" 00064 #include "Util/fpe.H" 00065 #include "rutz/error_context.h" 00066 #include "rutz/sfmt.h" 00067 #include "rutz/shared_ptr.h" 00068 #include "rutz/trace.h" 00069 00070 #include <deque> 00071 #include <fstream> 00072 #include <iostream> 00073 #include <sstream> 00074 #include <unistd.h> 00075 #include <vector> 00076 00077 // Used by: Context 00078 static const ModelOptionDef OPT_DoBottomUpContext = 00079 { MODOPT_FLAG, "DoBottomUpContext", &MOC_TIGS, OPTEXP_CORE, 00080 "Whether to scale the top-down prediction by a bottom-up map", 00081 "bottom-up-context", '\0', "", "false" }; 00082 00083 // Used by: Context 00084 static const ModelOptionDef OPT_TdcSaveSumo = 00085 { MODOPT_FLAG, "TdcSaveSumo", &MOC_TIGS, OPTEXP_CORE, 00086 "Whether to save the sumo display", 00087 "tdc-save-sumo", '\0', "", "false" }; 00088 00089 // Used by: Context 00090 static const ModelOptionDef OPT_TdcSaveSumo2 = 00091 { MODOPT_FLAG, "TdcSaveSumo2", &MOC_TIGS, OPTEXP_CORE, 00092 "Whether to save the sumo2 display", 00093 "tdc-save-sumo2", '\0', "", "false" }; 00094 00095 // Used by: Context 00096 static const ModelOptionDef OPT_TdcSaveMaps = 00097 { MODOPT_FLAG, "TdcSaveMaps", &MOC_TIGS, OPTEXP_CORE, 00098 "Whether to save the individual topdown context maps", 00099 "tdc-save-maps", '\0', "", "false" }; 00100 00101 // Used by: Context 00102 static const ModelOptionDef OPT_TdcSaveMapsNormalized = 00103 { MODOPT_FLAG, "TdcSaveMapsNormalized", &MOC_TIGS, OPTEXP_CORE, 00104 "Whether to rescale maps to [0,255] when saving with --tdc-save-maps", 00105 "tdc-save-maps-normalized", '\0', "", "true" }; 00106 00107 // Used by: Context 00108 static const ModelOptionDef OPT_TdcLocalMax = 00109 { MODOPT_ARG(float), "TdcLocalMax", &MOC_TIGS, OPTEXP_CORE, 00110 "Diameter of local max region to be applied to bias maps before scoring", 00111 "tdc-local-max", '\0', "<float>", "1" }; 00112 00113 // Used by: Context 00114 static const ModelOptionDef OPT_TdcTemporalMax = 00115 { MODOPT_ARG(unsigned int), "TdcTemporalMax", &MOC_TIGS, OPTEXP_CORE, 00116 "Number of frames across which to apply a temporal max to " 00117 "bias maps before scoring", 00118 "tdc-temporal-max", '\0', "<integer>", "1" }; 00119 00120 // Used by: Context 00121 static const ModelOptionDef OPT_TdcSaveRawData = 00122 { MODOPT_FLAG, "TdcSaveRawData", &MOC_TIGS, OPTEXP_CORE, 00123 "Whether to save a raw binary file containing the " 00124 "bottom-up and top-down maps", 00125 "tdc-save-raw-data", '\0', "", "false" }; 00126 00127 // Used by: Context 00128 static const ModelOptionDef OPT_TdcRectifyTd = 00129 { MODOPT_FLAG, "TdcRectifyTd", &MOC_TIGS, OPTEXP_CORE, 00130 "Whether to rectify the top-down maps", 00131 "tdc-rectify-td", '\0', "", "false" }; 00132 00133 // Used by: TigsJob 00134 static const ModelOptionDef OPT_TopdownContextSpec = 00135 { MODOPT_ARG_STRING, "TopdownContextSpec", &MOC_TIGS, OPTEXP_CORE, 00136 "Specification string for a topdown context", 00137 "context", '\0', "<string>", "" }; 00138 00139 // Used by: TigsJob 00140 static const ModelOptionDef OPT_MoviePeriod = 00141 { MODOPT_ARG(SimTime), "MoviePeriod", &MOC_TIGS, OPTEXP_CORE, 00142 "Inter-frame period (or rate) of input movie", 00143 "movie-period", '\0', "<float>{s|ms|us|ns|Hz}", "0.0s" }; 00144 00145 // Obsolete 00146 static const ModelOptionDef OPT_MovieHertzObsolete = 00147 { MODOPT_OBSOLETE, "MovieHertzObsolete", &MOC_TIGS, OPTEXP_CORE, 00148 "Obsolete; use --movie-period instead with a SimTime value", 00149 "movie-hertz", '\0', "<float>", "0.0" }; 00150 00151 // Used by: TigsJob 00152 static const ModelOptionDef OPT_NumSkipFrames = 00153 { MODOPT_ARG(int), "NumSkipFrames", &MOC_TIGS, OPTEXP_CORE, 00154 "Number of frames to skip over at beginning of input movie", 00155 "num-skip-frames", '\0', "<int>", "0" }; 00156 00157 // Used by: TigsJob 00158 static const ModelOptionDef OPT_NumTrainingFrames = 00159 { MODOPT_ARG(int), "NumTrainingFrames", &MOC_TIGS, OPTEXP_CORE, 00160 "Number of input movie frames to use as training data", 00161 "num-training-frames", '\0', "<int>", "0" }; 00162 00163 // Used by: TigsJob 00164 static const ModelOptionDef OPT_NumTestingFrames = 00165 { MODOPT_ARG(int), "NumTestingFrames", &MOC_TIGS, OPTEXP_CORE, 00166 "Number of input movie frames to use as testing data", 00167 "num-testing-frames", '\0', "<int>", "0" }; 00168 00169 // Used by: TigsJob 00170 static const ModelOptionDef OPT_SaveGhostFrames = 00171 { MODOPT_ARG_STRING, "SaveGhostFrames", &MOC_TIGS, OPTEXP_CORE, 00172 "Name of a file in which to save ghost frame info that can " 00173 "be used to accelerate processing in a subsequent run", 00174 "save-ghost-frames", '\0', "<filanem>", "" }; 00175 00176 // Used by: TigsInputFrameSeries 00177 static const ModelOptionDef OPT_GhostInput = 00178 { MODOPT_ARG_STRING, "GhostInput", &MOC_TIGS, OPTEXP_CORE, 00179 "Read ghost frame info from this file", 00180 "ghost-input", '\0', "<filename>", "" }; 00181 00182 namespace 00183 { 00184 template <class T> 00185 Image<T> asRow(const Image<T>& in) 00186 { 00187 GVX_TRACE(__PRETTY_FUNCTION__); 00188 return Image<T>(in.getArrayPtr(), in.getSize(), 1); 00189 } 00190 } 00191 00192 class Context : public ModelComponent 00193 { 00194 public: 00195 Context(OptionManager& mgr, 00196 const std::string& fx_type_, 00197 const std::string& learner_type_) 00198 : 00199 ModelComponent(mgr, "Context", "Context"), 00200 itsXptSavePrefix(&OPT_XptSavePrefix, this), 00201 itsDoBottomUp(&OPT_DoBottomUpContext, this), 00202 itsSaveSumo(&OPT_TdcSaveSumo, this), 00203 itsSaveSumo2(&OPT_TdcSaveSumo2, this), 00204 itsSaveMaps(&OPT_TdcSaveMaps, this), 00205 itsSaveMapsNormalized(&OPT_TdcSaveMapsNormalized, this), 00206 itsLocalMaxSize(&OPT_TdcLocalMax, this), 00207 itsTemporalMax(&OPT_TdcTemporalMax, this), 00208 itsSaveRawData(&OPT_TdcSaveRawData, this), 00209 itsRectifyTd(&OPT_TdcRectifyTd, this), 00210 itsTdata(new TrainingSet(this->getManager(), fx_type_)), 00211 itsFxType(fx_type_), 00212 itsFx(makeFeatureExtractor(mgr, fx_type_)), 00213 itsLearnerType(learner_type_), 00214 itsLearner(makeTopdownLearner(mgr, learner_type_)), 00215 itsCtxName("-" + itsFxType + "-" + itsLearnerType), 00216 itsScorer(), 00217 itsBottomUpScorer(), 00218 itsComboScorer(), 00219 itsRawDataFile(0) 00220 { 00221 if (theirBottomUp.is_invalid()) 00222 { 00223 GVX_ERR_CONTEXT(rutz::sfmt 00224 ("constructing SaliencyMapFeatureExtractor " 00225 "on behalf of %s", itsCtxName.c_str())); 00226 00227 theirBottomUp.reset 00228 (new SaliencyMapFeatureExtractor(this->getManager())); 00229 this->addSubComponent(theirBottomUp); 00230 } 00231 00232 this->addSubComponent(itsTdata); 00233 this->addSubComponent(itsFx); 00234 this->addSubComponent(itsLearner); 00235 } 00236 00237 virtual void start2() 00238 { 00239 ASSERT(itsRawDataFile == 0); 00240 00241 if (itsSaveRawData.getVal()) 00242 { 00243 std::string rawdatfname = this->contextName()+".rawdat"; 00244 itsRawDataFile = fopen(rawdatfname.c_str(), "w"); 00245 if (itsRawDataFile == 0) 00246 LFATAL("couldn't open %s for writing", rawdatfname.c_str()); 00247 } 00248 } 00249 00250 virtual void stop1() 00251 { 00252 itsScorer.showScore("finalscore:" + this->contextName()); 00253 if (itsDoBottomUp.getVal()) 00254 { 00255 itsBottomUpScorer.showScore("finalscore:" + this->contextName() 00256 + "...bu-only"); 00257 itsComboScorer.showScore("finalscore:" + this->contextName() 00258 + "+bu"); 00259 } 00260 00261 std::ofstream ofs((this->contextName() + ".score").c_str()); 00262 if (ofs.is_open()) 00263 { 00264 itsScorer.writeScore(this->contextName(), ofs); 00265 if (itsDoBottomUp.getVal()) 00266 { 00267 itsBottomUpScorer.writeScore(this->contextName() + "...bu-only", ofs); 00268 itsComboScorer.writeScore(this->contextName() + "+bu", ofs); 00269 } 00270 } 00271 ofs.close(); 00272 00273 if (itsRawDataFile != 0) 00274 { 00275 fclose(itsRawDataFile); 00276 itsRawDataFile = 0; 00277 } 00278 } 00279 00280 std::string contextName() const 00281 { 00282 if (itsXptSavePrefix.getVal() == "") 00283 LFATAL("no xpt name specified!"); 00284 00285 return itsXptSavePrefix.getVal() + itsCtxName; 00286 } 00287 00288 void loadTrainingSet(const std::string& xpt) 00289 { 00290 itsTdata->load(xpt + itsCtxName); 00291 } 00292 00293 void trainingFrame(const TigsInputFrame& fin, 00294 const Point2D<int>& eyepos, bool lastone, 00295 OutputFrameSeries& ofs) 00296 { 00297 GVX_ERR_CONTEXT(rutz::sfmt("handling training frame in Context %s", 00298 this->contextName().c_str())); 00299 00300 const Image<float> features = itsFx->extract(fin); 00301 00302 if (fin.origbounds().contains(eyepos)) 00303 { 00304 const Image<float> biasmap = 00305 itsTdata->recordSample(eyepos, features); 00306 00307 if (!ofs.isVoid() && itsSaveSumo.getVal()) 00308 ofs.writeRGB(makeSumoDisplay(fin, biasmap, *itsTdata, 00309 eyepos, features), 00310 this->contextName()); 00311 } 00312 00313 if (lastone) 00314 itsTdata->save(this->contextName()); 00315 00316 if (itsDoBottomUp.getVal()) 00317 { 00318 // this is a no-op, except that we want to force the features 00319 // to be computed so that they can be cached and saved 00320 (void) theirBottomUp->extract(fin); 00321 } 00322 } 00323 00324 Image<float> localMax(const Image<float>& in) const 00325 { 00326 if (itsLocalMaxSize.getVal() < 2.0f) 00327 return in; 00328 00329 const double rad = itsLocalMaxSize.getVal() / 2.0; 00330 const double rad2 = rad*rad; 00331 const int bound = int(rad+1.0); 00332 00333 Image<float> result(in.getDims(), ZEROS); 00334 00335 const int w = in.getWidth(); 00336 const int h = in.getHeight(); 00337 00338 for (int x = 0; x < w; ++x) 00339 for (int y = 0; y < h; ++y) 00340 { 00341 float maxv = in.getVal(x, y); 00342 00343 for (int i = -bound; i <= bound; ++i) 00344 for (int j = -bound; j <= bound; ++j) 00345 { 00346 if (i*i + j*j <= rad2 && result.coordsOk(x+i,y+j)) 00347 maxv = std::max(maxv, 0.999f*in.getVal(x+i,y+j)); 00348 } 00349 00350 result.setVal(x, y, maxv); 00351 } 00352 00353 return result; 00354 } 00355 00356 Image<float> combineBuTd(const Image<float>& bu, 00357 const Image<float>& td) const 00358 { 00359 Image<float> rtd = td; 00360 inplaceRectify(rtd); 00361 00362 return bu * rtd; 00363 } 00364 00365 Image<float> temporalMax(const Image<float>& img, 00366 std::deque<Image<float> >& q) const 00367 { 00368 q.push_front(img); 00369 00370 ASSERT(itsTemporalMax.getVal() > 0); 00371 00372 while (q.size() > itsTemporalMax.getVal()) 00373 q.pop_back(); 00374 ASSERT(q.size() <= itsTemporalMax.getVal()); 00375 00376 Image<float> result = q[0]; 00377 00378 for (uint i = 1; i < q.size(); ++i) 00379 result = takeMax(result, q[i]); 00380 00381 return result; 00382 } 00383 00384 void testFrame(const TigsInputFrame& fin, 00385 const Point2D<int>& eyepos, 00386 OutputFrameSeries& ofs) 00387 { 00388 GVX_ERR_CONTEXT(rutz::sfmt("handling training frame in Context %s", 00389 this->contextName().c_str())); 00390 00391 LINFO("context %s", this->contextName().c_str()); 00392 00393 if (!fin.origbounds().contains(eyepos)) 00394 return; 00395 00396 const Image<float> features = itsFx->extract(fin); 00397 00398 const Image<float> rawtdmap = 00399 reshape(itsLearner->getBiasMap(*itsTdata, asRow(features)), 00400 itsTdata->scaledInputDims()); 00401 00402 Image<float> tdmap = 00403 this->temporalMax(this->localMax(rawtdmap), itsTdQ); 00404 00405 if (itsRectifyTd.getVal()) 00406 inplaceRectify(tdmap); 00407 00408 const int pos = itsTdata->p2p(eyepos); 00409 itsScorer.score(this->contextName(), tdmap, pos); 00410 00411 if (itsDoBottomUp.getVal()) 00412 { 00413 const Image<float> rawbumap = 00414 rescale(reshape(theirBottomUp->extract(fin), 00415 Dims(512 >> 4, 512 >> 4)), 00416 tdmap.getDims()); 00417 00418 const Image<float> bumap = 00419 this->temporalMax(this->localMax(rawbumap), itsBuQ); 00420 00421 itsBottomUpScorer.score(this->contextName() + "...bu-only", 00422 bumap, pos); 00423 00424 const Image<float> rawcombomap = 00425 this->combineBuTd(rawbumap, rawtdmap); 00426 00427 const Image<float> combomap = 00428 this->temporalMax(this->localMax(rawcombomap), itsComboQ); 00429 00430 itsComboScorer.score(this->contextName() + "+bu", 00431 combomap, pos); 00432 00433 if (!ofs.isVoid() && itsSaveMaps.getVal()) 00434 { 00435 const int flags = 00436 itsSaveMapsNormalized.getVal() 00437 ? (FLOAT_NORM_0_255 | FLOAT_NORM_WITH_SCALE) 00438 : FLOAT_NORM_PRESERVE; 00439 00440 makeDirectory(this->contextName() + "-maps"); 00441 00442 ofs.writeFloat(tdmap, flags, 00443 this->contextName() + "-maps/td"); 00444 00445 ofs.writeFloat(bumap, flags, 00446 this->contextName() + "-maps/bu"); 00447 00448 ofs.writeFloat(combomap, flags, 00449 this->contextName() + "-maps/combo"); 00450 } 00451 00452 if (!ofs.isVoid() && itsSaveSumo2.getVal()) 00453 { 00454 ofs.writeRGB(makeSumoDisplay2(fin, 00455 tdmap, 00456 bumap, 00457 combomap, 00458 *itsTdata, 00459 eyepos), 00460 this->contextName() + "-sumo2"); 00461 } 00462 00463 if (itsSaveRawData.getVal()) 00464 { 00465 ASSERT(itsRawDataFile != 0); 00466 00467 ASSERT(bumap.getSize() == tdmap.getSize()); 00468 00469 const float sz = float(bumap.getSize()); 00470 fwrite(&sz, sizeof(float), 1, itsRawDataFile); 00471 00472 const float fpos = float(pos); 00473 fwrite(&fpos, sizeof(float), 1, itsRawDataFile); 00474 00475 const float buval = bumap[pos]; 00476 fwrite(&buval, sizeof(float), 1, itsRawDataFile); 00477 00478 const float tdval = tdmap[pos]; 00479 fwrite(&tdval, sizeof(float), 1, itsRawDataFile); 00480 00481 const float bumean = mean(bumap); 00482 fwrite(&bumean, sizeof(float), 1, itsRawDataFile); 00483 00484 const float tdmean = mean(tdmap); 00485 fwrite(&tdmean, sizeof(float), 1, itsRawDataFile); 00486 00487 const float bustd = stdev(bumap); 00488 fwrite(&bustd, sizeof(float), 1, itsRawDataFile); 00489 00490 const float tdstd = stdev(tdmap); 00491 fwrite(&tdstd, sizeof(float), 1, itsRawDataFile); 00492 00493 const float buz = (buval - bumean) / bustd; 00494 fwrite(&buz, sizeof(float), 1, itsRawDataFile); 00495 00496 const float tdz = (tdval - tdmean) / tdstd; 00497 fwrite(&tdz, sizeof(float), 1, itsRawDataFile); 00498 00499 fwrite(bumap.getArrayPtr(), sizeof(float), 00500 bumap.getSize(), itsRawDataFile); 00501 00502 fwrite(tdmap.getArrayPtr(), sizeof(float), 00503 tdmap.getSize(), itsRawDataFile); 00504 00505 fflush(itsRawDataFile); 00506 } 00507 } 00508 00509 if (!ofs.isVoid() && itsSaveSumo.getVal()) 00510 ofs.writeRGB(makeSumoDisplay(fin, tdmap, *itsTdata, 00511 eyepos, features), 00512 this->contextName()); 00513 } 00514 00515 private: 00516 OModelParam<std::string> itsXptSavePrefix; 00517 OModelParam<bool> itsDoBottomUp; 00518 OModelParam<bool> itsSaveSumo; 00519 OModelParam<bool> itsSaveSumo2; 00520 OModelParam<bool> itsSaveMaps; 00521 OModelParam<bool> itsSaveMapsNormalized; 00522 OModelParam<float> itsLocalMaxSize; 00523 OModelParam<uint> itsTemporalMax; 00524 OModelParam<bool> itsSaveRawData; 00525 OModelParam<bool> itsRectifyTd; 00526 00527 const nub::ref<TrainingSet> itsTdata; 00528 const std::string itsFxType; 00529 const nub::ref<FeatureExtractor> itsFx; 00530 const std::string itsLearnerType; 00531 const nub::ref<TopdownLearner> itsLearner; 00532 static nub::soft_ref<FeatureExtractor> theirBottomUp; 00533 std::string itsCtxName; 00534 MulticastScorer itsScorer; 00535 MulticastScorer itsBottomUpScorer; 00536 MulticastScorer itsComboScorer; 00537 std::deque<Image<float> > itsBuQ; 00538 std::deque<Image<float> > itsTdQ; 00539 std::deque<Image<float> > itsComboQ; 00540 FILE* itsRawDataFile; 00541 }; 00542 00543 nub::soft_ref<FeatureExtractor> Context::theirBottomUp; 00544 00545 class TigsJob : public ModelComponent 00546 { 00547 public: 00548 TigsJob(OptionManager& mgr) 00549 : 00550 ModelComponent(mgr, "TigsJob", "TigsJob"), 00551 itsXptSavePrefix(&OPT_XptSavePrefix, this), 00552 itsContextSpec(&OPT_TopdownContextSpec, this), 00553 itsMoviePeriod(&OPT_MoviePeriod, this), 00554 itsNumSkipFrames(&OPT_NumSkipFrames, this), 00555 itsNumTrainingFrames(&OPT_NumTrainingFrames, this), 00556 itsNumTestingFrames(&OPT_NumTestingFrames, this), 00557 itsSaveGhostFrames(&OPT_SaveGhostFrames, this), 00558 itsObsolete1(&OPT_MovieHertzObsolete, this) 00559 {} 00560 00561 virtual void start2() 00562 { 00563 rutz::prof::prof_summary_file_name 00564 ((itsXptSavePrefix.getVal() + "-prof.out").c_str()); 00565 00566 if (itsSaveGhostFrames.getVal().length() > 0) 00567 { 00568 itsGhostOutput = 00569 rutz::shared_ptr<std::ofstream> 00570 (new std::ofstream(itsSaveGhostFrames.getVal().c_str())); 00571 00572 if (!itsGhostOutput->is_open()) 00573 LFATAL("couldn't open '%s' for writing", 00574 itsSaveGhostFrames.getVal().c_str()); 00575 } 00576 } 00577 00578 virtual void paramChanged(ModelParamBase* const param, 00579 const bool valueChanged, 00580 ParamClient::ChangeStatus* status) 00581 { 00582 ModelComponent::paramChanged(param, valueChanged, status); 00583 00584 if (param == &itsContextSpec) 00585 { 00586 if (itsContextSpec.getVal() != "") 00587 this->addContext(itsContextSpec.getVal()); 00588 } 00589 } 00590 00591 void addContext(const std::string& spec) 00592 { 00593 GVX_ERR_CONTEXT(rutz::sfmt 00594 ("adding context for spec %s", spec.c_str())); 00595 00596 std::string::size_type comma = spec.find_first_of(','); 00597 if (comma == spec.npos) 00598 LFATAL("missing comma in context spec '%s'", spec.c_str()); 00599 if (comma+1 >= spec.length()) 00600 LFATAL("bogus context spec '%s'", spec.c_str()); 00601 00602 std::string fx_type = spec.substr(0, comma); 00603 std::string learner_type = spec.substr(comma+1); 00604 00605 LINFO("fxtype=%s, learnertype=%s", 00606 fx_type.c_str(), learner_type.c_str()); 00607 00608 nub::ref<Context> ctx(new Context(this->getManager(), 00609 fx_type, learner_type)); 00610 00611 this->addSubComponent(ctx); 00612 itsContexts.push_back(ctx); 00613 00614 ctx->exportOptions(MC_RECURSE); 00615 } 00616 00617 void loadTrainingSet(const std::string& xpt) 00618 { 00619 ASSERT(itsNumTrainingFrames.getVal() == 0); 00620 00621 for (size_t c = 0; c < itsContexts.size(); ++c) 00622 itsContexts[c]->loadTrainingSet(xpt); 00623 } 00624 00625 // returns true to continue looping, false to quit main loop 00626 bool handleFrame(int nframe, 00627 const TigsInputFrame& fin, 00628 const Point2D<int>& eyepos, 00629 const bool islast, 00630 OutputFrameSeries& ofs) 00631 { 00632 GVX_ERR_CONTEXT(rutz::sfmt("handling input frame %d in TigsJob", 00633 nframe)); 00634 00635 if (nframe < itsNumSkipFrames.getVal()) 00636 return true; 00637 00638 if (nframe >= (itsNumSkipFrames.getVal() 00639 +itsNumTrainingFrames.getVal() 00640 +itsNumTestingFrames.getVal())) 00641 { 00642 LINFO("exceeded skip+train+test frames"); 00643 return false; 00644 } 00645 00646 if (itsGhostOutput.is_valid()) 00647 { 00648 *itsGhostOutput << fin.toGhostString() << std::endl; 00649 } 00650 00651 for (size_t c = 0; c < itsContexts.size(); ++c) 00652 { 00653 if (nframe < itsNumSkipFrames.getVal()+itsNumTrainingFrames.getVal()) 00654 { 00655 const bool lasttraining = 00656 islast 00657 || 00658 (nframe+1 == (itsNumSkipFrames.getVal() 00659 +itsNumTrainingFrames.getVal())); 00660 00661 itsContexts[c]->trainingFrame(fin, eyepos, 00662 lasttraining, ofs); 00663 } 00664 else if (nframe < itsNumSkipFrames.getVal()+itsNumTrainingFrames.getVal()+itsNumTestingFrames.getVal()) 00665 { 00666 itsContexts[c]->testFrame(fin, eyepos, ofs); 00667 } 00668 } 00669 00670 return true; 00671 } 00672 00673 SimTime movieFrameLength() const 00674 { 00675 ASSERT(itsMoviePeriod.getVal() > SimTime::ZERO()); 00676 return itsMoviePeriod.getVal(); 00677 } 00678 00679 std::string getSavePrefix() const { return itsXptSavePrefix.getVal(); } 00680 00681 private: 00682 OModelParam<std::string> itsXptSavePrefix; 00683 OModelParam<std::string> itsContextSpec; 00684 OModelParam<SimTime> itsMoviePeriod; 00685 OModelParam<int> itsNumSkipFrames; 00686 OModelParam<int> itsNumTrainingFrames; 00687 OModelParam<int> itsNumTestingFrames; 00688 OModelParam<std::string> itsSaveGhostFrames; 00689 OModelParam<bool> itsObsolete1; 00690 std::vector<nub::ref<Context> > itsContexts; 00691 rutz::shared_ptr<std::ofstream> itsGhostOutput; 00692 }; 00693 00694 class TigsInputFrameSeries : public ModelComponent 00695 { 00696 public: 00697 TigsInputFrameSeries(OptionManager& mgr) 00698 : 00699 ModelComponent(mgr, "TigsInputFrameSeries", "TigsInputFrameSeries"), 00700 itsGhostInput(&OPT_GhostInput, this), 00701 itsIfs(new InputFrameSeries(mgr)), 00702 itsFirst(true) 00703 { 00704 this->addSubComponent(itsIfs); 00705 } 00706 00707 virtual void paramChanged(ModelParamBase* param, 00708 const bool valueChanged, 00709 ChangeStatus* status) 00710 { 00711 if (param == &itsGhostInput && valueChanged) 00712 { 00713 if (itsGhostInput.getVal().length() == 0) 00714 { 00715 ASSERT(itsIfs.is_valid() == false); 00716 00717 // close our ghost file; 00718 itsGhostFile.close(); 00719 00720 // make a new regular InputFrameSeries: 00721 itsIfs.reset(new InputFrameSeries(getManager())); 00722 this->addSubComponent(itsIfs); 00723 itsIfs->exportOptions(MC_RECURSE); 00724 itsFirst = true; 00725 } 00726 else 00727 { 00728 ASSERT(itsIfs.is_valid() == true); 00729 this->removeSubComponent(*itsIfs); 00730 itsIfs.reset(0); 00731 00732 itsGhostFile.open(itsGhostInput.getVal().c_str()); 00733 if (!itsGhostFile.is_open()) 00734 LFATAL("couldn't open file '%s' for reading", 00735 itsGhostInput.getVal().c_str()); 00736 00737 LINFO("reading ghost frames from '%s'", 00738 itsGhostInput.getVal().c_str()); 00739 00740 // ok, let's read the first line from the file so that we 00741 // can set the input dims: 00742 if (!std::getline(itsGhostFile, itsNextLine)) 00743 itsNextLine = ""; 00744 itsFirst = false; 00745 00746 rutz::shared_ptr<TigsInputFrame> f = 00747 TigsInputFrame::fromGhostString(itsNextLine); 00748 00749 getManager().setOptionValString 00750 (&OPT_InputFrameDims, 00751 convertToString(f->origbounds().dims())); 00752 } 00753 00754 // OK, now one way or the other we should have either a valid 00755 // InputFrameSeries or an open std::ifstream, but not both: 00756 ASSERT(itsIfs.is_valid() != itsGhostFile.is_open()); 00757 } 00758 } 00759 00760 rutz::shared_ptr<TigsInputFrame> getFrame(const SimTime stime, 00761 bool* islast) 00762 { 00763 if (itsIfs.is_valid()) 00764 { 00765 if (itsFirst) 00766 { 00767 itsIfs->updateNext(); 00768 itsNextFrame = itsIfs->readRGB(); 00769 itsFirst = false; 00770 } 00771 00772 if (!itsNextFrame.initialized()) 00773 return rutz::shared_ptr<TigsInputFrame>(); 00774 00775 // get a new frame and swap it with itsNextFrame 00776 itsIfs->updateNext(); 00777 Image<PixRGB<byte> > frame = itsIfs->readRGB(); 00778 frame.swap(itsNextFrame); 00779 00780 ASSERT(frame.initialized()); 00781 00782 *islast = (itsNextFrame.initialized() == false); 00783 00784 return rutz::shared_ptr<TigsInputFrame> 00785 (new TigsInputFrame(frame, stime)); 00786 } 00787 else 00788 { 00789 ASSERT(itsGhostFile.is_open()); 00790 00791 if (itsNextLine.length() == 0) 00792 return rutz::shared_ptr<TigsInputFrame>(); 00793 00794 // get a new line and swap it with itsNextLine 00795 std::string line; 00796 if (!std::getline(itsGhostFile, line)) 00797 line = ""; 00798 line.swap(itsNextLine); 00799 00800 ASSERT(line.length() > 0); 00801 00802 *islast = (itsNextLine.length() == 0); 00803 00804 rutz::shared_ptr<TigsInputFrame> result = 00805 TigsInputFrame::fromGhostString(line); 00806 00807 if (result->t() != stime) 00808 LFATAL("wrong time in ghost frame: expected %.2fms " 00809 "but got %.2fms", stime.msecs(), result->t().msecs()); 00810 00811 LINFO("got ghost frame: %s", line.c_str()); 00812 00813 return result; 00814 } 00815 } 00816 00817 private: 00818 OModelParam<std::string> itsGhostInput; 00819 00820 nub::soft_ref<InputFrameSeries> itsIfs; 00821 std::ifstream itsGhostFile; 00822 00823 bool itsFirst; 00824 Image<PixRGB<byte> > itsNextFrame; 00825 std::string itsNextLine; 00826 }; 00827 00828 int submain(int argc, const char** argv) 00829 { 00830 GVX_TRACE("test-TopdownContext-main"); 00831 00832 volatile int signum = 0; 00833 catchsignals(&signum); 00834 00835 rutz::prof::print_at_exit(true); 00836 00837 fpExceptionsUnlock(); 00838 fpExceptionsOff(); 00839 fpExceptionsLock(); 00840 00841 ModelManager mgr("topdown context tester"); 00842 00843 nub::ref<TigsInputFrameSeries> ifs(new TigsInputFrameSeries(mgr)); 00844 mgr.addSubComponent(ifs); 00845 00846 nub::ref<OutputFrameSeries> ofs(new OutputFrameSeries(mgr)); 00847 mgr.addSubComponent(ofs); 00848 00849 nub::ref<TigsJob> job(new TigsJob(mgr)); 00850 mgr.addSubComponent(job); 00851 00852 nub::ref<EyeSFile> eyeS(new EyeSFile(mgr)); 00853 mgr.addSubComponent(eyeS); 00854 00855 mgr.exportOptions(MC_RECURSE); 00856 00857 mgr.setOptionValString(&OPT_UseRandom, "false"); 00858 00859 if (mgr.parseCommandLine(argc, argv, 00860 "[load-pfx1 [load-pfx2 [...]]]", 00861 0, -1) == false) 00862 return 1; 00863 00864 ofs->setModelParamVal("OutputMPEGStreamFrameRate", int(24), 00865 MC_RECURSE | MC_IGNORE_MISSING); 00866 ofs->setModelParamVal("OutputMPEGStreamBitRate", int(2500000), 00867 MC_RECURSE | MC_IGNORE_MISSING); 00868 00869 mgr.start(); 00870 00871 { 00872 std::ofstream ofs((job->getSavePrefix() + ".model").c_str()); 00873 if (ofs.is_open()) 00874 { 00875 mgr.printout(ofs); 00876 } 00877 ofs.close(); 00878 } 00879 00880 mgr.printout(std::cout); 00881 00882 for (uint e = 0; e < mgr.numExtraArgs(); ++e) 00883 job->loadTrainingSet(mgr.getExtraArg(e)); 00884 00885 PauseWaiter p; 00886 00887 int nframes = 0; 00888 00889 while (1) 00890 { 00891 GVX_TRACE("frames loop"); 00892 00893 if (signum != 0) 00894 { 00895 LINFO("caught signal %s; quitting", signame(signum)); 00896 break; 00897 } 00898 00899 if (p.checkPause()) 00900 continue; 00901 00902 LINFO("trying frame %d", nframes); 00903 00904 const SimTime stime = job->movieFrameLength() * (nframes+1); 00905 00906 bool islast; 00907 rutz::shared_ptr<TigsInputFrame> fin = 00908 ifs->getFrame(stime, &islast); 00909 00910 if (fin.get() == 0) 00911 { 00912 LINFO("input exhausted; quitting"); 00913 break; 00914 } 00915 00916 const Point2D<int> eyepos = eyeS->readUpTo(stime); 00917 00918 LINFO("simtime %.6fs, movie frame %d, eye sample %d, ratio %f, " 00919 "eyepos (x=%d, y=%d)", 00920 stime.secs(), nframes+1, eyeS->lineNumber(), 00921 double(eyeS->lineNumber())/double(nframes+1), 00922 eyepos.i, eyepos.j); 00923 00924 if (!job->handleFrame(nframes, *fin, eyepos, islast, *ofs)) 00925 break; 00926 00927 ofs->updateNext(); 00928 00929 ++nframes; 00930 } 00931 00932 mgr.stop(); 00933 00934 return 0; 00935 } 00936 00937 int main(int argc, const char** argv) 00938 { 00939 try { 00940 submain(argc, argv); 00941 } catch(...) { 00942 REPORT_CURRENT_EXCEPTION; 00943 std::terminate(); 00944 } 00945 } 00946 00947 // ###################################################################### 00948 /* So things look consistent in everyone's emacs... */ 00949 /* Local Variables: */ 00950 /* indent-tabs-mode: nil */ 00951 /* End: */ 00952 00953 #endif // !APPNEURO_TEST_TOPDOWNCONTEXT_C_UTC20050726230120DEFINED