00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Component/ModelManager.H"
00039 #include "Learn/Bayes.H"
00040 #include "GUI/DebugWin.H"
00041 #include "NeovisionII/nv2_common.h"
00042 #include "NeovisionII/nv2_label_server.h"
00043
00044 #include "GUI/XWindow.H"
00045 #include "CUDA/CudaHmaxFL.H"
00046 #include "CUDA/CudaHmax.H"
00047 #include "CUDA/CudaImage.H"
00048 #include "Image/Image.H"
00049 #include "Image/ImageSet.H"
00050 #include "Image/ColorOps.H"
00051 #include "Image/ShapeOps.H"
00052 #include "Image/CutPaste.H"
00053 #include "Image/FilterOps.H"
00054 #include "Image/Rectangle.H"
00055 #include "Image/MathOps.H"
00056 #include "Image/DrawOps.H"
00057 #include "Image/MatrixOps.H"
00058 #include "Image/Transforms.H"
00059 #include "Image/Convolutions.H"
00060 #include "Learn/SVMClassifier.H"
00061 #include "Media/FrameSeries.H"
00062 #include "nub/ref.h"
00063 #include "Raster/GenericFrame.H"
00064 #include "Raster/Raster.H"
00065 #include "Util/Types.H"
00066 #include "Util/log.H"
00067
00068 #include <signal.h>
00069
00070 #include "rutz/fstring.h"
00071 #include "rutz/time.h"
00072 #include "rutz/timeformat.h"
00073
00074
00075 #include <fstream>
00076 #include <map>
00077 #include <vector>
00078 #include <utility>
00079 #include <iostream>
00080 #include <iomanip>
00081 #include <string>
00082 #include <unistd.h>
00083 #include <cstdlib>
00084
00085
00086
00087 #define NORI 4
00088 #define NUM_PATCHES_PER_SIZE 250
00089
00090
00091 const bool USECOLOR = false;
00092
00093 bool terminate = false;
00094
00095 void terminateProc(int s)
00096 {
00097 terminate = true;
00098 }
00099
00100 std::string getBestLabel(const std::deque<std::string>& labels,
00101 const size_t mincount)
00102 {
00103 if (labels.size() == 0)
00104 return std::string();
00105
00106 std::map<std::string, size_t> counts;
00107
00108 size_t bestcount = 0;
00109 size_t bestpos = 0;
00110
00111 for (size_t i = 0; i < labels.size(); ++i)
00112 {
00113 const size_t c = ++(counts[labels[i]]);
00114
00115 if (c >= bestcount)
00116 {
00117 bestcount = c;
00118 bestpos = i;
00119 }
00120 }
00121
00122 if (bestcount >= mincount)
00123 return labels[bestpos];
00124
00125 return std::string();
00126 }
00127
00128 namespace
00129 {
00130 void fillRegion(Image<PixRGB<byte> >& img, PixRGB<byte> col,
00131 const int x0, const int x1,
00132 const int y0, const int y1)
00133 {
00134 for (int x = x0; x < x1; ++x)
00135 for (int y = y0; y < y1; ++y)
00136 img.setVal(x, y, col);
00137 }
00138
00139 Image<PixRGB<byte> > makeColorbars(const int w, const int h)
00140 {
00141 Image<PixRGB<byte> > result = Image<PixRGB<byte> >(w, h, ZEROS);
00142
00143 const PixRGB<byte> cols[] =
00144 {
00145 PixRGB<byte>(255, 255, 255),
00146 PixRGB<byte>(255, 255, 0),
00147 PixRGB<byte>(0, 255, 255),
00148 PixRGB<byte>(0, 255, 0),
00149 PixRGB<byte>(255, 0, 255),
00150 PixRGB<byte>(255, 0, 0),
00151 PixRGB<byte>(0, 0, 255)
00152 };
00153
00154 int x1 = 0;
00155 for (int i = 0; i < 7; ++i)
00156 {
00157 const int x0 = x1+1;
00158 x1 = int(double(w)*(i+1)/7.0 + 0.5);
00159 fillRegion(result, cols[i],
00160 x0, x1,
00161 0, int(h*2.0/3.0));
00162 }
00163
00164 x1 = 0;
00165 for (int i = 0; i < 16; ++i)
00166 {
00167 const int x0 = x1;
00168 x1 = int(double(w)*(i+1)/16.0 + 0.5);
00169 const int gray = int(255.0*i/15.0 + 0.5);
00170 fillRegion(result, PixRGB<byte>(gray, gray, gray),
00171 x0, x1,
00172 int(h*2.0/3.0)+1, int(h*5.0/6.0));
00173 }
00174
00175 fillRegion(result, PixRGB<byte>(255, 0, 0),
00176 0, w,
00177 int(h*5.0/6.0)+1, h);
00178
00179 writeText(result, Point2D<int>(1, int(h*5.0/6.0)+2),
00180 "iLab Neuromorphic Vision",
00181 PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 0, 0),
00182 SimpleFont::FIXED(10));
00183
00184 return result;
00185 }
00186
00187 Image<PixRGB<byte> > addLabels(const Image<PixRGB<byte> >& templ,
00188 const int fnum)
00189 {
00190 Image<PixRGB<byte> > result = templ;
00191
00192 std::string fnumstr = sformat("%06d", fnum);
00193 writeText(result, Point2D<int>(1, 1),
00194 fnumstr.c_str(),
00195 PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 255, 255),
00196 SimpleFont::FIXED(10));
00197
00198 rutz::time t = rutz::time::wall_clock_now();
00199
00200 writeText(result, Point2D<int>(1, result.getHeight() - 14),
00201 rutz::format_time(t).c_str(),
00202 PixRGB<byte>(32, 32, 32), PixRGB<byte>(255, 0, 0),
00203 SimpleFont::FIXED(6));
00204
00205 return result;
00206 }
00207
00208 int maxKey(std::map<int, std::string> m)
00209 {
00210 map<int, std::string>::iterator cur,end;
00211 cur = m.begin(); end = m.end();
00212 int mKey=-1;
00213 while(cur!=end)
00214 {
00215 if(cur->first > mKey)
00216 mKey = cur->first;
00217 cur++;
00218 }
00219 return mKey;
00220 }
00221
00222 std::map<int, std::string> loadLabels(std::string labelFile)
00223 {
00224 std::map<int, std::string> labels;
00225 FILE *fp = fopen(labelFile.c_str(),"r");
00226 int ret;
00227 if(fp==NULL) return labels;
00228 while(1)
00229 {
00230 int id; char clabel[80];
00231 ret = fscanf(fp,"%d ",&id);
00232 if(ret != 1)
00233 {
00234 fprintf(stderr,"fscanf failed with %d\n",ret);
00235 break;
00236 }
00237 ret = fscanf(fp,"%80s",clabel);
00238 if(ret != 1)
00239 {
00240 fprintf(stderr,"fscanf failed with %d\n",ret);
00241 break;
00242 }
00243 printf("loaded label %d %s\n",id,clabel);
00244 std::string label = std::string(clabel);
00245 labels.insert(std::pair<int, std::string>(id,label));
00246 }
00247 fclose(fp);
00248 return labels;
00249 }
00250
00251 void writeLabels(std::string labelFile, std::map<int, std::string> labels)
00252 {
00253 FILE *fp = fopen(labelFile.c_str(),"w");
00254 if(fp==NULL) return;
00255 map<int, std::string>::iterator cur,end;
00256 cur = labels.begin(); end = labels.end();
00257 while(cur!=end)
00258 {
00259 fprintf(fp,"%d %80s\n",cur->first, (cur->second).c_str());
00260 cur++;
00261 }
00262 fclose(fp);
00263 }
00264
00265 int findLabel(std::string label, std::map<int, std::string> labels)
00266 {
00267 map<int, std::string>::iterator cur,end;
00268 cur = labels.begin(); end = labels.end();
00269 while(cur!=end)
00270 {
00271 if(cur->second.compare(label)==0)
00272 return cur->first;
00273 cur++;
00274 }
00275 return -1;
00276 }
00277
00278
00279 int addLabel(std::string label, std::map<int, std::string> &labels)
00280 {
00281 int id = maxKey(labels)+1;
00282 labels.insert(std::pair<int, std::string>(id,label));
00283 return id;
00284 }
00285
00286 bool idExists(int id, std::map<int, std::string> &labels)
00287 {
00288 if(labels.find(id) == labels.end())
00289 return false;
00290 else
00291 return true;
00292 }
00293
00294 void writeFeatures(std::string trainingFileName, int id, float **features, int dim1, int dim2)
00295 {
00296 std::ofstream trainFile;
00297 trainFile.open(trainingFileName.c_str(),std::ios::app);
00298
00299 if (trainFile.is_open())
00300 {
00301 trainFile << id << " ";
00302 for(int i=0;i<dim1;i++)
00303 {
00304 for(int j=0;j<dim2;j++)
00305 {
00306 trainFile << std::setiosflags(std::ios::fixed) << std::setprecision(4) <<
00307 (i*dim2+j+1) << ":" << features[i][j] << " ";
00308 }
00309 }
00310 trainFile << std::endl;
00311 }
00312
00313 trainFile.close();
00314 }
00315
00316 }
00317
00318
00319 int main(const int argc, const char **argv)
00320 {
00321
00322 MYLOGVERB = LOG_INFO;
00323
00324 ModelManager *mgr = new ModelManager("Hmax with Feature Learning Server");
00325
00326
00327 mgr->exportOptions(MC_RECURSE);
00328
00329
00330 if (mgr->parseCommandLine(
00331 (const int)argc, (const char**)argv, "<cudadev> <labelFile> <c1patchesDir> <featuresFile> <localport> <server_ip> <serverport> <svmModelFile> <svmRangeFile> ", 7, 9) == false)
00332 return 1;
00333
00334 std::string devArg, serverIP,serverPortStr,localPortStr;
00335 std::string c1PatchesBaseDir;
00336 std::string svmModelFileName, svmRangeFileName;
00337 std::string c2FileName;
00338 std::string labelFileName, featuresFileName;
00339 std::string trainPosName;
00340
00341
00342 SVMClassifier svm;
00343
00344
00345 bool train = false;
00346
00347
00348
00349 mgr->start();
00350
00351
00352 signal(SIGHUP, terminateProc); signal(SIGINT, terminateProc);
00353 signal(SIGQUIT, terminateProc); signal(SIGTERM, terminateProc);
00354 signal(SIGALRM, terminateProc);
00355
00356
00357 devArg = mgr->getExtraArg(0);
00358 labelFileName = mgr->getExtraArg(1);
00359 c1PatchesBaseDir = mgr->getExtraArg(2);
00360 featuresFileName = mgr->getExtraArg(3);
00361 localPortStr = mgr->getExtraArg(4);
00362 serverIP = mgr->getExtraArg(5);
00363 serverPortStr = mgr->getExtraArg(6);
00364
00365 if(mgr->numExtraArgs() > 7)
00366 {
00367 if(mgr->numExtraArgs() == 8)
00368 ;
00369 svmModelFileName = mgr->getExtraArg(7);
00370 svmRangeFileName = mgr->getExtraArg(8);
00371 svm.readModel(svmModelFileName);
00372 svm.readRange(svmRangeFileName);
00373 }
00374 else
00375 {
00376
00377 train = true;
00378 }
00379 std::map<int,std::string> labels = loadLabels(labelFileName);
00380
00381 MemoryPolicy mp = GLOBAL_DEVICE_MEMORY;
00382 int dev = strtol(devArg.c_str(),NULL,0);
00383
00384
00385 std::vector<int> scss(9);
00386 scss[0] = 1; scss[1] = 3; scss[2] = 5; scss[3] = 7; scss[4] = 9;
00387 scss[5] = 11; scss[6] = 13; scss[7] = 15; scss[8] = 17;
00388 std::vector<int> spss(8);
00389 spss[0] = 8; spss[1] = 10; spss[2] = 12; spss[3] = 14;
00390 spss[4] = 16; spss[5] = 18; spss[6] = 20; spss[7] = 22;
00391
00392
00393
00394
00395
00396 CudaHmaxFL hmax(mp,dev,NORI, spss, scss);
00397
00398
00399 hmax.readInC1Patches(c1PatchesBaseDir);
00400
00401 std::vector<int> patchSizes = hmax.getC1PatchSizes();
00402
00403
00404 float **c2Res = new float*[patchSizes.size()];
00405 for(unsigned int i=0;i<patchSizes.size();i++) {
00406 c2Res[i] = new float[NUM_PATCHES_PER_SIZE];
00407 }
00408
00409
00410 XWinManaged xwin(Dims(256,256),
00411 -1, -1, "ILab Robot Head Demo");
00412
00413 int serverPort = strtol(serverPortStr.c_str(),NULL,0);
00414 int localPort = strtol(localPortStr.c_str(),NULL,0);
00415 struct nv2_label_server* labelServer =
00416 nv2_label_server_create(localPort,
00417 serverIP.c_str(),
00418 serverPort);
00419
00420 nv2_label_server_set_verbosity(labelServer,1);
00421
00422
00423 const size_t max_label_history = 1;
00424 std::deque<std::string> recent_labels;
00425
00426 Image<PixRGB<byte> > colorbars = makeColorbars(256, 256);
00427
00428
00429 while(!terminate)
00430 {
00431 Point2D<int> clickLoc = xwin.getLastMouseClick();
00432 if (clickLoc.isValid())
00433 train = !train;
00434
00435 struct nv2_image_patch p;
00436 const enum nv2_image_patch_result res =
00437 nv2_label_server_get_current_patch(labelServer, &p);
00438
00439 std::string objName;
00440 if (res == NV2_IMAGE_PATCH_END)
00441 {
00442 LINFO("ok, quitting");
00443 break;
00444 }
00445 else if (res == NV2_IMAGE_PATCH_NONE)
00446 {
00447 usleep(10000);
00448 continue;
00449 }
00450 else if (res == NV2_IMAGE_PATCH_VALID)
00451 {
00452 if (p.type != NV2_PIXEL_TYPE_RGB24)
00453 {
00454 LINFO("got a non-rgb24 patch; ignoring %i", p.type);
00455 continue;
00456 }
00457
00458 if (p.width * p.height == 1)
00459 {
00460 xwin.drawImage(addLabels(colorbars, p.id));
00461 continue;
00462 }
00463
00464 Image<PixRGB<byte> > img(p.width, p.height, NO_INIT);
00465
00466 memcpy(img.getArrayPtr(), p.data, p.width*p.height*3);
00467
00468 Image<PixRGB<byte> > inputImg = rescale(img, 256, 256);
00469
00470 xwin.drawImage(inputImg);
00471
00472 Image<float> inputf = luminanceNTSC(inputImg);
00473
00474
00475 hmax.getC2(CudaImage<float>(inputf,mp,dev),c2Res);
00476 if(!train)
00477 {
00478
00479 double pred = svm.predict(c2Res,patchSizes.size(),NUM_PATCHES_PER_SIZE);
00480 printf("Prediction is %f\n",pred);
00481 int predId = (int) pred;
00482 bool knowObject = idExists(predId,labels);
00483 if(knowObject)
00484 {
00485 printf("Known object %d\n",predId);
00486 objName = labels[predId];
00487 }
00488 else
00489 {
00490 printf("Unknown object %d\n",predId);
00491 char tmp[200];
00492 sprintf(tmp,"Unknown-%d",predId);
00493 objName = std::string(tmp);
00494 }
00495 recent_labels.push_back(objName);
00496 while (recent_labels.size() > max_label_history)
00497 recent_labels.pop_front();
00498
00499 struct nv2_patch_label l;
00500 l.protocol_version = NV2_LABEL_PROTOCOL_VERSION;
00501 l.patch_id = p.id;
00502
00503 l.confidence = (int)(100.0F);
00504 snprintf(l.source, sizeof(l.source), "%s",
00505 "HmaxFL");
00506 snprintf(l.name, sizeof(l.name), "%s",
00507 objName.c_str());
00508 snprintf(l.extra_info, sizeof(l.extra_info),
00509 "%ux%u #%u",
00510 (unsigned int) p.width,
00511 (unsigned int) p.height,
00512 (unsigned int) p.id);
00513
00514 nv2_label_server_send_label(labelServer, &l);
00515 LINFO("sent label '%s (%s)'\n", l.name, l.extra_info);
00516 }
00517
00518 else
00519 {
00520 printf("Enter a label for this object:\n");
00521 std::getline(std::cin, objName);
00522 printf("You typed '%s'\n", objName.c_str());
00523
00524 if (objName == "exit")
00525 break;
00526 else if (objName != "")
00527 {
00528 int newId = findLabel(objName,labels);
00529 if(newId == -1)
00530 {
00531 newId = addLabel(objName,labels);
00532 printf("No existing label found, adding [%s]\n",objName.c_str());
00533 }
00534 else
00535 {
00536 printf("Found existing label\n");
00537 }
00538 writeFeatures(featuresFileName,newId,c2Res,patchSizes.size(),NUM_PATCHES_PER_SIZE);
00539 }
00540 }
00541
00542 nv2_image_patch_destroy(&p);
00543 }
00544 }
00545
00546 writeLabels(labelFileName,labels);
00547
00548 for(unsigned int i=0;i<patchSizes.size();i++) {
00549 delete[] c2Res[i];
00550 }
00551 delete [] c2Res;
00552
00553 if (terminate)
00554 LINFO("Ending application because a signal was caught");
00555
00556
00557 LINFO("Got Here");
00558
00559 return 0;
00560 }
00561
00562
00563
00564
00565
00566
00567
00568
00569