BackpropLearner.C
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #ifndef TIGS_BACKPROPLEARNER_C_DEFINED
00039 #define TIGS_BACKPROPLEARNER_C_DEFINED
00040
00041 #include "TIGS/BackpropLearner.H"
00042
00043 #include "Image/MathOps.H"
00044 #include "Image/MatrixOps.H"
00045 #include "Learn/BackpropNetwork.H"
00046 #include "TIGS/LeastSquaresLearner.H"
00047 #include "TIGS/TrainingSet.H"
00048 #include "rutz/trace.h"
00049
00050 BackpropLearner::BackpropLearner(OptionManager& mgr)
00051 :
00052 TopdownLearner(mgr, "BackpropLearner", "BackpropLearner"),
00053 itsLsq(new LeastSquaresLearner(mgr)),
00054 itsNetwork(0),
00055 itsInRange(0.0f, 1.0f),
00056 itsOutRange(0.0f, 1.0f)
00057 {
00058 itsLsq->dontSave();
00059 }
00060
00061 BackpropLearner::~BackpropLearner()
00062 {
00063 delete itsNetwork;
00064 }
00065
00066 Image<float> BackpropLearner::getBiasMap(const TrainingSet& tdata,
00067 const Image<float>& features) const
00068 {
00069 GVX_TRACE(__PRETTY_FUNCTION__);
00070
00071 if (itsNetwork == 0)
00072 {
00073 itsNetwork = new BackpropNetwork;
00074
00075 const int nhidden = 100;
00076 const float eta = 0.5f;
00077 const float alph = 0.5f;
00078 const int iters = 3000;
00079
00080 Image<float> XX = itsLsq->getBiasMap(tdata,
00081 tdata.getFeatures());
00082
00083 double preE = RMSerr(XX, tdata.getPositions());
00084 double preC = corrcoef(XX, tdata.getPositions());
00085 LINFO("preE=%f, preC=%f", preE, preC);
00086
00087 Image<float> X = transpose(XX);
00088 Image<float> D = transpose(tdata.getPositions());
00089
00090 itsInRange = rangeOf(X);
00091 itsOutRange = rangeOf(D);
00092
00093 X = remapRange(X, itsInRange, Range<float>(0.0f, 1.0f));
00094 D = remapRange(D, itsOutRange, Range<float>(0.0f, 1.0f));
00095
00096 double E, C;
00097 itsNetwork->train(X, D, nhidden, eta, alph, iters, &E, &C);
00098
00099 LINFO("E=%f, C=%f", E, C);
00100 }
00101
00102 Image<float> ff = transpose(itsLsq->getBiasMap(tdata, features));
00103 ff = remapRange(ff, itsInRange, Range<float>(0.0f, 1.0f));
00104 Image<float> bb = transpose(itsNetwork->compute(ff));
00105 bb = remapRange(bb, Range<float>(0.0f, 1.0f), itsOutRange);
00106 return bb;
00107 }
00108
00109
00110
00111
00112
00113
00114
00115
00116 #endif // TIGS_BACKPROPLEARNER_C_DEFINED