00001 /*!@file TIGS/LeastSquaresLearner.C */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Rob Peters <rjpeters at usc dot edu> 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/TIGS/LeastSquaresLearner.C $ 00035 // $Id: LeastSquaresLearner.C 6191 2006-02-01 23:56:12Z rjpeters $ 00036 // 00037 00038 #ifndef TIGS_LEASTSQUARESLEARNER_C_DEFINED 00039 #define TIGS_LEASTSQUARESLEARNER_C_DEFINED 00040 00041 #include "TIGS/LeastSquaresLearner.H" 00042 00043 #include "Component/ModelOptionDef.H" 00044 #include "GUI/XWinManaged.H" 00045 #include "Image/LinearAlgebra.H" 00046 #include "Image/MatrixOps.H" 00047 #include "Image/MathOps.H" 00048 #include "Image/Range.H" 00049 #include "Raster/Raster.H" 00050 #include "TIGS/TigsOpts.H" 00051 #include "TIGS/TrainingSet.H" 00052 #include "Util/CpuTimer.H" 00053 #include "Util/log.H" 00054 #include "rutz/trace.h" 00055 00056 // Used by: LeastSquaresLearner 00057 static const ModelOptionDef OPT_LsqSvdThresholdFactor = 00058 { MODOPT_ARG(float), "LsqSvdThresholdFactor", &MOC_TIGS, OPTEXP_CORE, 00059 "Multiple of the largest eigenvalue below which eigenvectors " 00060 "with small eigenvalues will be thrown out", 00061 "lsq-svd-thresh", '\0', "<float>", "1.0e-8f" }; 00062 00063 // Used by: LeastSquaresLearner 00064 static const ModelOptionDef OPT_LsqUseWeightsFile = 00065 { MODOPT_FLAG, "LsqUseWeightsFile", &MOC_TIGS, OPTEXP_CORE, 00066 "Whether to write/read least-squares weights file(s)", 00067 "lsq-use-weights-files", '\0', "", "false" }; 00068 00069 namespace 00070 { 00071 void inspect(const Image<float>& img, const char* name) 00072 { 00073 float m = mean(img); 00074 Range<float> r = rangeOf(img); 00075 LINFO("%s: (w,h)=(%d,%d), range=[%f..%f], mean=%f", 00076 name, img.getWidth(), img.getHeight(), r.min(), r.max(), m); 00077 } 00078 } 00079 00080 LeastSquaresLearner::LeastSquaresLearner(OptionManager& mgr) 00081 : 00082 TopdownLearner(mgr, "LeastSquaresLearner", "LeastSquaresLearner"), 00083 itsSvdThresh(&OPT_LsqSvdThresholdFactor, this), 00084 itsXptSavePrefix(&OPT_XptSavePrefix, this), 00085 itsUseWeightsFile(&OPT_LsqUseWeightsFile, this), 00086 itsWeights() // don't initialize until we're done training 00087 {} 00088 00089 void LeastSquaresLearner::dontSave() 00090 { 00091 itsUseWeightsFile.setVal(false); 00092 } 00093 00094 Image<float> LeastSquaresLearner::getBiasMap(const TrainingSet& tdata, 00095 const Image<float>& features) const 00096 { 00097 GVX_TRACE(__PRETTY_FUNCTION__); 00098 if (!itsWeights.initialized()) 00099 { 00100 const Image<float> rawTrainFeatures = tdata.getFeatures(); 00101 inspect(rawTrainFeatures, "rawTrainFeatures"); 00102 00103 itsMeanFeatures = meanRow(tdata.getFeatures()); 00104 inspect(itsMeanFeatures, "itsMeanFeatures"); 00105 00106 Image<float> trainFeatures = 00107 subtractRow(rawTrainFeatures, itsMeanFeatures); 00108 inspect(trainFeatures, "trainFeatures"); 00109 00110 itsStdevFeatures = stdevRow(trainFeatures); 00111 inspect(itsStdevFeatures, "itsStdevFeatures"); 00112 00113 trainFeatures = divideRow(trainFeatures, itsStdevFeatures); 00114 inspect(trainFeatures, "trainFeatures"); 00115 00116 const std::string name = 00117 itsXptSavePrefix.getVal() + "-" + tdata.fxType() + "-lsq"; 00118 00119 const std::string weightsfile = name + "-weights.pfm"; 00120 00121 if (itsUseWeightsFile.getVal() && 00122 Raster::fileExists(weightsfile)) 00123 { 00124 itsWeights = Raster::ReadFloat(weightsfile.c_str(), RASFMT_PFM); 00125 00126 LINFO("loaded weights (%s) from %s", 00127 name.c_str(), weightsfile.c_str()); 00128 } 00129 else 00130 { 00131 00132 try { 00133 CpuTimer t; 00134 00135 int rank = 0; 00136 00137 LINFO("svd threshold factor is %e", 00138 double(itsSvdThresh.getVal())); 00139 00140 const Image<float> pinvFeatures = 00141 svdPseudoInvf(trainFeatures, SVD_LAPACK, &rank, 00142 itsSvdThresh.getVal()); 00143 00144 t.mark(); 00145 t.report(sformat("pinvFeatures (%s)", name.c_str()).c_str()); 00146 00147 LINFO("svd rank=%d, fullrank=%d", 00148 rank, trainFeatures.getWidth()); 00149 00150 LINFO("trainFeatures size %dx%d, pinvFeatures size %dx%d", 00151 trainFeatures.getWidth(), trainFeatures.getHeight(), 00152 pinvFeatures.getWidth(), pinvFeatures.getHeight()); 00153 00154 const bool do_precisioncheck = false; 00155 00156 if (do_precisioncheck) { 00157 const Image<float> precisioncheck = 00158 matrixMult(trainFeatures, pinvFeatures); 00159 00160 t.mark(); 00161 t.report(sformat("precisioncheck (%s)", name.c_str()).c_str()); 00162 00163 const Image<float> diff = 00164 precisioncheck - eye<float>(pinvFeatures.getWidth()); 00165 00166 t.mark(); 00167 t.report(sformat("diff (%s)", name.c_str()).c_str()); 00168 00169 LINFO("rms error after inversion: %f", 00170 RMSerr(precisioncheck, eye<float>(pinvFeatures.getWidth()))); 00171 } 00172 00173 const Image<float> rawTrainPositions = tdata.getPositions(); 00174 00175 itsMeanPositions = meanRow(rawTrainPositions); 00176 00177 itsWeights = 00178 matrixMult(pinvFeatures, 00179 subtractRow(rawTrainPositions, itsMeanPositions)); 00180 00181 t.mark(); 00182 t.report(sformat("itsWeights (%s)", name.c_str()).c_str()); 00183 00184 if (itsUseWeightsFile.getVal()) 00185 { 00186 Raster::WriteFloat(itsWeights, FLOAT_NORM_PRESERVE, 00187 weightsfile.c_str(), RASFMT_PFM); 00188 00189 LINFO("saved weights (%s) to %s", 00190 name.c_str(), weightsfile.c_str()); 00191 } 00192 } 00193 catch (SingularMatrixException& e) { 00194 XWinManaged win(e.mtx, "singular matrix", true); 00195 00196 int c = 0; 00197 while (!win.pressedCloseButton() && ++c < 100) 00198 usleep(10000); 00199 00200 exit(1); 00201 } 00202 } 00203 } 00204 00205 ASSERT(itsWeights.getHeight() == features.getWidth()); 00206 ASSERT(itsWeights.getWidth() == tdata.scaledInputDims().sz()); 00207 00208 const Image<float> featureVec = 00209 divideRow(subtractRow(features, itsMeanFeatures), 00210 itsStdevFeatures); 00211 00212 const Image<float> result = 00213 addRow(matrixMult(featureVec, itsWeights), 00214 itsMeanPositions); 00215 00216 ASSERT(result.getWidth() == tdata.scaledInputDims().sz()); 00217 ASSERT(result.getHeight() == features.getHeight()); 00218 00219 return result; 00220 } 00221 00222 // ###################################################################### 00223 /* So things look consistent in everyone's emacs... */ 00224 /* Local Variables: */ 00225 /* mode: c++ */ 00226 /* indent-tabs-mode: nil */ 00227 /* End: */ 00228 00229 #endif // TIGS_LEASTSQUARESLEARNER_C_DEFINED