test-BPnnet.C

Go to the documentation of this file.
00001 /*!@file AppNeuro/test-BPnnet.C Test BPnnet class */
00002 
00003 // //////////////////////////////////////////////////////////////////// //
00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2003   //
00005 // by the University of Southern California (USC) and the iLab at USC.  //
00006 // See http://iLab.usc.edu for information about this project.          //
00007 // //////////////////////////////////////////////////////////////////// //
00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected //
00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency //
00010 // in Visual Environments, and Applications'' by Christof Koch and      //
00011 // Laurent Itti, California Institute of Technology, 2001 (patent       //
00012 // pending; application number 09/912,225 filed July 23, 2001; see      //
00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status).     //
00014 // //////////////////////////////////////////////////////////////////// //
00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit.       //
00016 //                                                                      //
00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can   //
00018 // redistribute it and/or modify it under the terms of the GNU General  //
00019 // Public License as published by the Free Software Foundation; either  //
00020 // version 2 of the License, or (at your option) any later version.     //
00021 //                                                                      //
00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope  //
00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the   //
00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR      //
00025 // PURPOSE.  See the GNU General Public License for more details.       //
00026 //                                                                      //
00027 // You should have received a copy of the GNU General Public License    //
00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write   //
00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330,   //
00030 // Boston, MA 02111-1307 USA.                                           //
00031 // //////////////////////////////////////////////////////////////////// //
00032 //
00033 // Primary maintainer for this file: Philip Williams <plw@usc.edu>
00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/AppNeuro/test-BPnnet.C $
00035 // $Id: test-BPnnet.C 10982 2009-03-05 05:11:22Z itti $
00036 //
00037 
00038 #include "Channels/ChannelOpts.H"
00039 #include "Channels/Jet.H"
00040 #include "Channels/JetFiller.H"
00041 #include "Component/ModelManager.H"
00042 #include "Image/DrawOps.H"
00043 #include "Image/MathOps.H"
00044 #include "Image/ShapeOps.H"
00045 #include "Image/Transforms.H"
00046 #include "Media/MediaSimEvents.H"
00047 #include "Neuro/NeuroOpts.H"
00048 #include "Neuro/NeuroSimEvents.H"
00049 #include "Neuro/SimulationViewerStd.H"
00050 #include "Neuro/StdBrain.H"
00051 #include "Neuro/VisualCortex.H"
00052 #include "Raster/Raster.H"
00053 #include "Simulation/SimEventQueueConfigurator.H"
00054 #include "BPnnet/BPnnet.H"
00055 
00056 #include <fstream>
00057 #include <iostream>
00058 
00059 //! number of hidden units
00060 #define NHIDDEN 64
00061 
00062 //! number of negative samples to generate
00063 #define NNEG 0 /*5*/
00064 
00065 //! sampling range in pixels for positive examples around most salient location
00066 #define SRANGE 3
00067 
00068 //! sampling step for positive examples, in pixels
00069 #define SSTEP 3
00070 
00071 int main(const int argc, const char** argv)
00072 {
00073   // Instantiate a ModelManager:
00074   ModelManager manager("Attention Model");
00075 
00076   // either run train, recog, or both:
00077   bool do_jet = false, do_train = false, do_reco = false, do_coords = false;
00078   if (argc > 1)
00079     {
00080       if (strcmp(argv[1], "train") == 0) do_train = true;
00081       else if (strcmp(argv[1], "reco") == 0) do_reco = true;
00082       else if (strcmp(argv[1], "coords") == 0) do_coords = true;
00083       else if (strcmp(argv[1], "jet") == 0) do_jet = true;
00084       else { LERROR("Incorrect argument(s).  See USAGE."); return 1; }
00085     }
00086   else
00087     {
00088       LERROR("USAGE:\n  %s jet <label> <img.ppm> <x> <y>\n"
00089              "  %s train <param> <jetfile> <eta>\n"
00090              "  %s reco <param> <img.ppm>\n"
00091              "  %s coords <targetmap.pgm> <img.ppm> <label>\n"
00092              "where <param> is the stem for parameter files.",
00093              argv[0], argv[0], argv[0], argv[0]);
00094       return 1;
00095     }
00096 
00097   initRandomNumbers();
00098 
00099   // Read input image from disk:
00100   Image< PixRGB<byte> > col_image;
00101   if (do_jet || do_reco || do_coords)
00102     col_image = Raster::ReadRGB(argv[3]);
00103 
00104   // create brain:
00105   nub::soft_ref<SimEventQueueConfigurator>
00106     seqc(new SimEventQueueConfigurator(manager));
00107   manager.addSubComponent(seqc);
00108   nub::soft_ref<StdBrain> brain(new StdBrain(manager));
00109   manager.addSubComponent(brain);
00110   manager.exportOptions(MC_RECURSE);
00111   manager.setOptionValString(&OPT_RawVisualCortexChans, "IOC");
00112   nub::soft_ref<SimEventQueue> seq = seqc->getQ();
00113 
00114   // get model started:
00115   manager.start();
00116   const uint nborient =
00117     manager.getModelParamVal<uint>("NumOrientations", MC_RECURSE);
00118 
00119   // build custom JetSpec for our Jets:
00120   JetSpec *jj = new JetSpec; int jlev = 2, jdepth = 5;
00121   jj->addIndexRange(RG, RAW, jlev, jlev + jdepth - 1);
00122   jj->addIndexRange(BY, RAW, jlev, jlev + jdepth - 1);
00123   jj->addIndexRange(INTENS, RAW, jlev, jlev + jdepth - 1);
00124   jj->addIndexRange(ORI, RAW, 0, nborient - 1);  // orientations
00125   jj->addIndexRange(ORI, RAW, jlev, jlev + jdepth - 1);
00126   rutz::shared_ptr<JetSpec> js(jj);
00127 
00128   // initialize a Jet according to our JetSpec:
00129   Jet<float> j(js);
00130 
00131   if (do_jet || do_reco || do_coords)
00132     {
00133       rutz::shared_ptr<SimEventInputFrame>
00134         e(new SimEventInputFrame(brain.get(), GenericFrame(col_image), 0));
00135       seq->post(e); // post the image to the brain
00136     }
00137 
00138   // do we want to extract a jet?
00139   if (do_jet)
00140     {
00141       // get a jet at specified coordinates:
00142 
00143       LFATAL("fixme");
00144       /*
00145       Point2D<int> p(atoi(argv[4]), atoi(argv[5]));
00146       JetFiller f(p, j, true);
00147       brain->getVC()->accept(f);
00148       std::cout<<argv[2]<<' '<<j<<std::endl;
00149       */
00150       return 0;
00151     }
00152 
00153   // do we just want to extract a bunch of coordinates from a target mask?
00154   if (do_coords)
00155     {
00156       Image<byte> tmap = Raster::ReadGray(argv[2]);
00157       PixRGB<byte> yellowPix(255, 255, 0), greenPix(0, 255, 0);
00158 
00159       // inflate the target mask a bit:
00160       Image<byte> blownup = chamfer34(tmap);
00161       blownup = binaryReverse(blownup, byte(255));
00162       blownup -= 240; blownup *= 255;  // exploit automatic range clamping
00163       Image<float> sm;
00164 
00165       // mask saliency map by objectmask:
00166       LFATAL("fixme");
00167       ///////////      sm = rescale(brain->getVC()->getOutput(), col_image.getDims()) * blownup;
00168 
00169       // find location of most salient point within target mask
00170       Point2D<int> p; float mval;
00171       findMax(sm, p, mval);
00172 
00173       LINFO("===== Max Saliency %g at (%d, %d) =====", mval/120.0f, p.i, p.j);
00174       for (int jj = -SRANGE; jj <= SRANGE; jj += SSTEP)
00175         for (int ii = -SRANGE; ii <= SRANGE; ii += SSTEP)
00176           {
00177             Point2D<int> pp;
00178             pp.i = std::max(0, std::min(col_image.getWidth()-1, p.i + ii));
00179             pp.j = std::max(0, std::min(col_image.getHeight()-1, p.j + jj));
00180             LINFO("===== Positive Sample at (%d, %d) =====", pp.i, pp.j);
00181             std::cout<<argv[4]<<' '<<argv[3]<<' '<<pp.i<<' '<<pp.j<<std::endl;
00182           }
00183       drawPatch(col_image, p, 3, yellowPix);
00184       drawCircle(col_image, p, 40, yellowPix, 2);
00185 
00186       // now generate a bunch of negative samples, outside target area:
00187       LFATAL("fixme");
00188       //////  sm = rescale(brain->getVC()->getOutput(), col_image.getDims()) * binaryReverse(blownup, byte(120));
00189 
00190       for (int i = 0; i < NNEG; i ++)
00191         {
00192           findMax(sm, p, mval);
00193 
00194           LINFO("===== Negative Sample at (%d, %d) =====", p.i, p.j);
00195           std::cout<<"unknown "<<argv[3]<<' '<<p.i<<' '<<p.j<<std::endl;
00196 
00197           drawDisk(sm, p, std::max(col_image.getWidth(),
00198                                    col_image.getHeight()) / 12, 0.0f);
00199           drawPatch(col_image, p, 3, greenPix);
00200           drawCircle(col_image, p, 40, greenPix, 2);
00201         }
00202 
00203       //Raster::Visu(col_image, "samples.pnm");
00204       //std::cerr<<"<<<< press [RETURN] to exit >>>"<<std::endl;
00205       //getchar();
00206 
00207       return 0;
00208     }
00209 
00210   // read in the knowledge base:
00211   KnowledgeBase kb; char kn[256]; strcpy(kn, argv[2]); strcat(kn, "_kb.txt");
00212   kb.load(kn);
00213 
00214   // Create BPnnet and load from disk
00215   int numHidden = NHIDDEN; // arbitrary for testing
00216   BPnnet net(js->getJetSize(), numHidden, &kb);
00217   if (net.load(argv[2]) == false) net.randomizeWeights();
00218 
00219   if (do_train)
00220     {
00221       std::ifstream s(argv[3]);
00222       if (s.is_open() == false)  LFATAL("Cannot read %s", argv[3]);
00223       double eta = atof(argv[4]);
00224       char buf[256]; double rms = 0.0; int nb = 0;
00225       while(!s.eof())
00226          {
00227            s.get(buf, 256, ' ');
00228            if (strlen(buf) > 1)
00229              {
00230                SimpleVisualObject vo(buf);
00231                s>>j; s.getline(buf, 256);
00232 
00233                rms += net.train(j, vo, eta); nb ++;
00234 
00235                net.normalizeWeights();
00236              }
00237          }
00238       s.close();
00239       rms = sqrt(rms / (double)nb);
00240       LINFO("Trained %d jets, eta=%.10f: RMS=%.10f", nb, eta, rms);
00241       net.save(argv[2]);
00242 
00243       std::cout<<rms<<std::endl;
00244       return 0;
00245     }
00246 
00247   if (do_reco)
00248     {
00249       bool keep_going = true;
00250       while(keep_going)
00251         {
00252           (void) seq->evolve();
00253 
00254 
00255           if (SeC<SimEventWTAwinner> e = seq->check<SimEventWTAwinner>(0))
00256             {
00257               const Point2D<int> winner = e->winner().p;
00258 
00259               LINFO("##### Winner (%d,%d) at %fms #####",
00260                     winner.i, winner.j, seq->now().msecs());
00261               Image< PixRGB<byte> > ctmp;/////////////////FIXME = brain->getSV()->getTraj(seq->now());
00262               Raster::VisuRGB(ctmp, sformat("traj_%s.ppm", argv[3]));
00263 
00264               LFATAL("fixme");
00265 
00266               /////////              JetFiller f(winner, j, true);
00267               ////////              brain->getVC()->accept(f);
00268 
00269               SimpleVisualObject vo;
00270               if (net.recognize(j, vo))
00271                 LINFO("##### Recognized: %s #####", vo.getName());
00272               else
00273                 LINFO("##### Not Recognized #####");
00274 
00275               std::cout<<"<<<< press [RETURN] to continue >>>"<<std::endl;
00276               getchar();
00277               if (seq->now().secs() > 3.0)
00278                 { LINFO("##### Time limit reached #####"); keep_going = false;}
00279             }
00280         }
00281     }
00282 
00283   return 0;
00284 }
00285 
00286 
00287 
00288 // ######################################################################
00289 /* So things look consistent in everyone's emacs... */
00290 /* Local Variables: */
00291 /* indent-tabs-mode: nil */
00292 /* End: */
Generated on Sun May 8 08:04:11 2011 for iLab Neuromorphic Vision Toolkit by  doxygen 1.6.3