00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Channels/ChannelOpts.H"
00039 #include "Channels/Jet.H"
00040 #include "Channels/JetFiller.H"
00041 #include "Component/ModelManager.H"
00042 #include "Image/DrawOps.H"
00043 #include "Image/MathOps.H"
00044 #include "Image/ShapeOps.H"
00045 #include "Image/Transforms.H"
00046 #include "Media/MediaSimEvents.H"
00047 #include "Neuro/NeuroOpts.H"
00048 #include "Neuro/NeuroSimEvents.H"
00049 #include "Neuro/SimulationViewerStd.H"
00050 #include "Neuro/StdBrain.H"
00051 #include "Neuro/VisualCortex.H"
00052 #include "Raster/Raster.H"
00053 #include "Simulation/SimEventQueueConfigurator.H"
00054 #include "BPnnet/BPnnet.H"
00055
00056 #include <fstream>
00057 #include <iostream>
00058
00059
00060 #define NHIDDEN 64
00061
00062
00063 #define NNEG 0
00064
00065
00066 #define SRANGE 3
00067
00068
00069 #define SSTEP 3
00070
00071 int main(const int argc, const char** argv)
00072 {
00073
00074 ModelManager manager("Attention Model");
00075
00076
00077 bool do_jet = false, do_train = false, do_reco = false, do_coords = false;
00078 if (argc > 1)
00079 {
00080 if (strcmp(argv[1], "train") == 0) do_train = true;
00081 else if (strcmp(argv[1], "reco") == 0) do_reco = true;
00082 else if (strcmp(argv[1], "coords") == 0) do_coords = true;
00083 else if (strcmp(argv[1], "jet") == 0) do_jet = true;
00084 else { LERROR("Incorrect argument(s). See USAGE."); return 1; }
00085 }
00086 else
00087 {
00088 LERROR("USAGE:\n %s jet <label> <img.ppm> <x> <y>\n"
00089 " %s train <param> <jetfile> <eta>\n"
00090 " %s reco <param> <img.ppm>\n"
00091 " %s coords <targetmap.pgm> <img.ppm> <label>\n"
00092 "where <param> is the stem for parameter files.",
00093 argv[0], argv[0], argv[0], argv[0]);
00094 return 1;
00095 }
00096
00097 initRandomNumbers();
00098
00099
00100 Image< PixRGB<byte> > col_image;
00101 if (do_jet || do_reco || do_coords)
00102 col_image = Raster::ReadRGB(argv[3]);
00103
00104
00105 nub::soft_ref<SimEventQueueConfigurator>
00106 seqc(new SimEventQueueConfigurator(manager));
00107 manager.addSubComponent(seqc);
00108 nub::soft_ref<StdBrain> brain(new StdBrain(manager));
00109 manager.addSubComponent(brain);
00110 manager.exportOptions(MC_RECURSE);
00111 manager.setOptionValString(&OPT_RawVisualCortexChans, "IOC");
00112 nub::soft_ref<SimEventQueue> seq = seqc->getQ();
00113
00114
00115 manager.start();
00116 const uint nborient =
00117 manager.getModelParamVal<uint>("NumOrientations", MC_RECURSE);
00118
00119
00120 JetSpec *jj = new JetSpec; int jlev = 2, jdepth = 5;
00121 jj->addIndexRange(RG, RAW, jlev, jlev + jdepth - 1);
00122 jj->addIndexRange(BY, RAW, jlev, jlev + jdepth - 1);
00123 jj->addIndexRange(INTENS, RAW, jlev, jlev + jdepth - 1);
00124 jj->addIndexRange(ORI, RAW, 0, nborient - 1);
00125 jj->addIndexRange(ORI, RAW, jlev, jlev + jdepth - 1);
00126 rutz::shared_ptr<JetSpec> js(jj);
00127
00128
00129 Jet<float> j(js);
00130
00131 if (do_jet || do_reco || do_coords)
00132 {
00133 rutz::shared_ptr<SimEventInputFrame>
00134 e(new SimEventInputFrame(brain.get(), GenericFrame(col_image), 0));
00135 seq->post(e);
00136 }
00137
00138
00139 if (do_jet)
00140 {
00141
00142
00143 LFATAL("fixme");
00144
00145
00146
00147
00148
00149
00150 return 0;
00151 }
00152
00153
00154 if (do_coords)
00155 {
00156 Image<byte> tmap = Raster::ReadGray(argv[2]);
00157 PixRGB<byte> yellowPix(255, 255, 0), greenPix(0, 255, 0);
00158
00159
00160 Image<byte> blownup = chamfer34(tmap);
00161 blownup = binaryReverse(blownup, byte(255));
00162 blownup -= 240; blownup *= 255;
00163 Image<float> sm;
00164
00165
00166 LFATAL("fixme");
00167
00168
00169
00170 Point2D<int> p; float mval;
00171 findMax(sm, p, mval);
00172
00173 LINFO("===== Max Saliency %g at (%d, %d) =====", mval/120.0f, p.i, p.j);
00174 for (int jj = -SRANGE; jj <= SRANGE; jj += SSTEP)
00175 for (int ii = -SRANGE; ii <= SRANGE; ii += SSTEP)
00176 {
00177 Point2D<int> pp;
00178 pp.i = std::max(0, std::min(col_image.getWidth()-1, p.i + ii));
00179 pp.j = std::max(0, std::min(col_image.getHeight()-1, p.j + jj));
00180 LINFO("===== Positive Sample at (%d, %d) =====", pp.i, pp.j);
00181 std::cout<<argv[4]<<' '<<argv[3]<<' '<<pp.i<<' '<<pp.j<<std::endl;
00182 }
00183 drawPatch(col_image, p, 3, yellowPix);
00184 drawCircle(col_image, p, 40, yellowPix, 2);
00185
00186
00187 LFATAL("fixme");
00188
00189
00190 for (int i = 0; i < NNEG; i ++)
00191 {
00192 findMax(sm, p, mval);
00193
00194 LINFO("===== Negative Sample at (%d, %d) =====", p.i, p.j);
00195 std::cout<<"unknown "<<argv[3]<<' '<<p.i<<' '<<p.j<<std::endl;
00196
00197 drawDisk(sm, p, std::max(col_image.getWidth(),
00198 col_image.getHeight()) / 12, 0.0f);
00199 drawPatch(col_image, p, 3, greenPix);
00200 drawCircle(col_image, p, 40, greenPix, 2);
00201 }
00202
00203
00204
00205
00206
00207 return 0;
00208 }
00209
00210
00211 KnowledgeBase kb; char kn[256]; strcpy(kn, argv[2]); strcat(kn, "_kb.txt");
00212 kb.load(kn);
00213
00214
00215 int numHidden = NHIDDEN;
00216 BPnnet net(js->getJetSize(), numHidden, &kb);
00217 if (net.load(argv[2]) == false) net.randomizeWeights();
00218
00219 if (do_train)
00220 {
00221 std::ifstream s(argv[3]);
00222 if (s.is_open() == false) LFATAL("Cannot read %s", argv[3]);
00223 double eta = atof(argv[4]);
00224 char buf[256]; double rms = 0.0; int nb = 0;
00225 while(!s.eof())
00226 {
00227 s.get(buf, 256, ' ');
00228 if (strlen(buf) > 1)
00229 {
00230 SimpleVisualObject vo(buf);
00231 s>>j; s.getline(buf, 256);
00232
00233 rms += net.train(j, vo, eta); nb ++;
00234
00235 net.normalizeWeights();
00236 }
00237 }
00238 s.close();
00239 rms = sqrt(rms / (double)nb);
00240 LINFO("Trained %d jets, eta=%.10f: RMS=%.10f", nb, eta, rms);
00241 net.save(argv[2]);
00242
00243 std::cout<<rms<<std::endl;
00244 return 0;
00245 }
00246
00247 if (do_reco)
00248 {
00249 bool keep_going = true;
00250 while(keep_going)
00251 {
00252 (void) seq->evolve();
00253
00254
00255 if (SeC<SimEventWTAwinner> e = seq->check<SimEventWTAwinner>(0))
00256 {
00257 const Point2D<int> winner = e->winner().p;
00258
00259 LINFO("##### Winner (%d,%d) at %fms #####",
00260 winner.i, winner.j, seq->now().msecs());
00261 Image< PixRGB<byte> > ctmp;
00262 Raster::VisuRGB(ctmp, sformat("traj_%s.ppm", argv[3]));
00263
00264 LFATAL("fixme");
00265
00266
00267
00268
00269 SimpleVisualObject vo;
00270 if (net.recognize(j, vo))
00271 LINFO("##### Recognized: %s #####", vo.getName());
00272 else
00273 LINFO("##### Not Recognized #####");
00274
00275 std::cout<<"<<<< press [RETURN] to continue >>>"<<std::endl;
00276 getchar();
00277 if (seq->now().secs() > 3.0)
00278 { LINFO("##### Time limit reached #####"); keep_going = false;}
00279 }
00280 }
00281 }
00282
00283 return 0;
00284 }
00285
00286
00287
00288
00289
00290
00291
00292