00001 /*!@file AppNeuro/test-BPnnet.C Test BPnnet class */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2003 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Philip Williams <plw@usc.edu> 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/AppNeuro/test-BPnnet.C $ 00035 // $Id: test-BPnnet.C 10982 2009-03-05 05:11:22Z itti $ 00036 // 00037 00038 #include "Channels/ChannelOpts.H" 00039 #include "Channels/Jet.H" 00040 #include "Channels/JetFiller.H" 00041 #include "Component/ModelManager.H" 00042 #include "Image/DrawOps.H" 00043 #include "Image/MathOps.H" 00044 #include "Image/ShapeOps.H" 00045 #include "Image/Transforms.H" 00046 #include "Media/MediaSimEvents.H" 00047 #include "Neuro/NeuroOpts.H" 00048 #include "Neuro/NeuroSimEvents.H" 00049 #include "Neuro/SimulationViewerStd.H" 00050 #include "Neuro/StdBrain.H" 00051 #include "Neuro/VisualCortex.H" 00052 #include "Raster/Raster.H" 00053 #include "Simulation/SimEventQueueConfigurator.H" 00054 #include "BPnnet/BPnnet.H" 00055 00056 #include <fstream> 00057 #include <iostream> 00058 00059 //! number of hidden units 00060 #define NHIDDEN 64 00061 00062 //! number of negative samples to generate 00063 #define NNEG 0 /*5*/ 00064 00065 //! sampling range in pixels for positive examples around most salient location 00066 #define SRANGE 3 00067 00068 //! sampling step for positive examples, in pixels 00069 #define SSTEP 3 00070 00071 int main(const int argc, const char** argv) 00072 { 00073 // Instantiate a ModelManager: 00074 ModelManager manager("Attention Model"); 00075 00076 // either run train, recog, or both: 00077 bool do_jet = false, do_train = false, do_reco = false, do_coords = false; 00078 if (argc > 1) 00079 { 00080 if (strcmp(argv[1], "train") == 0) do_train = true; 00081 else if (strcmp(argv[1], "reco") == 0) do_reco = true; 00082 else if (strcmp(argv[1], "coords") == 0) do_coords = true; 00083 else if (strcmp(argv[1], "jet") == 0) do_jet = true; 00084 else { LERROR("Incorrect argument(s). See USAGE."); return 1; } 00085 } 00086 else 00087 { 00088 LERROR("USAGE:\n %s jet <label> <img.ppm> <x> <y>\n" 00089 " %s train <param> <jetfile> <eta>\n" 00090 " %s reco <param> <img.ppm>\n" 00091 " %s coords <targetmap.pgm> <img.ppm> <label>\n" 00092 "where <param> is the stem for parameter files.", 00093 argv[0], argv[0], argv[0], argv[0]); 00094 return 1; 00095 } 00096 00097 initRandomNumbers(); 00098 00099 // Read input image from disk: 00100 Image< PixRGB<byte> > col_image; 00101 if (do_jet || do_reco || do_coords) 00102 col_image = Raster::ReadRGB(argv[3]); 00103 00104 // create brain: 00105 nub::soft_ref<SimEventQueueConfigurator> 00106 seqc(new SimEventQueueConfigurator(manager)); 00107 manager.addSubComponent(seqc); 00108 nub::soft_ref<StdBrain> brain(new StdBrain(manager)); 00109 manager.addSubComponent(brain); 00110 manager.exportOptions(MC_RECURSE); 00111 manager.setOptionValString(&OPT_RawVisualCortexChans, "IOC"); 00112 nub::soft_ref<SimEventQueue> seq = seqc->getQ(); 00113 00114 // get model started: 00115 manager.start(); 00116 const uint nborient = 00117 manager.getModelParamVal<uint>("NumOrientations", MC_RECURSE); 00118 00119 // build custom JetSpec for our Jets: 00120 JetSpec *jj = new JetSpec; int jlev = 2, jdepth = 5; 00121 jj->addIndexRange(RG, RAW, jlev, jlev + jdepth - 1); 00122 jj->addIndexRange(BY, RAW, jlev, jlev + jdepth - 1); 00123 jj->addIndexRange(INTENS, RAW, jlev, jlev + jdepth - 1); 00124 jj->addIndexRange(ORI, RAW, 0, nborient - 1); // orientations 00125 jj->addIndexRange(ORI, RAW, jlev, jlev + jdepth - 1); 00126 rutz::shared_ptr<JetSpec> js(jj); 00127 00128 // initialize a Jet according to our JetSpec: 00129 Jet<float> j(js); 00130 00131 if (do_jet || do_reco || do_coords) 00132 { 00133 rutz::shared_ptr<SimEventInputFrame> 00134 e(new SimEventInputFrame(brain.get(), GenericFrame(col_image), 0)); 00135 seq->post(e); // post the image to the brain 00136 } 00137 00138 // do we want to extract a jet? 00139 if (do_jet) 00140 { 00141 // get a jet at specified coordinates: 00142 00143 LFATAL("fixme"); 00144 /* 00145 Point2D<int> p(atoi(argv[4]), atoi(argv[5])); 00146 JetFiller f(p, j, true); 00147 brain->getVC()->accept(f); 00148 std::cout<<argv[2]<<' '<<j<<std::endl; 00149 */ 00150 return 0; 00151 } 00152 00153 // do we just want to extract a bunch of coordinates from a target mask? 00154 if (do_coords) 00155 { 00156 Image<byte> tmap = Raster::ReadGray(argv[2]); 00157 PixRGB<byte> yellowPix(255, 255, 0), greenPix(0, 255, 0); 00158 00159 // inflate the target mask a bit: 00160 Image<byte> blownup = chamfer34(tmap); 00161 blownup = binaryReverse(blownup, byte(255)); 00162 blownup -= 240; blownup *= 255; // exploit automatic range clamping 00163 Image<float> sm; 00164 00165 // mask saliency map by objectmask: 00166 LFATAL("fixme"); 00167 /////////// sm = rescale(brain->getVC()->getOutput(), col_image.getDims()) * blownup; 00168 00169 // find location of most salient point within target mask 00170 Point2D<int> p; float mval; 00171 findMax(sm, p, mval); 00172 00173 LINFO("===== Max Saliency %g at (%d, %d) =====", mval/120.0f, p.i, p.j); 00174 for (int jj = -SRANGE; jj <= SRANGE; jj += SSTEP) 00175 for (int ii = -SRANGE; ii <= SRANGE; ii += SSTEP) 00176 { 00177 Point2D<int> pp; 00178 pp.i = std::max(0, std::min(col_image.getWidth()-1, p.i + ii)); 00179 pp.j = std::max(0, std::min(col_image.getHeight()-1, p.j + jj)); 00180 LINFO("===== Positive Sample at (%d, %d) =====", pp.i, pp.j); 00181 std::cout<<argv[4]<<' '<<argv[3]<<' '<<pp.i<<' '<<pp.j<<std::endl; 00182 } 00183 drawPatch(col_image, p, 3, yellowPix); 00184 drawCircle(col_image, p, 40, yellowPix, 2); 00185 00186 // now generate a bunch of negative samples, outside target area: 00187 LFATAL("fixme"); 00188 ////// sm = rescale(brain->getVC()->getOutput(), col_image.getDims()) * binaryReverse(blownup, byte(120)); 00189 00190 for (int i = 0; i < NNEG; i ++) 00191 { 00192 findMax(sm, p, mval); 00193 00194 LINFO("===== Negative Sample at (%d, %d) =====", p.i, p.j); 00195 std::cout<<"unknown "<<argv[3]<<' '<<p.i<<' '<<p.j<<std::endl; 00196 00197 drawDisk(sm, p, std::max(col_image.getWidth(), 00198 col_image.getHeight()) / 12, 0.0f); 00199 drawPatch(col_image, p, 3, greenPix); 00200 drawCircle(col_image, p, 40, greenPix, 2); 00201 } 00202 00203 //Raster::Visu(col_image, "samples.pnm"); 00204 //std::cerr<<"<<<< press [RETURN] to exit >>>"<<std::endl; 00205 //getchar(); 00206 00207 return 0; 00208 } 00209 00210 // read in the knowledge base: 00211 KnowledgeBase kb; char kn[256]; strcpy(kn, argv[2]); strcat(kn, "_kb.txt"); 00212 kb.load(kn); 00213 00214 // Create BPnnet and load from disk 00215 int numHidden = NHIDDEN; // arbitrary for testing 00216 BPnnet net(js->getJetSize(), numHidden, &kb); 00217 if (net.load(argv[2]) == false) net.randomizeWeights(); 00218 00219 if (do_train) 00220 { 00221 std::ifstream s(argv[3]); 00222 if (s.is_open() == false) LFATAL("Cannot read %s", argv[3]); 00223 double eta = atof(argv[4]); 00224 char buf[256]; double rms = 0.0; int nb = 0; 00225 while(!s.eof()) 00226 { 00227 s.get(buf, 256, ' '); 00228 if (strlen(buf) > 1) 00229 { 00230 SimpleVisualObject vo(buf); 00231 s>>j; s.getline(buf, 256); 00232 00233 rms += net.train(j, vo, eta); nb ++; 00234 00235 net.normalizeWeights(); 00236 } 00237 } 00238 s.close(); 00239 rms = sqrt(rms / (double)nb); 00240 LINFO("Trained %d jets, eta=%.10f: RMS=%.10f", nb, eta, rms); 00241 net.save(argv[2]); 00242 00243 std::cout<<rms<<std::endl; 00244 return 0; 00245 } 00246 00247 if (do_reco) 00248 { 00249 bool keep_going = true; 00250 while(keep_going) 00251 { 00252 (void) seq->evolve(); 00253 00254 00255 if (SeC<SimEventWTAwinner> e = seq->check<SimEventWTAwinner>(0)) 00256 { 00257 const Point2D<int> winner = e->winner().p; 00258 00259 LINFO("##### Winner (%d,%d) at %fms #####", 00260 winner.i, winner.j, seq->now().msecs()); 00261 Image< PixRGB<byte> > ctmp;/////////////////FIXME = brain->getSV()->getTraj(seq->now()); 00262 Raster::VisuRGB(ctmp, sformat("traj_%s.ppm", argv[3])); 00263 00264 LFATAL("fixme"); 00265 00266 ///////// JetFiller f(winner, j, true); 00267 //////// brain->getVC()->accept(f); 00268 00269 SimpleVisualObject vo; 00270 if (net.recognize(j, vo)) 00271 LINFO("##### Recognized: %s #####", vo.getName()); 00272 else 00273 LINFO("##### Not Recognized #####"); 00274 00275 std::cout<<"<<<< press [RETURN] to continue >>>"<<std::endl; 00276 getchar(); 00277 if (seq->now().secs() > 3.0) 00278 { LINFO("##### Time limit reached #####"); keep_going = false;} 00279 } 00280 } 00281 } 00282 00283 return 0; 00284 } 00285 00286 00287 00288 // ###################################################################### 00289 /* So things look consistent in everyone's emacs... */ 00290 /* Local Variables: */ 00291 /* indent-tabs-mode: nil */ 00292 /* End: */