00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #ifndef LEARN_BACKPROPNETWORK_C_DEFINED
00039 #define LEARN_BACKPROPNETWORK_C_DEFINED
00040
00041 #include "Learn/BackpropNetwork.H"
00042
00043 #include "Image/CutPaste.H"
00044 #include "Image/MathOps.H"
00045 #include "Image/MatrixOps.H"
00046 #include "Util/CpuTimer.H"
00047 #include "Util/sformat.H"
00048 #include "rutz/rand.h"
00049 #include "rutz/trace.h"
00050
00051 #include <limits>
00052
00053 namespace
00054 {
00055 #if 1 // use the "standard" logistic sigmoid function
00056
00057
00058 void inplaceBackpropSigmoid(Image<float>& dst)
00059 {
00060 GVX_TRACE(__PRETTY_FUNCTION__);
00061
00062 Image<float>::iterator dptr = dst.beginw(), dstop = dst.endw();
00063
00064 while (dptr != dstop)
00065 {
00066 *dptr = 1.0f / (1.0f + exp(-(*dptr)));
00067 ++dptr;
00068 }
00069 }
00070
00071 inline float backpropSigmoidDerivf(const float src)
00072 {
00073 return src * (1.0f-src);
00074 }
00075
00076 #else // use a cheaper-to-compute sigmoid function
00077
00078
00079
00080 void inplaceBackpropSigmoid(Image<float>& dst)
00081 {
00082 GVX_TRACE(__PRETTY_FUNCTION__);
00083
00084 Image<float>::iterator dptr = dst.beginw(), dstop = dst.endw();
00085
00086 while (dptr != dstop)
00087 {
00088 if (*dptr >= 0.0f)
00089 *dptr = 1.0f - 1.0f/(1.0f + (*dptr));
00090 else
00091 *dptr = -1.0f + 1.0f/(1.0f - (*dptr));
00092
00093 *dptr = (*dptr + 1.0f) * 0.5f;
00094
00095 ++dptr;
00096 }
00097 }
00098
00099 inline float backpropSigmoidDerivf(const float src)
00100 {
00101 const float rawsrc = (src * 2.0f) - 1.0f;
00102
00103 if (rawsrc >= 0.0)
00104 return (rawsptr - 1.0f) * (rawsptr - 1.0f);
00105 else
00106 return (rawsptr + 1.0f) * (rawsptr + 1.0f);
00107 }
00108 #endif
00109
00110
00111
00112
00113 template <class T>
00114 Image<T> operator*(T numer, const Image<T>& denom)
00115 {
00116 Image<T> result(denom.getDims(), NO_INIT);
00117
00118 typename Image<T>::iterator dptr = result.beginw();
00119 typename Image<T>::iterator dstop = result.endw();
00120 typename Image<T>::const_iterator sptr = denom.begin();
00121
00122 while (dptr != dstop)
00123 *dptr++ = numer * (*sptr++);
00124
00125 return result;
00126 }
00127
00128 template <class T>
00129 Image<T> operator/(T numer, const Image<T>& denom)
00130 {
00131 Image<T> result(denom.getDims(), NO_INIT);
00132
00133 typename Image<T>::iterator dptr = result.beginw();
00134 typename Image<T>::iterator dstop = result.endw();
00135 typename Image<T>::const_iterator sptr = denom.begin();
00136
00137 while (dptr != dstop)
00138 *dptr++ = numer / (*sptr++);
00139
00140 return result;
00141 }
00142
00143 template <class T>
00144 Image<T> operator+(T numer, const Image<T>& denom)
00145 {
00146 Image<T> result(denom.getDims(), NO_INIT);
00147
00148 typename Image<T>::iterator dptr = result.beginw();
00149 typename Image<T>::iterator dstop = result.endw();
00150 typename Image<T>::const_iterator sptr = denom.begin();
00151
00152 while (dptr != dstop)
00153 *dptr++ = numer + (*sptr++);
00154
00155 return result;
00156 }
00157
00158 template <class T>
00159 Image<T> operator-(T numer, const Image<T>& denom)
00160 {
00161 Image<T> result(denom.getDims(), NO_INIT);
00162
00163 typename Image<T>::iterator dptr = result.beginw();
00164 typename Image<T>::iterator dstop = result.endw();
00165 typename Image<T>::const_iterator sptr = denom.begin();
00166
00167 while (dptr != dstop)
00168 *dptr++ = numer - (*sptr++);
00169
00170 return result;
00171 }
00172 }
00173
00174 void BackpropNetwork::train(const Image<float>& X,
00175 const Image<float>& D,
00176 const int h,
00177 const float eta,
00178 const float alph,
00179 const int iters,
00180 double* Efinal,
00181 double* Cfinal)
00182 {
00183 GVX_TRACE(__PRETTY_FUNCTION__);
00184
00185 const int n = X.getHeight();
00186 const int N = X.getWidth();
00187 const int m = D.getHeight();
00188
00189 LINFO("%d samples, %d input, %d hidden units, %d output units",
00190 N, n, h, m);
00191
00192 ASSERT(D.getWidth() == X.getWidth());
00193
00194
00195
00196 {
00197 double bestE = std::numeric_limits<double>::max();
00198 Image<float> bestW, bestV;
00199
00200 for (int r = 0; r < 10; ++r)
00201 {
00202
00203 rutz::urand_frange g(-1.0, 1.0, time((time_t*)0)+getpid());
00204 this->W = Image<float>(n+1, h, NO_INIT); g = fill(this->W, g);
00205 this->V = Image<float>(h+1, m, NO_INIT); g = fill(this->V, g);
00206
00207 const Image<float> Y = this->compute(X);
00208 double E = RMSerr(Y, D);
00209
00210 if (E < bestE)
00211 {
00212 LINFO("new best E=%f at init iteration %d", E, r);
00213 bestE = E;
00214 bestW = this->W;
00215 bestV = this->V;
00216 }
00217 }
00218
00219 this->W = bestW;
00220 this->V = bestV;
00221 }
00222
00223 Image<float> Xp;
00224 Image<float> Z;
00225 Image<float> Zp;
00226 Image<float> Y;
00227 Image<float> eY;
00228 Image<float> eZ;
00229 Image<float> dEdV;
00230 Image<float> dEdW;
00231 Image<float> delW;
00232 Image<float> delV;
00233 Image<float> delWprev;
00234 Image<float> delVprev;
00235
00236 double E = 1.0;
00237 double C = 0.0;
00238
00239 {GVX_TRACE("generate-Xp");
00240 Xp = Image<float>(X.getWidth(), X.getHeight() + 1, NO_INIT);
00241 Xp.clear(-1.0f);
00242 inplacePaste(Xp, X, Point2D<int>(0,1));
00243 }
00244
00245 CpuTimer t;
00246 CpuTimer t2;
00247
00248 for (int i = 0; i < iters; ++i)
00249 {
00250 GVX_TRACE("backprop loop");
00251
00252 {GVX_TRACE("compute-Z");
00253 Z = matrixMult(this->W, Xp);
00254 inplaceBackpropSigmoid(Z);
00255 }
00256
00257 {GVX_TRACE("generate-Zp");
00258 if (!Zp.initialized())
00259 {
00260 Zp = Image<float>(Z.getWidth(), Z.getHeight() + 1, NO_INIT);
00261 Zp.clear(-1.0f);
00262 }
00263 inplacePaste(Zp, Z, Point2D<int>(0,1));
00264 }
00265
00266 {GVX_TRACE("compute-Y");
00267 Y = matrixMult(this->V, Zp);
00268 inplaceBackpropSigmoid(Y);
00269 }
00270
00271 ASSERT(Y.getDims() == D.getDims());
00272
00273 E = RMSerr(Y, D);
00274 #if 0
00275 const float materr =
00276 (0.5 * E * E * double(Y.getSize()) / N);
00277 #endif
00278 C = corrcoef(Y, D);
00279
00280 {GVX_TRACE("compute-eY");
00281 eY.resize(Y.getDims());
00282 const int sz = eY.getSize();
00283 Image<float>::iterator const eYptr = eY.beginw();
00284 Image<float>::const_iterator const Yptr = Y.begin();
00285 Image<float>::const_iterator const Dptr = D.begin();
00286
00287 for (int k = 0; k < sz; ++k)
00288 eYptr[k] = (Yptr[k] - Dptr[k]) * backpropSigmoidDerivf(Yptr[k]);
00289 }
00290
00291 {GVX_TRACE("compute-eZ");
00292 const Image<float> eY_V = transpose(matrixMult(transpose(eY), V));
00293 eZ.resize(Zp.getDims());
00294 const int sz = eZ.getSize();
00295 Image<float>::iterator const eZptr = eZ.beginw();
00296 Image<float>::const_iterator const eY_Vptr = eY_V.begin();
00297 Image<float>::const_iterator const Zpptr = Zp.begin();
00298
00299 for (int k = 0; k < sz; ++k)
00300 eZptr[k] = eY_Vptr[k] * backpropSigmoidDerivf(Zpptr[k]);
00301 }
00302
00303 {GVX_TRACE("compute-dEdV");
00304 dEdV = matrixMult(eY, transpose(Zp));
00305 }
00306 {GVX_TRACE("compute-dEdW");
00307 dEdW = matrixMult(eZ, transpose(Xp));
00308 }
00309
00310 {GVX_TRACE("compute-delW");
00311 delW = (-eta) * crop(dEdW, Point2D<int>(0,1), Dims(n+1, h));
00312 if (delWprev.initialized())
00313 delW += alph * delWprev;
00314 delWprev = delW;
00315 }
00316
00317 {GVX_TRACE("compute-delV");
00318 delV = (-eta) * dEdV;
00319 if (delVprev.initialized())
00320 delV += alph * delVprev;
00321 delVprev = delV;
00322 }
00323
00324 this->V += delV;
00325 this->W += delW;
00326
00327 t2.mark();
00328 if (t2.real_secs() > 0.5)
00329 {
00330 t2.reset();
00331 t.mark();
00332 t.report(sformat("iteration %d, E=%f, C=%f", i, E, C).c_str());
00333 }
00334 }
00335
00336 if (Efinal != 0)
00337 *Efinal = E;
00338 if (Cfinal != 0)
00339 *Cfinal = C;
00340 }
00341
00342 Image<float> BackpropNetwork::compute(const Image<float>& X) const
00343 {
00344 GVX_TRACE(__PRETTY_FUNCTION__);
00345
00346 Image<float> Xp(X.getWidth(), X.getHeight() + 1, NO_INIT);
00347 Xp.clear(-1.0f);
00348 inplacePaste(Xp, X, Point2D<int>(0,1));
00349
00350 Image<float> Z = matrixMult(this->W, Xp);
00351 inplaceBackpropSigmoid(Z);
00352
00353 Image<float> Zp(Z.getWidth(), Z.getHeight() + 1, NO_INIT);
00354 Zp.clear(-1.0f);
00355 inplacePaste(Zp, Z, Point2D<int>(0,1));
00356
00357 Image<float> Y = matrixMult(this->V, Zp);
00358 inplaceBackpropSigmoid(Y);
00359
00360 ASSERT(Y.getWidth() == X.getWidth());
00361
00362 return Y;
00363 }
00364
00365
00366
00367
00368
00369
00370
00371 #endif // LEARN_BACKPROPNETWORK_C_DEFINED