00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Component/ModelManager.H"
00039 #include "Learn/Bayes.H"
00040 #include "GUI/DebugWin.H"
00041 #include "NeovisionII/nv2_common.h"
00042 #include "NeovisionII/nv2_label_server.h"
00043
00044 #include "GUI/XWindow.H"
00045 #include "CUDA/CudaHmaxCBCL.H"
00046 #include "Image/Image.H"
00047 #include "Image/ImageSet.H"
00048 #include "Image/ColorOps.H"
00049 #include "Image/ShapeOps.H"
00050 #include "Image/CutPaste.H"
00051 #include "Image/FilterOps.H"
00052 #include "Image/Rectangle.H"
00053 #include "Image/MathOps.H"
00054 #include "Image/DrawOps.H"
00055 #include "Image/MatrixOps.H"
00056 #include "Image/Transforms.H"
00057 #include "Image/Convolutions.H"
00058 #include "Learn/SVMClassifier.H"
00059 #include "Media/FrameSeries.H"
00060 #include "nub/ref.h"
00061 #include "Raster/GenericFrame.H"
00062 #include "Raster/Raster.H"
00063 #include "Util/Types.H"
00064 #include "Util/log.H"
00065
00066 #include <signal.h>
00067
00068 #include "rutz/fstring.h"
00069 #include "rutz/time.h"
00070 #include "rutz/timeformat.h"
00071
00072 #include <fstream>
00073 #include <map>
00074 #include <vector>
00075 #include <utility>
00076 #include <iostream>
00077 #include <iomanip>
00078 #include <string>
00079 #include <unistd.h>
00080 #include <cstdlib>
00081
00082
00083 const bool USECOLOR = false;
00084
00085 bool terminate = false;
00086
00087 void terminateProc(int s)
00088 {
00089 terminate = true;
00090 }
00091
00092 std::string getBestLabel(const std::deque<std::string>& labels,
00093 const size_t mincount)
00094 {
00095 if (labels.size() == 0)
00096 return std::string();
00097
00098 std::map<std::string, size_t> counts;
00099
00100 size_t bestcount = 0;
00101 size_t bestpos = 0;
00102
00103 for (size_t i = 0; i < labels.size(); ++i)
00104 {
00105 const size_t c = ++(counts[labels[i]]);
00106
00107 if (c >= bestcount)
00108 {
00109 bestcount = c;
00110 bestpos = i;
00111 }
00112 }
00113
00114 if (bestcount >= mincount)
00115 return labels[bestpos];
00116
00117 return std::string();
00118 }
00119
00120 namespace
00121 {
00122 void fillRegion(Image<PixRGB<byte> >& img, PixRGB<byte> col,
00123 const int x0, const int x1,
00124 const int y0, const int y1)
00125 {
00126 for (int x = x0; x < x1; ++x)
00127 for (int y = y0; y < y1; ++y)
00128 img.setVal(x, y, col);
00129 }
00130
00131 Image<PixRGB<byte> > makeColorbars(const int w, const int h)
00132 {
00133 Image<PixRGB<byte> > result = Image<PixRGB<byte> >(w, h, ZEROS);
00134
00135 const PixRGB<byte> cols[] =
00136 {
00137 PixRGB<byte>(255, 255, 255),
00138 PixRGB<byte>(255, 255, 0),
00139 PixRGB<byte>(0, 255, 255),
00140 PixRGB<byte>(0, 255, 0),
00141 PixRGB<byte>(255, 0, 255),
00142 PixRGB<byte>(255, 0, 0),
00143 PixRGB<byte>(0, 0, 255)
00144 };
00145
00146 int x1 = 0;
00147 for (int i = 0; i < 7; ++i)
00148 {
00149 const int x0 = x1+1;
00150 x1 = int(double(w)*(i+1)/7.0 + 0.5);
00151 fillRegion(result, cols[i],
00152 x0, x1,
00153 0, int(h*2.0/3.0));
00154 }
00155
00156 x1 = 0;
00157 for (int i = 0; i < 16; ++i)
00158 {
00159 const int x0 = x1;
00160 x1 = int(double(w)*(i+1)/16.0 + 0.5);
00161 const int gray = int(255.0*i/15.0 + 0.5);
00162 fillRegion(result, PixRGB<byte>(gray, gray, gray),
00163 x0, x1,
00164 int(h*2.0/3.0)+1, int(h*5.0/6.0));
00165 }
00166
00167 fillRegion(result, PixRGB<byte>(255, 0, 0),
00168 0, w,
00169 int(h*5.0/6.0)+1, h);
00170
00171 writeText(result, Point2D<int>(1, int(h*5.0/6.0)+2),
00172 "iLab Neuromorphic Vision",
00173 PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 0, 0),
00174 SimpleFont::FIXED(10));
00175
00176 return result;
00177 }
00178
00179 Image<PixRGB<byte> > addLabels(const Image<PixRGB<byte> >& templ,
00180 const int fnum)
00181 {
00182 Image<PixRGB<byte> > result = templ;
00183
00184 std::string fnumstr = sformat("%06d", fnum);
00185 writeText(result, Point2D<int>(1, 1),
00186 fnumstr.c_str(),
00187 PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 255, 255),
00188 SimpleFont::FIXED(10));
00189
00190 rutz::time t = rutz::time::wall_clock_now();
00191
00192 writeText(result, Point2D<int>(1, result.getHeight() - 14),
00193 rutz::format_time(t).c_str(),
00194 PixRGB<byte>(32, 32, 32), PixRGB<byte>(255, 0, 0),
00195 SimpleFont::FIXED(6));
00196
00197 return result;
00198 }
00199
00200 int maxKey(std::map<int, std::string> m)
00201 {
00202 map<int, std::string>::iterator cur,end;
00203 cur = m.begin(); end = m.end();
00204 int mKey=-1;
00205 while(cur!=end)
00206 {
00207 if(cur->first > mKey)
00208 mKey = cur->first;
00209 cur++;
00210 }
00211 return mKey;
00212 }
00213
00214 std::map<int, std::string> loadLabels(std::string labelFile)
00215 {
00216 std::map<int, std::string> labels;
00217 FILE *fp = fopen(labelFile.c_str(),"r");
00218 int ret;
00219 if(fp==NULL) return labels;
00220 while(1)
00221 {
00222 int id; char clabel[80];
00223 ret = fscanf(fp,"%d ",&id);
00224 if(ret != 1)
00225 {
00226 fprintf(stderr,"fscanf failed with %d\n",ret);
00227 break;
00228 }
00229 ret = fscanf(fp,"%80s",clabel);
00230 if(ret != 1)
00231 {
00232 fprintf(stderr,"fscanf failed with %d\n",ret);
00233 break;
00234 }
00235 printf("loaded label %d %s\n",id,clabel);
00236 std::string label = std::string(clabel);
00237 labels.insert(std::pair<int, std::string>(id,label));
00238 }
00239 fclose(fp);
00240 return labels;
00241 }
00242
00243 void writeLabels(std::string labelFile, std::map<int, std::string> labels)
00244 {
00245 FILE *fp = fopen(labelFile.c_str(),"w");
00246 if(fp==NULL) return;
00247 map<int, std::string>::iterator cur,end;
00248 cur = labels.begin(); end = labels.end();
00249 while(cur!=end)
00250 {
00251 fprintf(fp,"%d %80s\n",cur->first, (cur->second).c_str());
00252 cur++;
00253 }
00254 fclose(fp);
00255 }
00256
00257 int findLabel(std::string label, std::map<int, std::string> labels)
00258 {
00259 map<int, std::string>::iterator cur,end;
00260 cur = labels.begin(); end = labels.end();
00261 while(cur!=end)
00262 {
00263 if(cur->second.compare(label)==0)
00264 return cur->first;
00265 cur++;
00266 }
00267 return -1;
00268 }
00269
00270
00271 int addLabel(std::string label, std::map<int, std::string> &labels)
00272 {
00273 int id = maxKey(labels)+1;
00274 labels.insert(std::pair<int, std::string>(id,label));
00275 return id;
00276 }
00277
00278 bool idExists(int id, std::map<int, std::string> &labels)
00279 {
00280 if(labels.find(id) == labels.end())
00281 return false;
00282 else
00283 return true;
00284 }
00285
00286 }
00287
00288
00289 int main(const int argc, const char **argv)
00290 {
00291
00292 MYLOGVERB = LOG_INFO;
00293
00294 ModelManager *mgr = new ModelManager("Cuda Hmax CBCL Model Server");
00295
00296
00297 mgr->exportOptions(MC_RECURSE);
00298
00299
00300 if (mgr->parseCommandLine(
00301 (const int)argc, (const char**)argv, "<cudadev> <labelFile> <c0patches> <c1patches> <featuresFile> <localport> <server_ip> <serverport> <svmModelFile> <svmRangeFile> ", 8, 10) == false)
00302 return 1;
00303
00304 std::string devArg, serverIP,serverPortStr,localPortStr;
00305 std::string c0Patches;
00306 std::string c1Patches;
00307 std::string svmModelFileName, svmRangeFileName;
00308 std::string c2FileName;
00309 std::string labelFileName, featuresFileName;
00310 std::string trainPosName;
00311
00312
00313 SVMClassifier svm;
00314
00315
00316 bool train = false;
00317
00318
00319
00320 mgr->start();
00321
00322
00323 signal(SIGHUP, terminateProc); signal(SIGINT, terminateProc);
00324 signal(SIGQUIT, terminateProc); signal(SIGTERM, terminateProc);
00325 signal(SIGALRM, terminateProc);
00326
00327
00328 devArg = mgr->getExtraArg(0);
00329 labelFileName = mgr->getExtraArg(1);
00330 c0Patches = mgr->getExtraArg(2);
00331 c1Patches = mgr->getExtraArg(3);
00332 featuresFileName = mgr->getExtraArg(4);
00333 localPortStr = mgr->getExtraArg(5);
00334 serverIP = mgr->getExtraArg(6);
00335 serverPortStr = mgr->getExtraArg(7);
00336
00337 if(mgr->numExtraArgs() > 8)
00338 {
00339 if(mgr->numExtraArgs() == 9)
00340 {
00341 LFATAL("USAGE: prog <cudadev> <labelFile> <c0patches> <c1patches> <featuresFile> <localport> <server_ip> <serverport> <svmModelFile> <svmRangeFile>\n");
00342 }
00343 svmModelFileName = mgr->getExtraArg(8);
00344 svmRangeFileName = mgr->getExtraArg(9);
00345 svm.readModel(svmModelFileName);
00346 svm.readRange(svmRangeFileName);
00347 }
00348 else
00349 {
00350
00351 train = true;
00352 }
00353 std::map<int,std::string> labels = loadLabels(labelFileName);
00354
00355 int dev = strtol(devArg.c_str(),NULL,0);
00356 CudaDevices::setCurrentDevice(dev);
00357
00358 CudaHmaxCBCL hmax(c0Patches,c1Patches);
00359
00360 XWinManaged xwin(Dims(256,256),
00361 -1, -1, "ILab Robot Head Demo");
00362
00363 int serverPort = strtol(serverPortStr.c_str(),NULL,0);
00364 int localPort = strtol(localPortStr.c_str(),NULL,0);
00365 struct nv2_label_server* labelServer =
00366 nv2_label_server_create(localPort,
00367 serverIP.c_str(),
00368 serverPort);
00369
00370 nv2_label_server_set_verbosity(labelServer,1);
00371
00372
00373 const size_t max_label_history = 1;
00374 std::deque<std::string> recent_labels;
00375
00376 Image<PixRGB<byte> > colorbars = makeColorbars(256, 256);
00377 bool clearFile=true;
00378
00379 while(!terminate)
00380 {
00381 Point2D<int> clickLoc = xwin.getLastMouseClick();
00382 if (clickLoc.isValid())
00383 train = !train;
00384
00385 struct nv2_image_patch p;
00386 const enum nv2_image_patch_result res =
00387 nv2_label_server_get_current_patch(labelServer, &p);
00388
00389 std::string objName;
00390 if (res == NV2_IMAGE_PATCH_END)
00391 {
00392 LINFO("ok, quitting");
00393 break;
00394 }
00395 else if (res == NV2_IMAGE_PATCH_NONE)
00396 {
00397 usleep(10000);
00398 continue;
00399 }
00400 else if (res == NV2_IMAGE_PATCH_VALID)
00401 {
00402 if (p.type != NV2_PIXEL_TYPE_RGB24)
00403 {
00404 LINFO("got a non-rgb24 patch; ignoring %i", p.type);
00405 continue;
00406 }
00407
00408 if (p.width * p.height == 1)
00409 {
00410 xwin.drawImage(addLabels(colorbars, p.id));
00411 continue;
00412 }
00413
00414 Image<PixRGB<byte> > img(p.width, p.height, NO_INIT);
00415
00416 memcpy(img.getArrayPtr(), p.data, p.width*p.height*3);
00417
00418 Image<PixRGB<byte> > inputImg = rescale(img, 256, 256);
00419
00420 xwin.drawImage(inputImg);
00421
00422 Image<float> inputf = luminanceNTSC(inputImg);
00423
00424
00425 hmax.getC2(inputf.getArrayPtr(),inputf.getWidth(),inputf.getHeight());
00426 if(!train)
00427 {
00428
00429 double prob;
00430 float *feat = hmax.getC2Features();
00431 int numFeat = hmax.numC2Features();
00432 double pred = svm.predict(feat,numFeat,&prob);
00433 printf("Prediction is %f\n",pred);
00434 int predId = (int) pred;
00435 bool knowObject = idExists(predId,labels);
00436 if(knowObject)
00437 {
00438 printf("Known object %d, prob %f\n",predId,prob);
00439 objName = labels[predId];
00440 }
00441 else
00442 {
00443 printf("Unknown object %d, prob %f\n",predId,prob);
00444 char tmp[200];
00445 sprintf(tmp,"Unknown-%d",predId);
00446 objName = std::string(tmp);
00447 }
00448 recent_labels.push_back(objName);
00449 while (recent_labels.size() > max_label_history)
00450 recent_labels.pop_front();
00451
00452 struct nv2_patch_label l;
00453 l.protocol_version = NV2_LABEL_PROTOCOL_VERSION;
00454 l.patch_id = p.id;
00455 l.confidence = (int)(prob*NV2_MAX_LABEL_CONFIDENCE);
00456 snprintf(l.source, sizeof(l.source), "%s",
00457 "HmaxFL");
00458 snprintf(l.name, sizeof(l.name), "%s",
00459 objName.c_str());
00460 snprintf(l.extra_info, sizeof(l.extra_info),
00461 "%ux%u #%u",
00462 (unsigned int) p.width,
00463 (unsigned int) p.height,
00464 (unsigned int) p.id);
00465
00466 nv2_label_server_send_label(labelServer, &l);
00467 LINFO("sent label '%s (%s)'\n", l.name, l.extra_info);
00468 }
00469
00470 else
00471 {
00472 printf("Enter a label for this object:\n");
00473 std::getline(std::cin, objName);
00474 printf("You typed '%s'\n", objName.c_str());
00475
00476 if (objName == "exit")
00477 break;
00478 else if (objName != "")
00479 {
00480 int newId = findLabel(objName,labels);
00481 if(newId == -1)
00482 {
00483 newId = addLabel(objName,labels);
00484 printf("No existing label found, adding [%s]\n",objName.c_str());
00485 }
00486 else
00487 {
00488 printf("Found existing label\n");
00489 }
00490 hmax.writeOutFeatures(featuresFileName,newId,clearFile);
00491 clearFile=false;
00492 }
00493 }
00494
00495 nv2_image_patch_destroy(&p);
00496 }
00497 }
00498
00499 writeLabels(labelFileName,labels);
00500
00501 if (terminate)
00502 LINFO("Ending application because a signal was caught");
00503
00504
00505 LINFO("Got Here");
00506
00507 return 0;
00508 }
00509
00510
00511
00512
00513
00514
00515
00516
00517