00001 /*! @file SIFT/test-SIFT.C test the SIFT alg */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2000-2005 // 00005 // by the University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Lior Elazary <elazary@usc.edu> 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/SIFT/test-SIFT.C $ 00035 // $Id: test-SIFT.C 14376 2011-01-11 02:44:34Z pez $ 00036 // 00037 00038 00039 #include "Component/ModelManager.H" 00040 #include "Image/Image.H" 00041 #include "Image/Pixels.H" 00042 #include "Media/FrameSeries.H" 00043 #include "Util/Timer.H" 00044 #include "GUI/XWinManaged.H" 00045 #include "SIFT/ScaleSpace.H" 00046 #include "SIFT/VisualObject.H" 00047 #include "SIFT/Keypoint.H" 00048 #include "SIFT/VisualObjectDB.H" 00049 #include "Image/CutPaste.H" 00050 #include "Image/ShapeOps.H" 00051 #include "Image/DrawOps.H" 00052 #include "Media/TestImages.H" 00053 #include "Raster/Raster.H" 00054 #include "Transport/FrameInfo.H" 00055 #include "Raster/Raster.H" 00056 #include "Raster/GenericFrame.H" 00057 00058 #define DISPSCALE 1 00059 #define USECOLOR false 00060 00061 //the visual database 00062 VisualObjectDB vdb; 00063 #define WIDTH 384 00064 #define HEIGHT 288 00065 00066 std::string matchObject(Image<PixRGB<byte> > &ima); 00067 00068 /* 00069 XWinManaged xwin(Dims(WIDTH,HEIGHT*2), 1, 1, "Test SIFT"); 00070 00071 00072 rutz::shared_ptr<VisualObject> objTop, objBottom; 00073 00074 void showObjs(rutz::shared_ptr<VisualObject> obj1, rutz::shared_ptr<VisualObject> obj2){ 00075 //return ; 00076 00077 Image<PixRGB<byte> > keyIma = rescale(obj1->getKeypointImage(), 00078 WIDTH, HEIGHT); 00079 objTop = obj1; 00080 00081 if (obj2.is_valid()){ 00082 keyIma = concatY(keyIma, rescale(obj2->getKeypointImage(), 00083 WIDTH, HEIGHT)); 00084 objBottom = obj2; 00085 } 00086 00087 xwin.drawImage(keyIma); 00088 } 00089 00090 void showKeypoint(rutz::shared_ptr<VisualObject> obj, int keypi, 00091 Keypoint::CHANNEL channel = Keypoint::ORI){ 00092 00093 char winTitle[255]; 00094 switch(channel){ 00095 case Keypoint::ORI: 00096 sprintf(winTitle, "Keypoint view (Channel ORI)"); 00097 break; 00098 case Keypoint::COL: 00099 sprintf(winTitle, "Keypoint view (Channel COL)"); 00100 break; 00101 default: 00102 sprintf(winTitle, "Keypoint view (Channel )"); 00103 break; 00104 } 00105 00106 00107 rutz::shared_ptr<Keypoint> keyp = obj->getKeypoint(keypi); 00108 float x = keyp->getX(); 00109 float y = keyp->getY(); 00110 float s = keyp->getS(); 00111 float o = keyp->getO(); 00112 float m = keyp->getM(); 00113 00114 uint FVlength = keyp->getFVlength(channel); 00115 if (FVlength<=0) return; //dont show the Keypoint if we dont have a FV 00116 00117 XWinManaged *xwinKey = new XWinManaged(Dims(WIDTH*2,HEIGHT), -1, -1, winTitle); 00118 00119 00120 //draw the circle around the keypoint 00121 const float sigma = 1.6F * powf(2.0F, s / float(6 - 3)); 00122 const float sig = 1.5F * sigma; 00123 const int rad = int(3.0F * sig); 00124 00125 Image<PixRGB<byte> > img = obj->getImage(); 00126 Point2D<int> loc(int(x + 0.5F), int(y + 0.5F)); 00127 drawCircle(img, loc, rad, PixRGB<byte>(255, 0, 0)); 00128 drawDisk(img, loc, 2, PixRGB<byte>(255,0,0)); 00129 00130 s=s*5.0F; //mag for scale 00131 if (s > 0.0f) drawLine(img, loc, 00132 Point2D<int>(int(x + s * cosf(o) + 0.5F), 00133 int(y + s * sinf(o) + 0.5F)), 00134 PixRGB<byte>(255, 0, 0)); 00135 00136 char info[255]; 00137 sprintf(info, "(%0.2f,%0.2f) s=%0.2f o=%0.2f m=%0.2f", x, y, s, o, m); 00138 00139 writeText(img, Point2D<int>(0, HEIGHT-20), info, 00140 PixRGB<byte>(255), PixRGB<byte>(127)); 00141 00142 00143 //draw the vectors from the features vectors 00144 00145 Image<PixRGB<byte> > fvDisp(WIDTH, HEIGHT, NO_INIT); 00146 fvDisp.clear(PixRGB<byte>(255, 255, 255)); 00147 int xBins = int((float)WIDTH/4); 00148 int yBins = int((float)HEIGHT/4); 00149 00150 drawGrid(fvDisp, xBins, yBins, 1, 1, PixRGB<byte>(0, 0, 0)); 00151 00152 00153 00154 switch (channel){ 00155 case Keypoint::ORI: 00156 for (int xx=0; xx<4; xx++){ 00157 for (int yy=0; yy<4; yy++){ 00158 for (int oo=0; oo<8; oo++){ 00159 Point2D<int> loc(xBins/2+(xBins*xx), yBins/2+(yBins*yy)); 00160 byte mag = keyp->getFVelement(xx*32+yy*8+oo, channel); 00161 mag = mag/4; 00162 drawDisk(fvDisp, loc, 2, PixRGB<byte>(255, 0, 0)); 00163 drawLine(fvDisp, loc, 00164 Point2D<int>(int(loc.i + mag*cosf(oo*M_PI/4)), 00165 int(loc.j + mag*sinf(oo*M_PI/4))), 00166 PixRGB<byte>(255, 0, 0)); 00167 } 00168 } 00169 } 00170 break; 00171 00172 case Keypoint::COL: 00173 for (int xx=0; xx<4; xx++){ 00174 for (int yy=0; yy<4; yy++){ 00175 for (int cc=0; cc<3; cc++){ 00176 Point2D<int> loc(xBins/2+(xBins*xx), yBins/2+(yBins*yy)); 00177 byte mag = keyp->getFVelement(xx*12+yy*3+cc, channel); 00178 mag = mag/4; 00179 drawDisk(fvDisp, loc, 2, PixRGB<byte>(255, 0, 0)); 00180 drawLine(fvDisp, loc, 00181 Point2D<int>(int(loc.i + mag*cosf(-1*cc*M_PI/2)), 00182 int(loc.j + mag*sinf(-1*cc*M_PI/2))), 00183 PixRGB<byte>(255, 0, 0)); 00184 } 00185 } 00186 } 00187 break; 00188 default: 00189 break; 00190 } 00191 00192 00193 00194 Image<PixRGB<byte> > disp = img; 00195 disp = concatX(disp, fvDisp); 00196 00197 00198 xwinKey->drawImage(disp); 00199 00200 while(!xwinKey->pressedCloseButton()){ 00201 usleep(100); 00202 } 00203 delete xwinKey; 00204 00205 } 00206 00207 00208 00209 void analizeImage(){ 00210 int key = -1; 00211 00212 while(key != 24){ // q to quit window 00213 key = xwin.getLastKeyPress(); 00214 Point2D<int> point = xwin.getLastMouseClick(); 00215 if (point.i > -1 && point.j > -1){ 00216 00217 //get the right object 00218 rutz::shared_ptr<VisualObject> obj; 00219 if (point.j < HEIGHT){ 00220 obj = objTop; 00221 } else { 00222 obj = objBottom; 00223 point.j = point.j - HEIGHT; 00224 } 00225 LINFO("ClickInfo: key = %i, p=%i,%i", key, point.i, point.j); 00226 00227 //find the keypoint 00228 for(uint i=0; i<obj->numKeypoints(); i++){ 00229 rutz::shared_ptr<Keypoint> keyp = obj->getKeypoint(i); 00230 float x = keyp->getX(); 00231 float y = keyp->getY(); 00232 00233 if ( (point.i < (int)x + 5 && point.i > (int)x - 5) && 00234 (point.j < (int)y + 5 && point.j > (int)y - 5)){ 00235 showKeypoint(obj, i, Keypoint::ORI); 00236 showKeypoint(obj, i, Keypoint::COL); 00237 } 00238 00239 } 00240 00241 } 00242 } 00243 00244 } 00245 */ 00246 int main(const int argc, const char **argv) 00247 { 00248 00249 MYLOGVERB = LOG_INFO; 00250 ModelManager manager("Test SIFT"); 00251 00252 00253 00254 nub::ref<InputFrameSeries> ifs(new InputFrameSeries(manager)); 00255 manager.addSubComponent(ifs); 00256 00257 nub::ref<OutputFrameSeries> ofs(new OutputFrameSeries(manager)); 00258 manager.addSubComponent(ofs); 00259 00260 00261 00262 if (manager.parseCommandLine( 00263 (const int)argc, (const char**)argv, "<database file> <trainingLabel>", 2, 2) == false) 00264 return 0; 00265 00266 manager.start(); 00267 00268 Timer masterclock; // master clock for simulations 00269 Timer timer; 00270 00271 const char *vdbFile = manager.getExtraArg(0).c_str(); 00272 const char *trainingLabel = manager.getExtraArg(1).c_str(); 00273 00274 int numMatches = 0; //the number of correct matches 00275 int totalObjects = 0; //the number of objects presented to the network 00276 int uObjId = 0; //a unique obj id for sift 00277 00278 bool train = false; 00279 //load the database file 00280 // if (!train) 00281 vdb.loadFrom(std::string(vdbFile)); 00282 00283 while(1) 00284 { 00285 Image< PixRGB<byte> > inputImg; 00286 const FrameState is = ifs->updateNext(); 00287 if (is == FRAME_COMPLETE) 00288 break; 00289 00290 //grab the images 00291 GenericFrame input = ifs->readFrame(); 00292 if (!input.initialized()) 00293 break; 00294 inputImg = input.asRgb(); 00295 totalObjects++; 00296 00297 ofs->writeRGB(inputImg, "Input", FrameInfo("Input", SRC_POS)); 00298 00299 00300 if (train) 00301 { 00302 //add the object to the database 00303 char objName[255]; sprintf(objName, "%s_%i", trainingLabel, uObjId); 00304 uObjId++; 00305 rutz::shared_ptr<VisualObject> 00306 vo(new VisualObject(objName, "NULL", inputImg, 00307 Point2D<int>(-1,-1), 00308 std::vector<float>(), 00309 std::vector< rutz::shared_ptr<Keypoint> >(), 00310 USECOLOR)); 00311 00312 vdb.addObject(vo); 00313 } else { 00314 00315 //get the object classification 00316 std::string objName; 00317 std::string tmpName = matchObject(inputImg); 00318 int i = tmpName.find("_"); 00319 objName.assign(tmpName, 0, i); 00320 LINFO("Object name %s", objName.c_str()); 00321 printf("%i %s\n", ifs->frame(), objName.c_str()); 00322 00323 if (objName == trainingLabel) 00324 numMatches++; 00325 00326 //printf("objid %i:class %i:rate=%0.2f\n", 00327 // objData.description.c_str(), objData.id, cls, 00328 // (float)numMatches/(float)totalObjects); 00329 } 00330 } 00331 00332 if (train) 00333 { 00334 printf("Trained on %i objects\n", totalObjects); 00335 printf("Object in db %i\n" , vdb.numObjects()); 00336 vdb.saveTo(std::string(vdbFile)); 00337 } else { 00338 printf("Classification Rate: %i/%i %0.2f\n", 00339 numMatches, totalObjects, 00340 (float)numMatches/(float)totalObjects); 00341 } 00342 00343 00344 } 00345 00346 std::string matchObject(Image<PixRGB<byte> > &ima){ 00347 00348 //find object in the database 00349 std::vector< rutz::shared_ptr<VisualObjectMatch> > matches; 00350 rutz::shared_ptr<VisualObject> 00351 vo(new VisualObject("PIC", "PIC", ima, 00352 Point2D<int>(-1,-1), 00353 std::vector<float>(), 00354 std::vector< rutz::shared_ptr<Keypoint> >(), 00355 USECOLOR)); 00356 00357 const uint nmatches = vdb.getObjectMatches(vo, matches, VOMA_SIMPLE, 00358 10000U, //max objs to return 00359 0.5F, //keypoint distance score default 0.5F 00360 0.5F, //affine distance score default 0.5F 00361 1.0F, //minscore default 1.0F 00362 3U, //min # of keypoint match 00363 100U, //keypoint selection thershold 00364 false //sort by preattentive 00365 ); 00366 00367 std::string objName; 00368 //LINFO("Found %i", nmatches); 00369 if (nmatches > 0 ){ 00370 rutz::shared_ptr<VisualObject> obj; //so we will have a ref to the last matches obj 00371 rutz::shared_ptr<VisualObjectMatch> vom; 00372 //for(unsigned int i=0; i< nmatches; i++){ 00373 for(unsigned int i=0; i< 1; i++){ 00374 vom = matches[i]; 00375 obj = vom->getVoTest(); 00376 00377 // LINFO("### Object match with '%s' score=%f ID:%i", 00378 // obj->getName().c_str(), vom->getScore(), objId); 00379 objName = obj->getName(); 00380 } 00381 00382 } 00383 00384 return objName; 00385 }