00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #ifndef NEURO_WINNERTAKEALLSTDOPTIM_C_DEFINED
00039 #define NEURO_WINNERTAKEALLSTDOPTIM_C_DEFINED
00040
00041 #include "Neuro/WinnerTakeAllStdOptim.H"
00042
00043 #include "Util/JobWithSemaphore.H"
00044 #include "Util/MainJobServer.H"
00045 #include "rutz/trace.h"
00046
00047 struct WinnerTakeAllStdOptim::EvolveJob : public JobWithSemaphore
00048 {
00049 EvolveJob(Image<float>::iterator vitr_,
00050 Image<float>::iterator vstop_,
00051 Image<float>::const_iterator initr_,
00052 const float ginput_,
00053 const float dt_c_,
00054 const float gsum_,
00055 const float isum_,
00056 const float Ei_,
00057 const float Vth_)
00058 :
00059 vitr(vitr_),
00060 vstop(vstop_),
00061 initr(initr_),
00062 ginput(ginput_),
00063 dt_c(dt_c_),
00064 gsum(gsum_),
00065 isum(isum_),
00066 Ei(Ei_),
00067 Vth(Vth_),
00068 vwinner()
00069 {}
00070
00071 virtual ~EvolveJob()
00072 {}
00073
00074 virtual void run()
00075 {
00076 while (vitr != vstop)
00077 {
00078 const float Iin = ginput * (*initr++);
00079
00080
00081 (*vitr) += dt_c * (Iin - (*vitr) * gsum + isum);
00082
00083
00084 if ((*vitr) < Ei) (*vitr) = Ei;
00085
00086
00087 if ((*vitr) >= Vth) { vwinner = vitr; *vitr = 0.0F; }
00088
00089 ++vitr;
00090 }
00091
00092 this->markFinished();
00093 }
00094
00095 virtual const char* jobType() const
00096 { return "WinnerTakeAllStdOptimEvolveJob"; }
00097
00098 Image<float>::iterator vitr;
00099 Image<float>::iterator vstop;
00100 Image<float>::const_iterator initr;
00101 const float ginput;
00102 const float dt_c;
00103 const float gsum;
00104 const float isum;
00105 const float Ei;
00106 const float Vth;
00107 Image<float>::iterator vwinner;
00108 };
00109
00110
00111
00112
00113
00114
00115
00116 WinnerTakeAllStdOptim::WinnerTakeAllStdOptim(OptionManager& mgr,
00117 const std::string& descrName,
00118 const std::string& tagName) :
00119 WinnerTakeAllAdapter(mgr, descrName, tagName),
00120 itsTimeStep(SimTime::SECS(0.0001)),
00121 itsEl(0.0F),
00122 itsEe(100.0e-3F),
00123 itsEi(-20.0e-3F),
00124 itsC(1.0E-9F),
00125 itsVth(0.001F),
00126 itsV(),
00127 itsT(),
00128 itsGleak(1.0e-8F),
00129 itsGinh(1.0e-2F),
00130 itsGinput(5.0e-8F),
00131 itsGIN_Gl(1.0e-8F),
00132 itsGIN_Ge(0.0F),
00133 itsGIN_El(0.0F),
00134 itsGIN_Ee(100.0e-3F),
00135 itsGIN_Ei(-20.0e-3F),
00136 itsGIN_C(1.0E-9F),
00137 itsGIN_Vth(0.001F),
00138 itsGIN_V(itsGIN_Ei),
00139 itsInputCopy()
00140 {
00141 GVX_TRACE(__PRETTY_FUNCTION__);
00142 }
00143
00144
00145 WinnerTakeAllStdOptim::~WinnerTakeAllStdOptim()
00146 {
00147 GVX_TRACE(__PRETTY_FUNCTION__);
00148 }
00149
00150
00151 void WinnerTakeAllStdOptim::reset1()
00152 {
00153 GVX_TRACE(__PRETTY_FUNCTION__);
00154 itsV.freeMem();
00155 itsInputCopy.freeMem();
00156 itsT = SimTime::ZERO();
00157
00158 WinnerTakeAllAdapter::reset1();
00159 }
00160
00161
00162 void WinnerTakeAllStdOptim::input(const Image<float>& in)
00163 {
00164 GVX_TRACE(__PRETTY_FUNCTION__);
00165 if (itsV.initialized() == false)
00166 {
00167
00168 itsV.resize(in.getDims(), NO_INIT);
00169 itsV.clear(itsEi);
00170
00171 itsGe = 0.0F;
00172 itsGi = itsGinh;
00173 }
00174
00175
00176 itsInputCopy = in;
00177 }
00178
00179
00180 void WinnerTakeAllStdOptim::integrate(const SimTime& t, Point2D<int>& winner)
00181 {
00182 GVX_TRACE(__PRETTY_FUNCTION__);
00183 winner.i = -1;
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193 const SimTime dt = SimTime::computeDeltaT(t - itsT, itsTimeStep);
00194 const float dt_c = float(dt.secs()) / itsC;
00195
00196 for (SimTime tt = itsT; tt < t; tt += dt)
00197 {
00198 if (tt == SimTime::ZERO())
00199 continue;
00200
00201 ASSERT(dt != SimTime::ZERO());
00202
00203 const float gsum = itsGleak + itsGe + itsGi;
00204 const float isum = itsGleak * itsEl + itsGe * itsEe + itsGi * itsEi;
00205
00206 JobServer& srv = getMainJobServer();
00207
00208 const unsigned int ntiles = srv.getParallelismHint();
00209
00210 std::vector<rutz::shared_ptr<EvolveJob> > jobs;
00211
00212 for (unsigned int i = 0; i < ntiles; ++i)
00213 {
00214 const int start = (i*itsV.getSize()) / ntiles;
00215 const int end = ((i+1)*itsV.getSize()) / ntiles;
00216
00217 jobs.push_back
00218 (rutz::make_shared(new EvolveJob
00219 (itsV.beginw() + start,
00220 itsV.beginw() + end,
00221 itsInputCopy.begin() + start,
00222 itsGinput,
00223 dt_c,
00224 gsum,
00225 isum,
00226 itsEi,
00227 itsVth)));
00228
00229 srv.enqueueJob(jobs.back());
00230 }
00231
00232 for (size_t i = 0; i < jobs.size(); ++i)
00233 {
00234 jobs[i]->wait();
00235
00236 if (jobs[i]->vwinner != Image<float>::iterator())
00237 {
00238 const size_t offset = jobs[i]->vwinner - itsV.beginw();
00239 winner.i = offset % itsV.getWidth();
00240 winner.j = offset / itsV.getWidth();
00241
00242
00243 itsGIN_Ge = itsGleak * 10.0F;
00244 }
00245 }
00246
00247 itsGe = 0.0F;
00248 itsGi = 0.0F;
00249
00250
00251
00252
00253
00254
00255
00256 const float dt_c2 = float(dt.secs()) / itsGIN_C;
00257
00258 itsGIN_V += dt_c2 *
00259 (- itsGIN_Gl * (itsGIN_V - itsGIN_El)
00260 - itsGIN_Ge * (itsGIN_V - itsGIN_Ee));
00261
00262
00263 if (itsGIN_V < itsGIN_Ei) itsGIN_V = itsGIN_Ei;
00264
00265
00266 if (itsGIN_V >= itsGIN_Vth)
00267 {
00268 itsGIN_V = 0.0F;
00269 this->inhibit();
00270 }
00271 }
00272
00273
00274 itsT = t;
00275 }
00276
00277
00278 Image<float> WinnerTakeAllStdOptim::getV() const
00279 {
00280 GVX_TRACE(__PRETTY_FUNCTION__);
00281 return itsV;
00282 }
00283
00284
00285 void WinnerTakeAllStdOptim::inhibit()
00286 {
00287 GVX_TRACE(__PRETTY_FUNCTION__);
00288 itsGe = 0.0F;
00289 itsGi = itsGinh;
00290 itsGIN_Ge = 0.0F;
00291 LDEBUG("WTA inhibition firing...");
00292 }
00293
00294
00295 void WinnerTakeAllStdOptim::saccadicSuppression(const bool on)
00296 {
00297 GVX_TRACE(__PRETTY_FUNCTION__);
00298 if (itsUseSaccadicSuppression.getVal() == false) return;
00299 if (on) inhibit();
00300 LINFO("------- WTA saccadic suppression %s -------", on ? "on":"off");
00301 }
00302
00303
00304 void WinnerTakeAllStdOptim::blinkSuppression(const bool on)
00305 {
00306 GVX_TRACE(__PRETTY_FUNCTION__);
00307 if (itsUseBlinkSuppression.getVal() == false) return;
00308 if (on) inhibit();
00309 LINFO("------- WTA blink suppression %s -------", on ? "on":"off");
00310 }
00311
00312
00313
00314
00315
00316
00317
00318
00319 #endif // NEURO_WINNERTAKEALLSTDOPTIM_C_DEFINED