test-SIFT.C

Go to the documentation of this file.
00001 /*! @file SIFT/test-SIFT.C test the SIFT alg */
00002 
00003 // //////////////////////////////////////////////////////////////////// //
00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005   //
00005 // by the University of Southern California (USC) and the iLab at USC.  //
00006 // See http://iLab.usc.edu for information about this project.          //
00007 // //////////////////////////////////////////////////////////////////// //
00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected //
00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency //
00010 // in Visual Environments, and Applications'' by Christof Koch and      //
00011 // Laurent Itti, California Institute of Technology, 2001 (patent       //
00012 // pending; application number 09/912,225 filed July 23, 2001; see      //
00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status).     //
00014 // //////////////////////////////////////////////////////////////////// //
00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit.       //
00016 //                                                                      //
00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can   //
00018 // redistribute it and/or modify it under the terms of the GNU General  //
00019 // Public License as published by the Free Software Foundation; either  //
00020 // version 2 of the License, or (at your option) any later version.     //
00021 //                                                                      //
00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope  //
00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the   //
00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR      //
00025 // PURPOSE.  See the GNU General Public License for more details.       //
00026 //                                                                      //
00027 // You should have received a copy of the GNU General Public License    //
00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write   //
00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330,   //
00030 // Boston, MA 02111-1307 USA.                                           //
00031 // //////////////////////////////////////////////////////////////////// //
00032 //
00033 // Primary maintainer for this file: Lior Elazary <elazary@usc.edu>
00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/SIFT/test-SIFT.C $
00035 // $Id: test-SIFT.C 14376 2011-01-11 02:44:34Z pez $
00036 //
00037 
00038 
00039 #include "Component/ModelManager.H"
00040 #include "Image/Image.H"
00041 #include "Image/Pixels.H"
00042 #include "Media/FrameSeries.H"
00043 #include "Util/Timer.H"
00044 #include "GUI/XWinManaged.H"
00045 #include "SIFT/ScaleSpace.H"
00046 #include "SIFT/VisualObject.H"
00047 #include "SIFT/Keypoint.H"
00048 #include "SIFT/VisualObjectDB.H"
00049 #include "Image/CutPaste.H"
00050 #include "Image/ShapeOps.H"
00051 #include "Image/DrawOps.H"
00052 #include "Media/TestImages.H"
00053 #include "Raster/Raster.H"
00054 #include "Transport/FrameInfo.H"
00055 #include "Raster/Raster.H"
00056 #include "Raster/GenericFrame.H"
00057 
00058 #define DISPSCALE 1
00059 #define USECOLOR false
00060 
00061 //the visual database
00062 VisualObjectDB vdb;
00063 #define WIDTH 384
00064 #define HEIGHT 288
00065 
00066 std::string matchObject(Image<PixRGB<byte> > &ima);
00067 
00068 /*
00069 XWinManaged xwin(Dims(WIDTH,HEIGHT*2), 1, 1, "Test SIFT");
00070 
00071 
00072 rutz::shared_ptr<VisualObject> objTop, objBottom;
00073 
00074 void showObjs(rutz::shared_ptr<VisualObject> obj1, rutz::shared_ptr<VisualObject> obj2){
00075         //return ;
00076 
00077         Image<PixRGB<byte> > keyIma = rescale(obj1->getKeypointImage(),
00078                         WIDTH, HEIGHT);
00079         objTop = obj1;
00080 
00081         if (obj2.is_valid()){
00082                 keyIma = concatY(keyIma, rescale(obj2->getKeypointImage(),
00083                                         WIDTH, HEIGHT));
00084                 objBottom = obj2;
00085         }
00086 
00087         xwin.drawImage(keyIma);
00088 }
00089 
00090 void showKeypoint(rutz::shared_ptr<VisualObject> obj, int keypi,
00091                 Keypoint::CHANNEL channel = Keypoint::ORI){
00092 
00093         char winTitle[255];
00094         switch(channel){
00095                 case Keypoint::ORI:
00096                         sprintf(winTitle, "Keypoint view (Channel ORI)");
00097                         break;
00098                 case Keypoint::COL:
00099                         sprintf(winTitle, "Keypoint view (Channel COL)");
00100          break;
00101                 default:
00102                         sprintf(winTitle, "Keypoint view (Channel   )");
00103                         break;
00104         }
00105 
00106 
00107         rutz::shared_ptr<Keypoint> keyp = obj->getKeypoint(keypi);
00108         float x = keyp->getX();
00109         float y = keyp->getY();
00110         float s = keyp->getS();
00111         float o = keyp->getO();
00112         float m = keyp->getM();
00113 
00114         uint FVlength = keyp->getFVlength(channel);
00115         if (FVlength<=0) return; //dont show the Keypoint if we dont have a FV
00116 
00117         XWinManaged *xwinKey = new XWinManaged(Dims(WIDTH*2,HEIGHT), -1, -1, winTitle);
00118 
00119 
00120         //draw the circle around the keypoint
00121         const float sigma = 1.6F * powf(2.0F, s / float(6 - 3));
00122         const float sig = 1.5F * sigma;
00123         const int rad = int(3.0F * sig);
00124 
00125         Image<PixRGB<byte> > img = obj->getImage();
00126         Point2D<int> loc(int(x + 0.5F), int(y + 0.5F));
00127         drawCircle(img, loc, rad, PixRGB<byte>(255, 0, 0));
00128         drawDisk(img, loc, 2, PixRGB<byte>(255,0,0));
00129 
00130         s=s*5.0F; //mag for scale
00131         if (s > 0.0f) drawLine(img, loc,
00132                         Point2D<int>(int(x + s * cosf(o)  + 0.5F),
00133                                 int(y + s * sinf(o) + 0.5F)),
00134                         PixRGB<byte>(255, 0, 0));
00135 
00136         char info[255];
00137         sprintf(info, "(%0.2f,%0.2f) s=%0.2f o=%0.2f m=%0.2f", x, y, s, o, m);
00138 
00139         writeText(img, Point2D<int>(0, HEIGHT-20), info,
00140                         PixRGB<byte>(255), PixRGB<byte>(127));
00141 
00142 
00143         //draw the vectors from the features vectors
00144 
00145         Image<PixRGB<byte> > fvDisp(WIDTH, HEIGHT, NO_INIT);
00146         fvDisp.clear(PixRGB<byte>(255, 255, 255));
00147         int xBins = int((float)WIDTH/4);
00148         int yBins = int((float)HEIGHT/4);
00149 
00150         drawGrid(fvDisp, xBins, yBins, 1, 1, PixRGB<byte>(0, 0, 0));
00151 
00152 
00153 
00154         switch (channel){
00155                 case Keypoint::ORI:
00156                         for (int xx=0; xx<4; xx++){
00157                                 for (int yy=0; yy<4; yy++){
00158                                         for (int oo=0; oo<8; oo++){
00159                                                 Point2D<int> loc(xBins/2+(xBins*xx), yBins/2+(yBins*yy));
00160                                                 byte mag = keyp->getFVelement(xx*32+yy*8+oo, channel);
00161                                                 mag = mag/4;
00162                                                 drawDisk(fvDisp, loc, 2, PixRGB<byte>(255, 0, 0));
00163                                                 drawLine(fvDisp, loc,
00164                                                                 Point2D<int>(int(loc.i + mag*cosf(oo*M_PI/4)),
00165                                                                         int(loc.j + mag*sinf(oo*M_PI/4))),
00166                                                                 PixRGB<byte>(255, 0, 0));
00167                                         }
00168                                 }
00169                         }
00170                         break;
00171 
00172                 case Keypoint::COL:
00173                         for (int xx=0; xx<4; xx++){
00174                                 for (int yy=0; yy<4; yy++){
00175                                         for (int cc=0; cc<3; cc++){
00176                                                 Point2D<int> loc(xBins/2+(xBins*xx), yBins/2+(yBins*yy));
00177                                                 byte mag = keyp->getFVelement(xx*12+yy*3+cc, channel);
00178                                                 mag = mag/4;
00179                                                 drawDisk(fvDisp, loc, 2, PixRGB<byte>(255, 0, 0));
00180                                                 drawLine(fvDisp, loc,
00181                                                                 Point2D<int>(int(loc.i + mag*cosf(-1*cc*M_PI/2)),
00182                                                                         int(loc.j + mag*sinf(-1*cc*M_PI/2))),
00183                                                                 PixRGB<byte>(255, 0, 0));
00184                                         }
00185                                 }
00186                         }
00187                         break;
00188                 default:
00189                         break;
00190         }
00191 
00192 
00193 
00194         Image<PixRGB<byte> > disp = img;
00195         disp = concatX(disp, fvDisp);
00196 
00197 
00198         xwinKey->drawImage(disp);
00199 
00200         while(!xwinKey->pressedCloseButton()){
00201                 usleep(100);
00202         }
00203         delete xwinKey;
00204 
00205 }
00206 
00207 
00208 
00209 void analizeImage(){
00210    int key = -1;
00211 
00212         while(key != 24){ // q to quit window
00213                 key = xwin.getLastKeyPress();
00214                 Point2D<int>  point = xwin.getLastMouseClick();
00215                 if (point.i > -1 && point.j > -1){
00216 
00217                         //get the right object
00218                         rutz::shared_ptr<VisualObject> obj;
00219                         if (point.j < HEIGHT){
00220                                 obj = objTop;
00221                         } else {
00222                                 obj = objBottom;
00223                                 point.j = point.j - HEIGHT;
00224                         }
00225                         LINFO("ClickInfo: key = %i, p=%i,%i", key, point.i, point.j);
00226 
00227                         //find the keypoint
00228                         for(uint i=0; i<obj->numKeypoints(); i++){
00229                                 rutz::shared_ptr<Keypoint> keyp = obj->getKeypoint(i);
00230                                 float x = keyp->getX();
00231                                 float y = keyp->getY();
00232 
00233                                 if ( (point.i < (int)x + 5 && point.i > (int)x - 5) &&
00234                                           (point.j < (int)y + 5 && point.j > (int)y - 5)){
00235                                         showKeypoint(obj, i, Keypoint::ORI);
00236                                         showKeypoint(obj, i, Keypoint::COL);
00237                                 }
00238 
00239                         }
00240 
00241                 }
00242         }
00243 
00244 }
00245 */
00246 int main(const int argc, const char **argv)
00247 {
00248 
00249   MYLOGVERB = LOG_INFO;
00250   ModelManager manager("Test SIFT");
00251 
00252 
00253 
00254   nub::ref<InputFrameSeries> ifs(new InputFrameSeries(manager));
00255   manager.addSubComponent(ifs);
00256 
00257   nub::ref<OutputFrameSeries> ofs(new OutputFrameSeries(manager));
00258   manager.addSubComponent(ofs);
00259 
00260 
00261 
00262   if (manager.parseCommandLine(
00263         (const int)argc, (const char**)argv, "<database file> <trainingLabel>", 2, 2) == false)
00264     return 0;
00265 
00266   manager.start();
00267 
00268   Timer masterclock;                // master clock for simulations
00269   Timer timer;
00270 
00271   const char *vdbFile = manager.getExtraArg(0).c_str();
00272   const char *trainingLabel = manager.getExtraArg(1).c_str();
00273 
00274   int numMatches = 0; //the number of correct matches
00275   int totalObjects = 0; //the number of objects presented to the network
00276   int uObjId = 0; //a unique obj id for sift
00277 
00278   bool train = false;
00279   //load the database file
00280  // if (!train)
00281   vdb.loadFrom(std::string(vdbFile));
00282 
00283   while(1)
00284   {
00285     Image< PixRGB<byte> > inputImg;
00286     const FrameState is = ifs->updateNext();
00287     if (is == FRAME_COMPLETE)
00288       break;
00289 
00290     //grab the images
00291     GenericFrame input = ifs->readFrame();
00292     if (!input.initialized())
00293       break;
00294     inputImg = input.asRgb();
00295     totalObjects++;
00296 
00297     ofs->writeRGB(inputImg, "Input", FrameInfo("Input", SRC_POS));
00298 
00299 
00300     if (train)
00301     {
00302       //add the object to the database
00303       char objName[255]; sprintf(objName, "%s_%i", trainingLabel, uObjId);
00304       uObjId++;
00305       rutz::shared_ptr<VisualObject>
00306         vo(new VisualObject(objName, "NULL", inputImg,
00307               Point2D<int>(-1,-1),
00308               std::vector<float>(),
00309               std::vector< rutz::shared_ptr<Keypoint> >(),
00310               USECOLOR));
00311 
00312       vdb.addObject(vo);
00313     } else {
00314 
00315       //get the object classification
00316       std::string objName;
00317       std::string tmpName = matchObject(inputImg);
00318       int i = tmpName.find("_");
00319       objName.assign(tmpName, 0, i);
00320       LINFO("Object name %s", objName.c_str());
00321       printf("%i %s\n", ifs->frame(), objName.c_str());
00322 
00323       if (objName == trainingLabel)
00324         numMatches++;
00325 
00326       //printf("objid %i:class %i:rate=%0.2f\n",
00327       //    objData.description.c_str(), objData.id, cls,
00328       //    (float)numMatches/(float)totalObjects);
00329     }
00330   }
00331 
00332   if (train)
00333   {
00334     printf("Trained on %i objects\n", totalObjects);
00335     printf("Object in db %i\n" , vdb.numObjects());
00336     vdb.saveTo(std::string(vdbFile));
00337   } else {
00338     printf("Classification Rate: %i/%i %0.2f\n",
00339         numMatches, totalObjects,
00340         (float)numMatches/(float)totalObjects);
00341   }
00342 
00343 
00344 }
00345 
00346 std::string matchObject(Image<PixRGB<byte> > &ima){
00347 
00348   //find object in the database
00349   std::vector< rutz::shared_ptr<VisualObjectMatch> > matches;
00350   rutz::shared_ptr<VisualObject>
00351     vo(new VisualObject("PIC", "PIC", ima,
00352                         Point2D<int>(-1,-1),
00353                         std::vector<float>(),
00354                         std::vector< rutz::shared_ptr<Keypoint> >(),
00355                         USECOLOR));
00356 
00357   const uint nmatches = vdb.getObjectMatches(vo, matches, VOMA_SIMPLE,
00358       10000U, //max objs to return
00359       0.5F, //keypoint distance score default 0.5F
00360       0.5F, //affine distance score default 0.5F
00361       1.0F, //minscore  default 1.0F
00362       3U, //min # of keypoint match
00363       100U, //keypoint selection thershold
00364       false //sort by preattentive
00365       );
00366 
00367   std::string objName;
00368   //LINFO("Found %i", nmatches);
00369   if (nmatches > 0 ){
00370     rutz::shared_ptr<VisualObject> obj; //so we will have a ref to the last matches obj
00371     rutz::shared_ptr<VisualObjectMatch> vom;
00372     //for(unsigned int i=0; i< nmatches; i++){
00373     for(unsigned int i=0; i< 1; i++){
00374       vom = matches[i];
00375       obj = vom->getVoTest();
00376 
00377 //      LINFO("### Object match with '%s' score=%f ID:%i",
00378 //          obj->getName().c_str(), vom->getScore(), objId);
00379       objName = obj->getName();
00380     }
00381 
00382   }
00383 
00384   return objName;
00385 }
Generated on Sun May 8 08:06:49 2011 for iLab Neuromorphic Vision Toolkit by  doxygen 1.6.3