00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #ifndef TIGS_TRAININGSET_C_DEFINED
00039 #define TIGS_TRAININGSET_C_DEFINED
00040
00041 #include "TIGS/TrainingSet.H"
00042
00043 #include "Component/ModelOptionDef.H"
00044 #include "Image/ShapeOps.H"
00045 #include "Media/MediaOpts.H"
00046 #include "Raster/Raster.H"
00047 #include "TIGS/TigsOpts.H"
00048 #include "Util/AllocAux.H"
00049 #include "rutz/trace.h"
00050
00051
00052 static const ModelOptionDef OPT_TrainingSetDecimation =
00053 { MODOPT_ARG(int), "TrainingSetDecimation", &MOC_TIGS, OPTEXP_CORE,
00054 "Factor by which to decimate the number of samples in "
00055 "topdown context training sets",
00056 "tdata-decimation", '\0', "<int>", "1" };
00057
00058
00059 static const ModelOptionDef OPT_TrainingSetRebalance =
00060 { MODOPT_FLAG, "TrainingSetRebalance", &MOC_TIGS, OPTEXP_CORE,
00061 "Whether to rebalance the training set so that the distribution "
00062 "of eye positions is as flat as possible",
00063 "tdata-rebalance", '\0', "", "false" };
00064
00065
00066 static const ModelOptionDef OPT_TrainingSetRebalanceThresh =
00067 { MODOPT_ARG(uint), "TrainingSetRebalanceThresh", &MOC_TIGS, OPTEXP_CORE,
00068 "When rebalancing the training set's distribution of eye "
00069 "positions, only include positions for which at least this many "
00070 "samples are available",
00071 "tdata-rebalance-thresh", '\0', "<int>", "10" };
00072
00073
00074 static const ModelOptionDef OPT_TrainingSetRebalanceGroupSize =
00075 { MODOPT_ARG(uint), "TrainingSetRebalanceGroupSize", &MOC_TIGS, OPTEXP_CORE,
00076 "When rebalancing the training set's distribution of eye "
00077 "positions, pool the samples into this many samples per eye position",
00078 "tdata-rebalance-group-size", '\0', "<int>", "10" };
00079
00080 TrainingSet::TrainingSet(OptionManager& mgr, const std::string& fx_type)
00081 :
00082 ModelComponent(mgr, "TrainingSet", "TrainingSet"),
00083 itsRawInputDims(&OPT_InputFrameDims, this),
00084 itsDoRebalance(&OPT_TrainingSetRebalance, this),
00085 itsRebalanceThresh(&OPT_TrainingSetRebalanceThresh, this),
00086 itsRebalanceGroupSize(&OPT_TrainingSetRebalanceGroupSize, this),
00087 itsFxType(fx_type),
00088 itsReduction(32),
00089 itsNumFeatures(0),
00090 itsLocked(false),
00091 itsFeatureVec(),
00092 itsPositionVec(),
00093 itsPosGroups(),
00094 itsNumTraining(0),
00095 itsNumLoaded(0),
00096 itsFeatures(),
00097 itsPositions(),
00098 itsDecimationFactor(&OPT_TrainingSetDecimation, this)
00099 {}
00100
00101 Dims TrainingSet::scaledInputDims() const
00102 {
00103 ASSERT(itsRawInputDims.getVal().isNonEmpty());
00104 ASSERT(itsReduction > 0);
00105
00106 return itsRawInputDims.getVal() / int(itsReduction);
00107 }
00108
00109 size_t TrainingSet::numPositions() const
00110 {
00111 ASSERT(scaledInputDims().isNonEmpty());
00112 return scaledInputDims().sz();
00113 }
00114
00115 int TrainingSet::p2p(const int i, const int j) const
00116 {
00117 ASSERT(scaledInputDims().isNonEmpty());
00118 ASSERT(itsReduction > 0);
00119 return (j / itsReduction) * scaledInputDims().w() + (i / itsReduction);
00120 }
00121
00122 int TrainingSet::p2p(const Point2D<int>& p) const
00123 {
00124 return p2p(p.i, p.j);
00125 }
00126
00127 Image<float> TrainingSet::recordSample(const Point2D<int>& loc,
00128 const Image<float>& features)
00129 {
00130 GVX_TRACE(__PRETTY_FUNCTION__);
00131
00132 ASSERT(!itsLocked);
00133
00134 ASSERT(scaledInputDims().isNonEmpty());
00135
00136 if (itsNumFeatures == 0)
00137 {
00138
00139
00140 itsNumFeatures = features.getSize();
00141 LINFO("%s TrainingSet with %"ZU" features",
00142 itsFxType.c_str(), itsNumFeatures);
00143 }
00144
00145 ASSERT(itsNumFeatures > 0);
00146 ASSERT(size_t(features.getSize()) == itsNumFeatures);
00147
00148 ASSERT(loc.i >= 0);
00149 ASSERT(loc.j >= 0);
00150
00151 ASSERT(itsReduction > 0);
00152
00153 const Point2D<int> locr(loc.i / itsReduction, loc.j / itsReduction);
00154
00155 const size_t i1 = locr.i;
00156 const size_t i0 = locr.i > 0 ? (locr.i-1) : locr.i;
00157 const size_t i2 = locr.i < (scaledInputDims().w() - 1) ? (locr.i+1) : locr.i;
00158
00159 const size_t j1 = locr.j;
00160 const size_t j0 = locr.j > 0 ? (locr.j-1) : locr.j;
00161 const size_t j2 = locr.j < (scaledInputDims().h() - 1) ? (locr.j+1) : locr.j;
00162
00163 const size_t p00 = j0 * scaledInputDims().w() + i0;
00164 const size_t p01 = j1 * scaledInputDims().w() + i0;
00165 const size_t p02 = j2 * scaledInputDims().w() + i0;
00166
00167 const size_t p10 = j0 * scaledInputDims().w() + i1;
00168 const size_t p11 = j1 * scaledInputDims().w() + i1;
00169 const size_t p12 = j2 * scaledInputDims().w() + i1;
00170
00171 const size_t p20 = j0 * scaledInputDims().w() + i2;
00172 const size_t p21 = j1 * scaledInputDims().w() + i2;
00173 const size_t p22 = j2 * scaledInputDims().w() + i2;
00174
00175 const size_t np = this->numPositions();
00176
00177 for (size_t x = 0; x < np; ++x)
00178 {
00179 itsPositionVec.push_back(0.0f);
00180
00181 float& v = itsPositionVec.back();
00182
00183
00184
00185
00186 if (x == p00) v += 0.25;
00187 if (x == p01) v += 0.5;
00188 if (x == p02) v += 0.25;
00189
00190 if (x == p10) v += 0.5;
00191 if (x == p11) v += 1.0;
00192 if (x == p12) v += 0.5;
00193
00194 if (x == p20) v += 0.25;
00195 if (x == p21) v += 0.5;
00196 if (x == p22) v += 0.25;
00197 }
00198
00199 for (size_t x = 0; x < itsNumFeatures; ++x)
00200 {
00201 itsFeatureVec.push_back(features[x]);
00202 }
00203
00204 ++itsNumTraining;
00205
00206
00207
00208 return Image<float>(&*itsPositionVec.end() - this->numPositions(),
00209 scaledInputDims());
00210 }
00211
00212 void TrainingSet::load(const std::string& pfx)
00213 {
00214 if (itsDoRebalance.getVal())
00215 {
00216 this->loadRebalanced(pfx);
00217 return;
00218 }
00219
00220 GVX_TRACE(__PRETTY_FUNCTION__);
00221
00222 const std::string ffile = pfx+"-features.pfm";
00223 const std::string pfile = pfx+"-positions.pfm";
00224
00225 Image<float> feat = Raster::ReadFloat(ffile, RASFMT_PFM);
00226 Image<float> pos = Raster::ReadFloat(pfile, RASFMT_PFM);
00227
00228 ASSERT(feat.getHeight() == pos.getHeight());
00229
00230 if (itsNumFeatures == 0)
00231 {
00232
00233
00234 itsNumFeatures = feat.getWidth();
00235 LINFO("%s TrainingSet with %"ZU" features",
00236 itsFxType.c_str(), itsNumFeatures);
00237 }
00238
00239 ASSERT(size_t(feat.getWidth()) == itsNumFeatures);
00240
00241 if (itsDecimationFactor.getVal() > 1)
00242 {
00243 feat = blurAndDecY(feat, itsDecimationFactor.getVal());
00244 pos = blurAndDecY(pos, itsDecimationFactor.getVal());
00245
00246 ASSERT(feat.getHeight() == pos.getHeight());
00247 }
00248
00249 itsFeatureVec.insert(itsFeatureVec.end(), feat.begin(), feat.end());
00250 itsPositionVec.insert(itsPositionVec.end(), pos.begin(), pos.end());
00251
00252 itsNumTraining += feat.getHeight();
00253
00254 ++itsNumLoaded;
00255
00256
00257
00258
00259
00260 itsLocked = true;
00261
00262 LINFO("loaded %d samples from training set %s, %d total training samples from %d files",
00263 feat.getHeight(), pfx.c_str(), itsNumTraining, itsNumLoaded);
00264
00265
00266
00267
00268 invt_allocation_release_free_mem();
00269 }
00270
00271 void TrainingSet::loadRebalanced(const std::string& pfx)
00272 {
00273 GVX_TRACE(__PRETTY_FUNCTION__);
00274
00275 const std::string ffile = pfx+"-features.pfm";
00276 const std::string pfile = pfx+"-positions.pfm";
00277
00278 Image<float> feat = Raster::ReadFloat(ffile, RASFMT_PFM);
00279 Image<float> pos = Raster::ReadFloat(pfile, RASFMT_PFM);
00280
00281 ASSERT(feat.getHeight() == pos.getHeight());
00282
00283 if (itsNumFeatures == 0)
00284 {
00285
00286
00287 itsNumFeatures = feat.getWidth();
00288 LINFO("%s TrainingSet with %"ZU" features",
00289 itsFxType.c_str(), itsNumFeatures);
00290
00291 std::vector<PosGroup>().swap(itsPosGroups);
00292 itsPosGroups.resize(pos.getWidth(),
00293 PosGroup(itsRebalanceGroupSize.getVal(),
00294 feat.getWidth(), pos.getWidth()));
00295
00296 ASSERT(itsRebalanceThresh.getVal() >= itsRebalanceGroupSize.getVal());
00297 }
00298
00299 ASSERT(size_t(feat.getWidth()) == itsNumFeatures);
00300
00301 for (int y = 0; y < pos.getHeight(); ++y)
00302 {
00303 int nmax = 0;
00304 for (int x = 0; x < pos.getWidth(); ++x)
00305 {
00306 const float v = pos.getVal(x, y);
00307 if (v >= 1.0f)
00308 {
00309 ++nmax;
00310 itsPosGroups[x].add(feat.getArrayPtr() + y * feat.getWidth(),
00311 pos.getArrayPtr() + y * pos.getWidth());
00312 }
00313 }
00314
00315 if (nmax != 1)
00316 LFATAL("nmax = %d (expected nmax = 1) in row %d", nmax, y);
00317 }
00318
00319 std::vector<float>().swap(itsFeatureVec);
00320 std::vector<float>().swap(itsPositionVec);
00321
00322 uint nzero = 0;
00323 uint naccept = 0;
00324 uint nsamp = 0;
00325 Image<byte> bb(20, 15, ZEROS);
00326 itsNumTraining = 0;
00327 for (uint i = 0; i < itsPosGroups.size(); ++i)
00328 {
00329 if (itsPosGroups[i].totalcount == 0)
00330 ++nzero;
00331 if (itsPosGroups[i].totalcount >= itsRebalanceThresh.getVal())
00332 {
00333 ++naccept;
00334
00335 for (uint k = 0; k < itsPosGroups[i].counts.size(); ++k)
00336 {
00337 const Image<float> f =
00338 itsPosGroups[i].features[k] / itsPosGroups[i].counts[k];
00339
00340 const Image<float> p =
00341 itsPosGroups[i].positions[k] / itsPosGroups[i].counts[k];
00342
00343 itsFeatureVec.insert(itsFeatureVec.end(),
00344 f.begin(), f.end());
00345 itsPositionVec.insert(itsPositionVec.end(),
00346 p.begin(), p.end());
00347
00348 ++itsNumTraining;
00349 }
00350 bb[i] = 255;
00351 }
00352 nsamp += itsPosGroups[i].totalcount;
00353 }
00354
00355 LINFO("ngroups = %" ZU ", nsamp = %u, naccept = %u, nzero = %u",
00356 itsPosGroups.size(), nsamp, naccept, nzero);
00357
00358 ++itsNumLoaded;
00359
00360
00361
00362
00363
00364 itsLocked = true;
00365
00366 LINFO("loaded %d samples from training set %s, %d total training samples from %d files",
00367 feat.getHeight(), pfx.c_str(), itsNumTraining, itsNumLoaded);
00368
00369
00370
00371
00372 invt_allocation_release_free_mem();
00373 }
00374
00375 void TrainingSet::save(const std::string& pfx)
00376 {
00377 GVX_TRACE(__PRETTY_FUNCTION__);
00378
00379 const std::string ffile = pfx+"-features.pfm";
00380 const std::string pfile = pfx+"-positions.pfm";
00381
00382 if (Raster::fileExists(ffile))
00383 LINFO("save skipped; file already exists: %s", ffile.c_str());
00384 else
00385 Raster::WriteFloat(this->getFeatures(), FLOAT_NORM_PRESERVE, ffile, RASFMT_PFM);
00386
00387 if (Raster::fileExists(pfile))
00388 LINFO("save skipped; file already exists: %s", pfile.c_str());
00389 else
00390 Raster::WriteFloat(this->getPositions(), FLOAT_NORM_PRESERVE, pfile, RASFMT_PFM);
00391
00392 LINFO("saved training set %s", pfx.c_str());
00393 }
00394
00395 Image<float> TrainingSet::getFeatures() const
00396 {
00397 ASSERT(itsNumFeatures > 0);
00398
00399 if (itsFeatures.getHeight() != itsNumTraining)
00400 {
00401 itsFeatures = Image<float>(&itsFeatureVec[0],
00402 itsNumFeatures, itsNumTraining);
00403 }
00404
00405 return itsFeatures;
00406 }
00407
00408 Image<float> TrainingSet::getPositions() const
00409 {
00410 if (itsPositions.getHeight() != itsNumTraining)
00411 {
00412 itsPositions = Image<float>(&itsPositionVec[0],
00413 this->numPositions(), itsNumTraining);
00414 }
00415
00416 return itsPositions;
00417 }
00418
00419 uint TrainingSet::inputReduction() const
00420 {
00421 return itsReduction;
00422 }
00423
00424 const std::string& TrainingSet::fxType() const
00425 {
00426 return itsFxType;
00427 }
00428
00429
00430
00431
00432
00433
00434
00435
00436 #endif // TIGS_TRAININGSET_C_DEFINED