cudacbcl-server.C

Go to the documentation of this file.
00001 /*!@file CUDA/cudacbcl-server.C Test TCP Label Server accepts patches and returns labels */
00002 
00003 // //////////////////////////////////////////////////////////////////// //
00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the //
00005 // University of Southern California (USC) and the iLab at USC.         //
00006 // See http://iLab.usc.edu for information about this project.          //
00007 // //////////////////////////////////////////////////////////////////// //
00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected //
00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency //
00010 // in Visual Environments, and Applications'' by Christof Koch and      //
00011 // Laurent Itti, California Institute of Technology, 2001 (patent       //
00012 // pending; application number 09/912,225 filed July 23, 2001; see      //
00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status).     //
00014 // //////////////////////////////////////////////////////////////////// //
00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit.       //
00016 //                                                                      //
00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can   //
00018 // redistribute it and/or modify it under the terms of the GNU General  //
00019 // Public License as published by the Free Software Foundation; either  //
00020 // version 2 of the License, or (at your option) any later version.     //
00021 //                                                                      //
00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope  //
00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the   //
00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR      //
00025 // PURPOSE.  See the GNU General Public License for more details.       //
00026 //                                                                      //
00027 // You should have received a copy of the GNU General Public License    //
00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write   //
00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330,   //
00030 // Boston, MA 02111-1307 USA.                                           //
00031 // //////////////////////////////////////////////////////////////////// //
00032 //
00033 // Primary maintainer for this file: Dan Parks <danielfp@usc.edu>
00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/CUDA/cudacbcl-server.C $
00035 // $Id: cudacbcl-server.C 14154 2010-10-21 05:07:25Z dparks $
00036 //
00037 
00038 #include "Component/ModelManager.H"
00039 #include "Learn/Bayes.H"
00040 #include "GUI/DebugWin.H"
00041 #include "NeovisionII/nv2_common.h"
00042 #include "NeovisionII/nv2_label_server.h"
00043 
00044 #include "GUI/XWindow.H"
00045 #include "CUDA/CudaHmaxCBCL.H"
00046 #include "Image/Image.H"
00047 #include "Image/ImageSet.H"
00048 #include "Image/ColorOps.H"
00049 #include "Image/ShapeOps.H"
00050 #include "Image/CutPaste.H"
00051 #include "Image/FilterOps.H"
00052 #include "Image/Rectangle.H"
00053 #include "Image/MathOps.H"
00054 #include "Image/DrawOps.H"
00055 #include "Image/MatrixOps.H"
00056 #include "Image/Transforms.H"
00057 #include "Image/Convolutions.H"
00058 #include "Learn/SVMClassifier.H"
00059 #include "Media/FrameSeries.H"
00060 #include "nub/ref.h"
00061 #include "Raster/GenericFrame.H"
00062 #include "Raster/Raster.H"
00063 #include "Util/Types.H"
00064 #include "Util/log.H"
00065 
00066 #include <signal.h>
00067 
00068 #include "rutz/fstring.h"
00069 #include "rutz/time.h"
00070 #include "rutz/timeformat.h"
00071 
00072 #include <fstream>
00073 #include <map>
00074 #include <vector>
00075 #include <utility>
00076 #include <iostream>
00077 #include <iomanip>
00078 #include <string>
00079 #include <unistd.h>
00080 #include <cstdlib>
00081 
00082 
00083 const bool USECOLOR = false;
00084 
00085 bool terminate = false;
00086 
00087 void terminateProc(int s)
00088 {
00089   terminate = true;
00090 }
00091 
00092 std::string getBestLabel(const std::deque<std::string>& labels,
00093                          const size_t mincount)
00094 {
00095   if (labels.size() == 0)
00096     return std::string();
00097 
00098   std::map<std::string, size_t> counts;
00099 
00100   size_t bestcount = 0;
00101   size_t bestpos = 0;
00102 
00103   for (size_t i = 0; i < labels.size(); ++i)
00104     {
00105       const size_t c = ++(counts[labels[i]]);
00106 
00107       if (c >= bestcount)
00108         {
00109           bestcount = c;
00110           bestpos = i;
00111         }
00112     }
00113 
00114   if (bestcount >= mincount)
00115     return labels[bestpos];
00116 
00117   return std::string();
00118 }
00119 
00120 namespace
00121 {
00122   void fillRegion(Image<PixRGB<byte> >& img, PixRGB<byte> col,
00123                   const int x0, const int x1,
00124                   const int y0, const int y1)
00125   {
00126     for (int x = x0; x < x1; ++x)
00127       for (int y = y0; y < y1; ++y)
00128         img.setVal(x, y, col);
00129   }
00130 
00131   Image<PixRGB<byte> > makeColorbars(const int w, const int h)
00132   {
00133     Image<PixRGB<byte> > result = Image<PixRGB<byte> >(w, h, ZEROS);
00134 
00135     const PixRGB<byte> cols[] =
00136       {
00137         PixRGB<byte>(255, 255, 255), // white
00138         PixRGB<byte>(255, 255, 0),   // yellow
00139         PixRGB<byte>(0,   255, 255), // cyan
00140         PixRGB<byte>(0,   255, 0),   // green
00141         PixRGB<byte>(255, 0,   255), // magenta
00142         PixRGB<byte>(255, 0,   0),   // red
00143         PixRGB<byte>(0,   0,   255)  // blue
00144       };
00145 
00146     int x1 = 0;
00147     for (int i = 0; i < 7; ++i)
00148       {
00149         const int x0 = x1+1;
00150         x1 = int(double(w)*(i+1)/7.0 + 0.5);
00151         fillRegion(result, cols[i],
00152                    x0, x1,
00153                    0, int(h*2.0/3.0));
00154       }
00155 
00156     x1 = 0;
00157     for (int i = 0; i < 16; ++i)
00158       {
00159         const int x0 = x1;
00160         x1 = int(double(w)*(i+1)/16.0 + 0.5);
00161         const int gray = int(255.0*i/15.0 + 0.5);
00162         fillRegion(result, PixRGB<byte>(gray, gray, gray),
00163                    x0, x1,
00164                    int(h*2.0/3.0)+1, int(h*5.0/6.0));
00165       }
00166 
00167     fillRegion(result, PixRGB<byte>(255, 0, 0),
00168                0, w,
00169                int(h*5.0/6.0)+1, h);
00170 
00171     writeText(result, Point2D<int>(1, int(h*5.0/6.0)+2),
00172               "iLab Neuromorphic Vision",
00173               PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 0, 0),
00174               SimpleFont::FIXED(10));
00175 
00176     return result;
00177   }
00178 
00179   Image<PixRGB<byte> > addLabels(const Image<PixRGB<byte> >& templ,
00180                                  const int fnum)
00181   {
00182     Image<PixRGB<byte> > result = templ;
00183 
00184     std::string fnumstr = sformat("%06d", fnum);
00185     writeText(result, Point2D<int>(1, 1),
00186               fnumstr.c_str(),
00187               PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 255, 255),
00188               SimpleFont::FIXED(10));
00189 
00190     rutz::time t = rutz::time::wall_clock_now();
00191 
00192     writeText(result, Point2D<int>(1, result.getHeight() - 14),
00193               rutz::format_time(t).c_str(),
00194               PixRGB<byte>(32, 32, 32), PixRGB<byte>(255, 0, 0),
00195               SimpleFont::FIXED(6));
00196 
00197     return result;
00198   }
00199 
00200   int maxKey(std::map<int, std::string> m)
00201   {
00202     map<int, std::string>::iterator cur,end;
00203     cur = m.begin(); end = m.end();
00204     int mKey=-1;
00205     while(cur!=end)
00206       {
00207         if(cur->first > mKey)
00208           mKey = cur->first;
00209         cur++;
00210       }
00211     return mKey;
00212   }
00213 
00214   std::map<int, std::string> loadLabels(std::string labelFile)
00215   {
00216     std::map<int, std::string> labels;
00217     FILE *fp = fopen(labelFile.c_str(),"r");
00218     int ret;
00219     if(fp==NULL) return labels;
00220     while(1)
00221       {
00222         int id; char clabel[80];
00223         ret = fscanf(fp,"%d ",&id);
00224         if(ret != 1)
00225         {
00226           fprintf(stderr,"fscanf failed with %d\n",ret);
00227           break;
00228         }
00229         ret = fscanf(fp,"%80s",clabel);
00230         if(ret != 1)
00231           {
00232             fprintf(stderr,"fscanf failed with %d\n",ret);
00233             break;
00234           }
00235         printf("loaded label %d %s\n",id,clabel);
00236         std::string label = std::string(clabel);
00237         labels.insert(std::pair<int, std::string>(id,label));
00238       }
00239     fclose(fp);
00240     return labels;
00241   }
00242 
00243   void writeLabels(std::string labelFile, std::map<int, std::string> labels)
00244   {
00245     FILE *fp = fopen(labelFile.c_str(),"w");
00246     if(fp==NULL) return;
00247     map<int, std::string>::iterator cur,end;
00248     cur = labels.begin(); end = labels.end();
00249     while(cur!=end)
00250       {
00251         fprintf(fp,"%d %80s\n",cur->first, (cur->second).c_str());
00252         cur++;
00253       }
00254     fclose(fp);
00255   }
00256 
00257   int findLabel(std::string label, std::map<int, std::string> labels)
00258   {
00259     map<int, std::string>::iterator cur,end;
00260     cur = labels.begin(); end = labels.end();
00261     while(cur!=end)
00262       {
00263         if(cur->second.compare(label)==0)
00264           return cur->first;
00265         cur++;
00266       }
00267     return -1;
00268   }
00269 
00270 
00271   int addLabel(std::string label, std::map<int, std::string> &labels)
00272   {
00273     int id = maxKey(labels)+1;
00274     labels.insert(std::pair<int, std::string>(id,label));
00275     return id;
00276   }
00277 
00278   bool idExists(int id, std::map<int, std::string> &labels)
00279   {
00280     if(labels.find(id) == labels.end())
00281       return false;
00282     else
00283       return true;
00284   }
00285 
00286 }
00287 
00288 
00289 int main(const int argc, const char **argv)
00290 {
00291 
00292   MYLOGVERB = LOG_INFO;
00293 
00294   ModelManager *mgr = new ModelManager("Cuda Hmax CBCL Model Server");
00295 
00296 
00297   mgr->exportOptions(MC_RECURSE);
00298 
00299 
00300   if (mgr->parseCommandLine(
00301                             (const int)argc, (const char**)argv, "<cudadev> <labelFile> <c0patches> <c1patches> <featuresFile> <localport> <server_ip> <serverport> <svmModelFile> <svmRangeFile> ", 8, 10) == false)
00302     return 1;
00303 
00304   std::string devArg, serverIP,serverPortStr,localPortStr;
00305   std::string c0Patches;
00306   std::string c1Patches;
00307   std::string svmModelFileName, svmRangeFileName;
00308   std::string c2FileName;
00309   std::string labelFileName, featuresFileName;
00310   std::string trainPosName; // Directory where positive images are
00311 
00312   // Load the SVM Classifier Model and Range in
00313   SVMClassifier svm;
00314 
00315   // Whether we are in training mode
00316   bool train = false;
00317 
00318 
00319   // Now we run
00320   mgr->start();
00321 
00322   // catch signals and redirect them to terminate for clean exit:
00323   signal(SIGHUP, terminateProc); signal(SIGINT, terminateProc);
00324   signal(SIGQUIT, terminateProc); signal(SIGTERM, terminateProc);
00325   signal(SIGALRM, terminateProc);
00326 
00327 
00328   devArg = mgr->getExtraArg(0);
00329   labelFileName = mgr->getExtraArg(1);
00330   c0Patches = mgr->getExtraArg(2);
00331   c1Patches = mgr->getExtraArg(3);
00332   featuresFileName = mgr->getExtraArg(4);
00333   localPortStr = mgr->getExtraArg(5);
00334   serverIP = mgr->getExtraArg(6);
00335   serverPortStr = mgr->getExtraArg(7);
00336   // If we are given the SVM info, load it in
00337   if(mgr->numExtraArgs() > 8)
00338   {
00339     if(mgr->numExtraArgs() == 9)
00340       {
00341         LFATAL("USAGE: prog <cudadev> <labelFile> <c0patches> <c1patches> <featuresFile> <localport> <server_ip> <serverport> <svmModelFile> <svmRangeFile>\n");
00342       }
00343     svmModelFileName = mgr->getExtraArg(8);
00344     svmRangeFileName = mgr->getExtraArg(9);
00345     svm.readModel(svmModelFileName);
00346     svm.readRange(svmRangeFileName);
00347   }
00348   else
00349   {
00350     // With no SVM data, we should be in training mode
00351     train = true;
00352   }
00353   std::map<int,std::string> labels = loadLabels(labelFileName);
00354 
00355   int dev = strtol(devArg.c_str(),NULL,0);
00356   CudaDevices::setCurrentDevice(dev);
00357 
00358   CudaHmaxCBCL hmax(c0Patches,c1Patches);
00359 
00360   XWinManaged xwin(Dims(256,256),
00361                    -1, -1, "ILab Robot Head Demo");
00362 
00363   int serverPort = strtol(serverPortStr.c_str(),NULL,0);
00364   int localPort = strtol(localPortStr.c_str(),NULL,0);
00365   struct nv2_label_server* labelServer =
00366     nv2_label_server_create(localPort,
00367                             serverIP.c_str(),
00368                             serverPort);
00369 
00370   nv2_label_server_set_verbosity(labelServer,1); //allow warnings
00371 
00372 
00373   const size_t max_label_history = 1;
00374   std::deque<std::string> recent_labels;
00375 
00376   Image<PixRGB<byte> > colorbars = makeColorbars(256, 256);
00377   bool clearFile=true;
00378 
00379   while(!terminate)
00380     {
00381       Point2D<int> clickLoc = xwin.getLastMouseClick();
00382       if (clickLoc.isValid())
00383         train = !train;
00384 
00385       struct nv2_image_patch p;
00386       const enum nv2_image_patch_result res =
00387         nv2_label_server_get_current_patch(labelServer, &p);
00388 
00389       std::string objName;
00390       if (res == NV2_IMAGE_PATCH_END)
00391         {
00392           LINFO("ok, quitting");
00393           break;
00394         }
00395       else if (res == NV2_IMAGE_PATCH_NONE)
00396         {
00397           usleep(10000);
00398           continue;
00399         }
00400       else if (res == NV2_IMAGE_PATCH_VALID)
00401         {
00402           if (p.type != NV2_PIXEL_TYPE_RGB24)
00403             {
00404               LINFO("got a non-rgb24 patch; ignoring %i", p.type);
00405               continue;
00406             }
00407 
00408           if (p.width * p.height == 1)
00409             {
00410               xwin.drawImage(addLabels(colorbars, p.id));
00411               continue;
00412             }
00413 
00414           Image<PixRGB<byte> > img(p.width, p.height, NO_INIT);
00415           // Get the test image from the socket
00416           memcpy(img.getArrayPtr(), p.data, p.width*p.height*3);
00417 
00418           Image<PixRGB<byte> > inputImg = rescale(img, 256, 256);
00419 
00420           xwin.drawImage(inputImg);
00421 
00422           Image<float> inputf = luminanceNTSC(inputImg);
00423 
00424           // Get the C2 Layer Response
00425           hmax.getC2(inputf.getArrayPtr(),inputf.getWidth(),inputf.getHeight());
00426           if(!train)
00427             {
00428               // Output the c2 responses into a libsvm
00429               double prob;
00430               float *feat = hmax.getC2Features();
00431               int numFeat = hmax.numC2Features();
00432               double pred = svm.predict(feat,numFeat,&prob);
00433               printf("Prediction is %f\n",pred);
00434               int predId = (int) pred;
00435               bool knowObject = idExists(predId,labels);
00436               if(knowObject)
00437               {
00438                 printf("Known object %d, prob %f\n",predId,prob);
00439                 objName = labels[predId];
00440               }
00441               else
00442               {
00443                 printf("Unknown object %d, prob %f\n",predId,prob);
00444                 char tmp[200];
00445                 sprintf(tmp,"Unknown-%d",predId);
00446                 objName = std::string(tmp);
00447               }
00448               recent_labels.push_back(objName);
00449               while (recent_labels.size() > max_label_history)
00450                 recent_labels.pop_front();
00451 
00452               struct nv2_patch_label l;
00453               l.protocol_version = NV2_LABEL_PROTOCOL_VERSION;
00454               l.patch_id = p.id;
00455               l.confidence = (int)(prob*NV2_MAX_LABEL_CONFIDENCE);
00456               snprintf(l.source, sizeof(l.source), "%s",
00457                        "HmaxFL");
00458               snprintf(l.name, sizeof(l.name), "%s",
00459                        objName.c_str());
00460               snprintf(l.extra_info, sizeof(l.extra_info),
00461                        "%ux%u #%u",
00462                        (unsigned int) p.width,
00463                        (unsigned int) p.height,
00464                        (unsigned int) p.id);
00465 
00466               nv2_label_server_send_label(labelServer, &l);
00467               LINFO("sent label '%s (%s)'\n", l.name, l.extra_info);
00468             }
00469           // Determine what the object is
00470           else
00471             {
00472               printf("Enter a label for this object:\n");
00473               std::getline(std::cin, objName);
00474               printf("You typed '%s'\n", objName.c_str());
00475 
00476               if (objName == "exit")
00477                 break;
00478               else if (objName != "")
00479                 {
00480                   int newId = findLabel(objName,labels);
00481                   if(newId == -1)
00482                     {
00483                       newId = addLabel(objName,labels);
00484                       printf("No existing label found, adding [%s]\n",objName.c_str());
00485                     }
00486                   else
00487                     {
00488                       printf("Found existing label\n");
00489                     }
00490                   hmax.writeOutFeatures(featuresFileName,newId,clearFile);
00491                   clearFile=false;
00492                 }
00493             }
00494 
00495           nv2_image_patch_destroy(&p);
00496         }
00497     }
00498 
00499   writeLabels(labelFileName,labels);
00500 
00501   if (terminate)
00502     LINFO("Ending application because a signal was caught");
00503 
00504   //nv2_label_server_destroy(labelServer);
00505   LINFO("Got Here");
00506 
00507   return 0;
00508 }
00509 
00510 
00511 
00512 
00513 // ######################################################################
00514 /* So things look consistent in everyone's emacs... */
00515 /* Local Variables: */
00516 /* indent-tabs-mode: nil */
00517 /* End: */
Generated on Sun May 8 08:40:25 2011 for iLab Neuromorphic Vision Toolkit by  doxygen 1.6.3