00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Channels/ChannelBase.H"
00039 #include "Channels/Jet.H"
00040 #include "Channels/JetFiller.H"
00041 #include "Component/ModelManager.H"
00042 #include "Image/DrawOps.H"
00043 #include "Image/MathOps.H"
00044 #include "Image/Pixels.H"
00045 #include "Image/ShapeOps.H"
00046 #include "Media/MediaSimEvents.H"
00047 #include "Channels/ChannelOpts.H"
00048 #include "Channels/RawVisualCortex.H"
00049 #include "Simulation/SimEventQueueConfigurator.H"
00050 #include "Simulation/SimEventQueue.H"
00051 #include "Raster/Raster.H"
00052 #include "Util/log.H"
00053
00054 #include <fstream>
00055 #include <string>
00056
00057 #define SMIN 1
00058 #define SMAX 3
00059
00060 namespace
00061 {
00062
00063 Image<float> getRaoJetMap(RawVisualCortex& vc,
00064 const Jet<float>& targ,
00065 const int smin, const int smax)
00066 {
00067 const int w = vc.getInputDims().w() >> smin;
00068 const int h = vc.getInputDims().h() >> smin;
00069 Image<float> result(w, h, NO_INIT);
00070 Jet<float> currJet(targ.getSpec());
00071
00072
00073
00074
00075 for (int x = 0; x < w; x++)
00076 for (int y = 0; y < h; y++)
00077 {
00078 JetFiller f(Point2D<int>(x << smin, y << smin), currJet, false);
00079 vc.accept(f);
00080 result.setVal(x, y, float(raodistance(targ, currJet, smin, smax)));
00081 }
00082 return result;
00083 }
00084 }
00085
00086 int main(const int argc, const char** argv)
00087 {
00088
00089 ModelManager manager("Test Rao Model");
00090
00091
00092 nub::soft_ref<SimEventQueueConfigurator>
00093 seqc(new SimEventQueueConfigurator(manager));
00094 manager.addSubComponent(seqc);
00095
00096 nub::soft_ref<RawVisualCortex> vcx(new RawVisualCortex(manager));
00097 manager.addSubComponent(vcx);
00098 manager.setOptionValString(&OPT_RawVisualCortexChans, "ICO");
00099
00100
00101 if (manager.parseCommandLine(argc, argv,
00102 "jet <image.ppm> <x> <y> <jet.txt> -OR- "
00103 "jet2 <image.ppm> <target.pgm> <jet.txt> -OR- "
00104 "snr <image.ppm> <target.pgm> <jet.txt> <lambda> <map.pgm> -OR- "
00105 "search <image.ppm> <jet.txt> <lambda>",
00106 4, 6) == false)
00107 return(1);
00108
00109
00110 manager.setModelParamVal("GaborChannelIntensity", 100.0, MC_RECURSE);
00111 nub::soft_ref<SimEventQueue> seq = seqc->getQ();
00112
00113 bool do_jet = false, do_jet2 = false, do_snr = false, do_search = false;
00114 std::string action = manager.getExtraArg(0);
00115 if (action.compare("jet") == 0)
00116 {
00117 if (manager.numExtraArgs() != 5)
00118 LFATAL("USAGE: %s jet <image.ppm> <x> <y> <jet.txt>", argv[0]);
00119 do_jet = true;
00120 }
00121 else if (action.compare("jet2") == 0)
00122 {
00123 if (manager.numExtraArgs() != 4)
00124 LFATAL("USAGE: %s jet2 <image.ppm> <target.pgm> <jet.txt>", argv[0]);
00125 do_jet2 = true;
00126 }
00127 else if (action.compare("snr") == 0)
00128 {
00129 if (manager.numExtraArgs() != 6)
00130 LFATAL("USAGE: %s snr <image.ppm> <target.pgm> <jet.txt> <lambda> <map.pgm>", argv[0]);
00131 do_snr = true;
00132 }
00133 else if (action.compare("search") == 0)
00134 {
00135 if (manager.numExtraArgs() != 4)
00136 LFATAL("USAGE: %s search <image.ppm> <jet.txt> <lambda>", argv[0]);
00137 do_search = true;
00138 }
00139 else
00140 LFATAL("Incorrect usage -- try to run without args to see usage.");
00141
00142
00143 Image< PixRGB<byte> > col_image =
00144 Raster::ReadRGB(manager.getExtraArg(1));
00145
00146
00147 manager.start();
00148
00149
00150 vcx->input(InputFrame::fromRgb(&col_image));
00151 vcx->getOutput();
00152
00153 rutz::shared_ptr<JetSpec> js(new JetSpec);
00154
00155
00156 js->addIndexRange(COLBAND, RAW, 0, 5);
00157 js->addIndexRange(COLBAND, RAW, SMIN, SMAX);
00158 js->addIndexRange(INTENS, RAW, 0, 3);
00159 js->addIndexRange(INTENS, RAW, SMIN, SMAX);
00160
00161
00162
00163
00164
00165
00166
00167 const uint nori = vcx->subChan("orientation")->getModelParamVal<uint>("NumOrientations");
00168 js->addIndexRange(ORI, RAW, 0, nori - 1);
00169 js->addIndexRange(ORI, RAW, SMIN, SMAX);
00170 js->print();
00171
00172
00173
00174 if (do_jet)
00175 {
00176 int x = manager.getExtraArgAs<int>(2);
00177 int y = manager.getExtraArgAs<int>(3);
00178 Point2D<int> p(x, y);
00179
00180
00181 Jet<float> j(js);
00182 JetFiller f(p, j, false);
00183 vcx->accept(f);
00184
00185
00186 std::ofstream s(manager.getExtraArg(4).c_str());
00187 if (s.is_open() == false)
00188 LFATAL("Cannot write %s", manager.getExtraArg(4).c_str());
00189 s<<j<<std::endl;
00190 s.close();
00191 LINFO("Saved Jet(%d, %d) to %s -- DONE.", x, y,
00192 manager.getExtraArg(4).c_str());
00193 }
00194
00195
00196
00197 if (do_jet2)
00198 {
00199
00200 Image<byte> target =
00201 Raster::ReadGray(manager.getExtraArg(2));
00202
00203
00204 int w = target.getWidth(), h = target.getHeight();
00205 int t = h+1, b = -1, l = w+1, r = -1;
00206 for (int i = 0; i < w; i++)
00207 for (int j = 0; j < h; j++)
00208 if (target.getVal(i,j) > 200){
00209 if (j < t) t = j;
00210 if (i < l) l = i;
00211 if (i > r) r = i;
00212 if (j > b) b = j;
00213 }
00214 int x = (r + l) / 2;
00215 int y = (t + b)/ 2;
00216 Point2D<int> p(x, y);
00217
00218
00219 Jet<float> j(js);
00220 JetFiller f(p, j, false);
00221 vcx->accept(f);
00222
00223
00224 std::ofstream s(manager.getExtraArg(3).c_str());
00225 if (s.is_open() == false)
00226 LFATAL("Cannot write %s", manager.getExtraArg(3).c_str());
00227 s<<j<<std::endl;
00228 s.close();
00229 LINFO("Saved Jet(%d, %d) to %s -- DONE.", x, y,
00230 manager.getExtraArg(3).c_str());
00231 }
00232
00233
00234
00235 if (do_snr)
00236 {
00237
00238 Image<byte> targetMask =
00239 Raster::ReadGray(manager.getExtraArg(2));
00240
00241 std::ifstream s(manager.getExtraArg(3).c_str());
00242 if (s.is_open() == false)
00243 LFATAL("Cannot read %s", manager.getExtraArg(3).c_str());
00244 Jet<float> j(js);
00245 s>>j; s.close();
00246
00247 float lambda = manager.getExtraArgAs<float>(4);
00248
00249
00250 Image<float> jmap = getRaoJetMap(*vcx, j, SMAX, SMAX);
00251 float mi, ma; getMinMax(jmap, mi, ma);
00252 LINFO("jmap range [%f .. %f]", mi, ma);
00253
00254
00255
00256
00257 Image<float> fmap = exp(jmap * (-1.0f / (lambda*65535.0f)));
00258 float denom = sum(fmap); fmap *= 1.0f / denom;
00259
00260
00261 Image<float> targetMap(fmap.getDims(), ZEROS);
00262 float sT = 0.0f, sD = 0.0f;
00263 float BG_FIRING_RATE = 0.1f;
00264 if (targetMask.initialized()){
00265
00266 if (targetMask.getWidth() > fmap.getWidth())
00267 targetMap = downSize(targetMask, fmap.getDims());
00268 else if (targetMask.getWidth() < fmap.getWidth())
00269 targetMap = rescale(targetMask, fmap.getDims());
00270 else
00271 targetMap = targetMask;
00272
00273
00274 Image<float>::const_iterator aptr = targetMap.begin(),
00275 astop = targetMap.end();
00276 Image<float>::const_iterator sptr = fmap.begin(),
00277 sstop = fmap.end();
00278 while (aptr != astop && sptr != sstop)
00279 {
00280 if (*aptr > 0.0f){
00281 if (sT < *sptr)
00282 sT = *sptr;
00283 }
00284 else {
00285 if (sD < *sptr)
00286 sD = *sptr;
00287 }
00288 aptr++; sptr++;
00289 }
00290 }
00291
00292 float SNR = log((sT + BG_FIRING_RATE) / (sD + BG_FIRING_RATE));
00293 LINFO ("sT = %f, sD = %f", sT, sD);
00294 LINFO ("-------------- SNR = %f dB", SNR);
00295
00296
00297 FILE * fout = fopen("snr", "w");
00298 fprintf(fout, " SNR = %f dB", SNR);
00299 fclose(fout);
00300
00301
00302 Raster::WriteFloat(fmap, FLOAT_NORM_0_255, manager.getExtraArg(5), RASFMT_PNM);
00303
00304 LINFO("Saved Fmap to %s -- DONE.", manager.getExtraArg(5).c_str());
00305 }
00306
00307
00308
00309 if (do_search)
00310 {
00311 std::ifstream s(manager.getExtraArg(2).c_str());
00312 if (s.is_open() == false)
00313 LFATAL("Cannot read %s", manager.getExtraArg(2).c_str());
00314 Jet<float> j(js);
00315 s>>j; s.close();
00316
00317
00318
00319 float lambda = atof(manager.getExtraArg(3).c_str());
00320 for (int k = SMAX; k >= SMIN; k --)
00321 {
00322 LINFO("Using scales [%d .. %d], lambda = %f", k, SMAX, lambda);
00323
00324
00325 Image<float> jmap = getRaoJetMap(*vcx, j, k, SMAX);
00326 float mi, ma; getMinMax(jmap, mi, ma);
00327 LINFO("jmap range [%f .. %f]", mi, ma);
00328
00329
00330
00331
00332 Image<float> fmap = exp(jmap * (-1.0f / (lambda*65535.0f)));
00333 float denom = sum(fmap); fmap *= 1.0f / denom;
00334
00335
00336 float xhat = 0.0f, yhat = 0.0f;
00337 int w = fmap.getWidth(), h = fmap.getHeight();
00338 for (int jj = 0; jj < h; jj ++)
00339 for (int ii = 0; ii < w; ii ++)
00340 {
00341 float val = fmap.getVal(ii, jj);
00342 xhat += ii * val;
00343 yhat += jj * val;
00344 }
00345 getMinMax(fmap, mi, ma);
00346 LINFO("fmap range [%f .. %f]", mi, ma);
00347 Point2D<int> win(int(xhat + 0.499f) << k, int(yhat + 0.499f) << k);
00348 LINFO("Saccade to (%d, %d)", win.i, win.j);
00349
00350
00351 Raster::VisuFloat(fmap, FLOAT_NORM_0_255, "fmap.pgm");
00352 Image< PixRGB<byte> > traj(col_image);
00353 drawPatch(traj, win, 5, PixRGB<byte>(255, 255, 0));
00354 int foar = std::max(traj.getWidth(), traj.getHeight()) / 12;
00355 drawCircle(traj, win, foar, PixRGB<byte>(255, 255, 0), 3);
00356 Raster::VisuRGB(traj, "traj.ppm");
00357
00358
00359
00360 lambda *= 0.5f;
00361 }
00362 }
00363
00364
00365 manager.stop();
00366 return 0;
00367 }
00368
00369
00370
00371
00372
00373