00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039 #ifndef OBJREC_CUDASIFTSERVER_C_DEFINED
00040 #define OBJREC_CUDASIFTSERVER_C_DEFINED
00041
00042 #include <signal.h>
00043 #include "Component/ModelManager.H"
00044 #include "Image/Image.H"
00045 #include "Image/ImageSet.H"
00046 #include "Image/ShapeOps.H"
00047 #include "Image/CutPaste.H"
00048 #include "Image/DrawOps.H"
00049 #include "Image/FilterOps.H"
00050 #include "Image/ColorOps.H"
00051 #include "Image/Transforms.H"
00052 #include "Image/MathOps.H"
00053 #include "Learn/Bayes.H"
00054 #include "GUI/DebugWin.H"
00055 #include "SIFT/ScaleSpace.H"
00056 #include "SIFT/VisualObject.H"
00057 #include "SIFT/Keypoint.H"
00058 #include "SIFT/VisualObjectDB.H"
00059
00060 #include "CUDASIFT/CUDAVisualObject.H"
00061 #include "NeovisionII/nv2_common.h"
00062 #include "NeovisionII/nv2_label_server.h"
00063 #include "rutz/fstring.h"
00064 #include "rutz/time.h"
00065 #include "rutz/timeformat.h"
00066
00067 #include "CUDASIFT/tpimageutil.h"
00068 #include "CUDASIFT/tpimage.h"
00069 #include "CUDASIFT/cudaImage.h"
00070 #include "CUDASIFT/cudaSift.h"
00071 #include "CUDASIFT/cudaSiftH.h"
00072
00073 #include <iostream>
00074
00075 const bool USECOLOR = false;
00076
00077 bool terminate = false;
00078
00079 void terminateProc(int s)
00080 {
00081 terminate = true;
00082 }
00083
00084 std::string matchObject(Image<PixRGB<byte> > &ima, VisualObjectDB& vdb, float &score)
00085 {
00086 std::vector< rutz::shared_ptr<VisualObjectMatch> > matches;
00087 #ifdef GPUSIFT
00088 rutz::shared_ptr<CUDAVisualObject>
00089 vo(new CUDAVisualObject("PIC", "PIC", ima,
00090 Point2D<int>(-1,-1),
00091 std::vector<float>(),
00092 std::vector< rutz::shared_ptr<Keypoint> >(),
00093 false,true));
00094 #else
00095 rutz::shared_ptr<VisualObject>
00096 vo(new VisualObject("PIC", "PIC", ima,
00097 Point2D<int>(-1,-1),
00098 std::vector<float>(),
00099 std::vector< rutz::shared_ptr<Keypoint> >(),
00100 false,true));
00101 #endif
00102
00103 const uint nmatches = vdb.getObjectMatches(vo, matches, VOMA_SIMPLE,
00104 100U,
00105 0.5F,
00106 0.5F,
00107 1.0F,
00108 3U,
00109 100U,
00110 false
00111 );
00112 score = 0;
00113 float avgScore = 0, affineAvgDist = 0;
00114 int nkeyp = 0;
00115 int objId = -1;
00116
00117
00118
00119 if (nmatches > 0)
00120 {
00121
00122 rutz::shared_ptr<VisualObject> obj;
00123 rutz::shared_ptr<VisualObjectMatch> vom;
00124
00125 for (unsigned int i = 0; i < 1; ++i)
00126 {
00127 vom = matches[i];
00128 obj = vom->getVoTest();
00129 score = vom->getScore();
00130 nkeyp = vom->size();
00131 avgScore = vom->getKeypointAvgDist();
00132 affineAvgDist = vom->getAffineAvgDist();
00133
00134 objId = atoi(obj->getName().c_str()+3);
00135
00136 std::string fullpath = obj->getName();
00137 std::string::size_type spos = fullpath.find_last_of('/');
00138 std::string protoname = fullpath.substr(0,spos);
00139 spos = protoname.find_last_of('/');
00140 protoname = protoname.substr(spos+1);
00141 std::cout << "protoname = " << protoname << std::endl;
00142
00143 LINFO("### Object match with '%s' score=%f ID:%i",
00144 obj->getName().c_str(), vom->getScore(), objId);
00145 return protoname;
00146 }
00147 }
00148
00149 return std::string("nomatch");
00150 }
00151
00152 std::string getBestLabel(const std::deque<std::string>& labels,
00153 const size_t mincount)
00154 {
00155 if (labels.size() == 0)
00156 return std::string();
00157
00158 std::map<std::string, size_t> counts;
00159
00160 size_t bestcount = 0;
00161 size_t bestpos = 0;
00162
00163 for (size_t i = 0; i < labels.size(); ++i)
00164 {
00165 const size_t c = ++(counts[labels[i]]);
00166
00167 if (c >= bestcount)
00168 {
00169 bestcount = c;
00170 bestpos = i;
00171 }
00172 }
00173
00174 if (bestcount >= mincount)
00175 return labels[bestpos];
00176
00177 return std::string();
00178 }
00179
00180 namespace
00181 {
00182 void fillRegion(Image<PixRGB<byte> >& img, PixRGB<byte> col,
00183 const int x0, const int x1,
00184 const int y0, const int y1)
00185 {
00186 for (int x = x0; x < x1; ++x)
00187 for (int y = y0; y < y1; ++y)
00188 img.setVal(x, y, col);
00189 }
00190
00191 Image<PixRGB<byte> > makeColorbars(const int w, const int h)
00192 {
00193 Image<PixRGB<byte> > result = Image<PixRGB<byte> >(w, h, ZEROS);
00194
00195 const PixRGB<byte> cols[] =
00196 {
00197 PixRGB<byte>(255, 255, 255),
00198 PixRGB<byte>(255, 255, 0),
00199 PixRGB<byte>(0, 255, 255),
00200 PixRGB<byte>(0, 255, 0),
00201 PixRGB<byte>(255, 0, 255),
00202 PixRGB<byte>(255, 0, 0),
00203 PixRGB<byte>(0, 0, 255)
00204 };
00205
00206 int x1 = 0;
00207 for (int i = 0; i < 7; ++i)
00208 {
00209 const int x0 = x1+1;
00210 x1 = int(double(w)*(i+1)/7.0 + 0.5);
00211 fillRegion(result, cols[i],
00212 x0, x1,
00213 0, int(h*2.0/3.0));
00214 }
00215
00216 x1 = 0;
00217 for (int i = 0; i < 16; ++i)
00218 {
00219 const int x0 = x1;
00220 x1 = int(double(w)*(i+1)/16.0 + 0.5);
00221 const int gray = int(255.0*i/15.0 + 0.5);
00222 fillRegion(result, PixRGB<byte>(gray, gray, gray),
00223 x0, x1,
00224 int(h*2.0/3.0)+1, int(h*5.0/6.0));
00225 }
00226
00227 fillRegion(result, PixRGB<byte>(255, 0, 0),
00228 0, w,
00229 int(h*5.0/6.0)+1, h);
00230
00231 writeText(result, Point2D<int>(1, int(h*5.0/6.0)+2),
00232 "iLab Neuromorphic Vision",
00233 PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 0, 0),
00234 SimpleFont::FIXED(10));
00235
00236 return result;
00237 }
00238
00239 Image<PixRGB<byte> > addLabels(const Image<PixRGB<byte> >& templ,
00240 const int fnum)
00241 {
00242 Image<PixRGB<byte> > result = templ;
00243
00244 std::string fnumstr = sformat("%06d", fnum);
00245 writeText(result, Point2D<int>(1, 1),
00246 fnumstr.c_str(),
00247 PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 255, 255),
00248 SimpleFont::FIXED(10));
00249
00250 rutz::time t = rutz::time::wall_clock_now();
00251
00252 writeText(result, Point2D<int>(1, result.getHeight() - 14),
00253 rutz::format_time(t).c_str(),
00254 PixRGB<byte>(32, 32, 32), PixRGB<byte>(255, 0, 0),
00255 SimpleFont::FIXED(6));
00256
00257 return result;
00258 }
00259 }
00260
00261 int main(const int argc, const char **argv)
00262 {
00263
00264
00265 MYLOGVERB = LOG_INFO;
00266 ModelManager mgr("Test ObjRec");
00267
00268 if (mgr.parseCommandLine(argc, argv, "<cudadev> <vdb file> <localport> <server ip> <serverport>", 5, 5) == false)
00269 return 1;
00270
00271 mgr.start();
00272
00273
00274 signal(SIGHUP, terminateProc); signal(SIGINT, terminateProc);
00275 signal(SIGQUIT, terminateProc); signal(SIGTERM, terminateProc);
00276 signal(SIGALRM, terminateProc);
00277
00278
00279 const std::string devArg = mgr.getExtraArg(0);
00280 const std::string vdbFile = mgr.getExtraArg(1);
00281 const std::string localPortStr = mgr.getExtraArg(2);
00282 const std::string serverIP = mgr.getExtraArg(3);
00283 const std::string serverPortStr = mgr.getExtraArg(4);
00284
00285 bool train = false;
00286
00287 int dev = strtol(devArg.c_str(),NULL,0);
00288 std::cout << "device = " << dev << std::endl;
00289 cudaSetDevice(dev);
00290
00291
00292 LINFO("Loading db from %s\n", vdbFile.c_str());
00293 VisualObjectDB vdb;
00294 vdb.loadFrom(vdbFile,false);
00295
00296
00297
00298 XWinManaged xwin(Dims(256,256),
00299 -1, -1, "ILab NeoVision2 CUDASIFT Demo");
00300
00301 int serverPort = strtol(serverPortStr.c_str(),NULL,0);
00302 int localPort = strtol(localPortStr.c_str(),NULL,0);
00303
00304 struct nv2_label_server* labelServer =
00305 nv2_label_server_create(localPort,
00306 serverIP.c_str(),
00307 serverPort);
00308
00309 nv2_label_server_set_verbosity(labelServer,1);
00310
00311
00312 const size_t max_label_history = 1;
00313 std::deque<std::string> recent_labels;
00314
00315 Image<PixRGB<byte> > colorbars = makeColorbars(256, 256);
00316
00317 while (!terminate)
00318 {
00319 Point2D<int> clickLoc = xwin.getLastMouseClick();
00320 if (clickLoc.isValid())
00321 train = !train;
00322
00323 struct nv2_image_patch p;
00324 const enum nv2_image_patch_result res =
00325 nv2_label_server_get_current_patch(labelServer, &p);
00326
00327 std::string objName;
00328 if (res == NV2_IMAGE_PATCH_END)
00329 {
00330 LINFO("ok, quitting");
00331 break;
00332 }
00333 else if (res == NV2_IMAGE_PATCH_NONE)
00334 {
00335 usleep(10000);
00336 continue;
00337 }
00338 else if (res == NV2_IMAGE_PATCH_VALID)
00339 {
00340 if (p.type != NV2_PIXEL_TYPE_RGB24)
00341 {
00342 LINFO("got a non-rgb24 patch; ignoring %i", p.type);
00343 continue;
00344 }
00345
00346 if (p.width * p.height == 1)
00347 {
00348 xwin.drawImage(addLabels(colorbars, p.id));
00349 continue;
00350 }
00351
00352 Image<PixRGB<byte> > bimage(p.width, p.height, NO_INIT);
00353 memcpy(bimage.getArrayPtr(), p.data, p.width*p.height*3);
00354 Image<PixRGB<byte> > inputImage = bimage;
00355 printf("inputImage w=%d, h=%d\n",inputImage.getWidth(),inputImage.getHeight());
00356
00357 xwin.drawImage(inputImage);
00358 float score = 0.0;
00359 std::string objName = matchObject(inputImage, vdb, score);
00360
00361
00362 if (objName == "nomatch")
00363 {
00364 recent_labels.resize(0);
00365
00366 if (train)
00367 {
00368 printf("Enter a label for this object:\n");
00369 std::getline(std::cin, objName);
00370 printf("You typed '%s'\n", objName.c_str());
00371
00372 if (objName == "exit")
00373 break;
00374 else if (objName != "")
00375 {
00376 #ifdef GPUSIFT
00377 rutz::shared_ptr<CUDAVisualObject>
00378 vo(new CUDAVisualObject(objName.c_str(), "NULL", inputImage,
00379 Point2D<int>(-1,-1),
00380 std::vector<float>(),
00381 std::vector< rutz::shared_ptr<Keypoint> >(),
00382 false,true));
00383 #else
00384 rutz::shared_ptr<VisualObject>
00385 vo(new VisualObject(objName.c_str(), "NULL", inputImage,
00386 Point2D<int>(-1,-1),
00387 std::vector<float>(),
00388 std::vector< rutz::shared_ptr<Keypoint> >(),
00389 false,true));
00390 #endif
00391 vdb.addObject(vo);
00392 vdb.saveTo(vdbFile);
00393 }
00394 }
00395 }
00396 else
00397 {
00398 recent_labels.push_back(objName);
00399 while (recent_labels.size() > max_label_history)
00400 recent_labels.pop_front();
00401
00402 const std::string bestObjName =
00403 getBestLabel(recent_labels, 1);
00404
00405 if (bestObjName.size() > 0)
00406 {
00407 struct nv2_patch_label l;
00408 l.protocol_version = NV2_LABEL_PROTOCOL_VERSION;
00409 l.patch_id = p.id;
00410
00411
00412 l.confidence = (int)(score*10000.0F);
00413 snprintf(l.source, sizeof(l.source), "%s",
00414 "ObjRec");
00415 snprintf(l.name, sizeof(l.name), "%s",
00416 objName.c_str());
00417 snprintf(l.extra_info, sizeof(l.extra_info),
00418 "%ux%u #%u",
00419 (unsigned int) p.width,
00420 (unsigned int) p.height,
00421 (unsigned int) p.id);
00422
00423 nv2_label_server_send_label(labelServer, &l);
00424
00425 LINFO("sent label '%s (%s)'\n", l.name, l.extra_info);
00426 }
00427 }
00428
00429 nv2_image_patch_destroy(&p);
00430 }
00431
00432 }
00433
00434 if (terminate)
00435 LINFO("Ending application because a signal was caught");
00436
00437
00438
00439 }
00440
00441
00442
00443
00444
00445
00446
00447 #endif