00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Component/ModelManager.H"
00039 #include "Learn/Bayes.H"
00040 #include "GUI/DebugWin.H"
00041 #include "NeovisionII/nv2_common.h"
00042 #include "NeovisionII/nv2_label_server.h"
00043 #include "HMAX/HmaxFL.H"
00044 #include "GUI/XWindow.H"
00045 #include "Image/Image.H"
00046 #include "Image/ImageSet.H"
00047 #include "Image/ColorOps.H"
00048 #include "Image/ShapeOps.H"
00049 #include "Image/CutPaste.H"
00050 #include "Image/FilterOps.H"
00051 #include "Image/Rectangle.H"
00052 #include "Image/MathOps.H"
00053 #include "Image/DrawOps.H"
00054 #include "Image/MatrixOps.H"
00055 #include "Image/Transforms.H"
00056 #include "Image/Convolutions.H"
00057 #include "Learn/SVMClassifier.H"
00058 #include "Media/FrameSeries.H"
00059 #include "nub/ref.h"
00060 #include "Raster/GenericFrame.H"
00061 #include "Raster/Raster.H"
00062 #include "Util/Types.H"
00063 #include "Util/log.H"
00064
00065 #include <signal.h>
00066
00067 #include "rutz/fstring.h"
00068 #include "rutz/time.h"
00069 #include "rutz/timeformat.h"
00070
00071
00072 #include <fstream>
00073 #include <map>
00074 #include <vector>
00075 #include <utility>
00076 #include <iostream>
00077 #include <iomanip>
00078 #include <string>
00079 #include <unistd.h>
00080 #include <cstdlib>
00081
00082
00083
00084 #define NORI 4
00085 #define NUM_PATCHES_PER_SIZE 250
00086
00087
00088 const bool USECOLOR = false;
00089
00090 bool terminate = false;
00091
00092 void terminateProc(int s)
00093 {
00094 terminate = true;
00095 }
00096
00097 std::string getBestLabel(const std::deque<std::string>& labels,
00098 const size_t mincount)
00099 {
00100 if (labels.size() == 0)
00101 return std::string();
00102
00103 std::map<std::string, size_t> counts;
00104
00105 size_t bestcount = 0;
00106 size_t bestpos = 0;
00107
00108 for (size_t i = 0; i < labels.size(); ++i)
00109 {
00110 const size_t c = ++(counts[labels[i]]);
00111
00112 if (c >= bestcount)
00113 {
00114 bestcount = c;
00115 bestpos = i;
00116 }
00117 }
00118
00119 if (bestcount >= mincount)
00120 return labels[bestpos];
00121
00122 return std::string();
00123 }
00124
00125 namespace
00126 {
00127 void fillRegion(Image<PixRGB<byte> >& img, PixRGB<byte> col,
00128 const int x0, const int x1,
00129 const int y0, const int y1)
00130 {
00131 for (int x = x0; x < x1; ++x)
00132 for (int y = y0; y < y1; ++y)
00133 img.setVal(x, y, col);
00134 }
00135
00136 Image<PixRGB<byte> > makeColorbars(const int w, const int h)
00137 {
00138 Image<PixRGB<byte> > result = Image<PixRGB<byte> >(w, h, ZEROS);
00139
00140 const PixRGB<byte> cols[] =
00141 {
00142 PixRGB<byte>(255, 255, 255),
00143 PixRGB<byte>(255, 255, 0),
00144 PixRGB<byte>(0, 255, 255),
00145 PixRGB<byte>(0, 255, 0),
00146 PixRGB<byte>(255, 0, 255),
00147 PixRGB<byte>(255, 0, 0),
00148 PixRGB<byte>(0, 0, 255)
00149 };
00150
00151 int x1 = 0;
00152 for (int i = 0; i < 7; ++i)
00153 {
00154 const int x0 = x1+1;
00155 x1 = int(double(w)*(i+1)/7.0 + 0.5);
00156 fillRegion(result, cols[i],
00157 x0, x1,
00158 0, int(h*2.0/3.0));
00159 }
00160
00161 x1 = 0;
00162 for (int i = 0; i < 16; ++i)
00163 {
00164 const int x0 = x1;
00165 x1 = int(double(w)*(i+1)/16.0 + 0.5);
00166 const int gray = int(255.0*i/15.0 + 0.5);
00167 fillRegion(result, PixRGB<byte>(gray, gray, gray),
00168 x0, x1,
00169 int(h*2.0/3.0)+1, int(h*5.0/6.0));
00170 }
00171
00172 fillRegion(result, PixRGB<byte>(255, 0, 0),
00173 0, w,
00174 int(h*5.0/6.0)+1, h);
00175
00176 writeText(result, Point2D<int>(1, int(h*5.0/6.0)+2),
00177 "iLab Neuromorphic Vision",
00178 PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 0, 0),
00179 SimpleFont::FIXED(10));
00180
00181 return result;
00182 }
00183
00184 Image<PixRGB<byte> > addLabels(const Image<PixRGB<byte> >& templ,
00185 const int fnum)
00186 {
00187 Image<PixRGB<byte> > result = templ;
00188
00189 std::string fnumstr = sformat("%06d", fnum);
00190 writeText(result, Point2D<int>(1, 1),
00191 fnumstr.c_str(),
00192 PixRGB<byte>(0, 0, 0), PixRGB<byte>(255, 255, 255),
00193 SimpleFont::FIXED(10));
00194
00195 rutz::time t = rutz::time::wall_clock_now();
00196
00197 writeText(result, Point2D<int>(1, result.getHeight() - 14),
00198 rutz::format_time(t).c_str(),
00199 PixRGB<byte>(32, 32, 32), PixRGB<byte>(255, 0, 0),
00200 SimpleFont::FIXED(6));
00201
00202 return result;
00203 }
00204
00205 int maxKey(std::map<int, std::string> m)
00206 {
00207 std::map<int, std::string>::iterator cur,end;
00208 cur = m.begin(); end = m.end();
00209 int mKey=-1;
00210 while(cur!=end)
00211 {
00212 if(cur->first > mKey)
00213 mKey = cur->first;
00214 cur++;
00215 }
00216 return mKey;
00217 }
00218
00219 std::map<int, std::string> loadLabels(std::string labelFile)
00220 {
00221 std::map<int, std::string> labels;
00222 FILE *fp = fopen(labelFile.c_str(),"r");
00223 int ret;
00224 if(fp==NULL) return labels;
00225 while(1)
00226 {
00227 int id; char clabel[80];
00228 ret = fscanf(fp,"%d ",&id);
00229 if(ret != 1)
00230 {
00231 fprintf(stderr,"fscanf failed with %d\n",ret);
00232 break;
00233 }
00234 ret = fscanf(fp,"%80s",clabel);
00235 if(ret != 1)
00236 {
00237 fprintf(stderr,"fscanf failed with %d\n",ret);
00238 break;
00239 }
00240 printf("loaded label %d %s\n",id,clabel);
00241 std::string label = std::string(clabel);
00242 labels.insert(std::pair<int, std::string>(id,label));
00243 }
00244 fclose(fp);
00245 return labels;
00246 }
00247
00248 void writeLabels(std::string labelFile, std::map<int, std::string> labels)
00249 {
00250 FILE *fp = fopen(labelFile.c_str(),"w");
00251 if(fp==NULL) return;
00252 std::map<int, std::string>::iterator cur,end;
00253 cur = labels.begin(); end = labels.end();
00254 while(cur!=end)
00255 {
00256 fprintf(fp,"%d %80s\n",cur->first, (cur->second).c_str());
00257 cur++;
00258 }
00259 fclose(fp);
00260 }
00261
00262 int findLabel(std::string label, std::map<int, std::string> labels)
00263 {
00264 std::map<int, std::string>::iterator cur,end;
00265 cur = labels.begin(); end = labels.end();
00266 while(cur!=end)
00267 {
00268 if(cur->second.compare(label)==0)
00269 return cur->first;
00270 cur++;
00271 }
00272 return -1;
00273 }
00274
00275
00276 int addLabel(std::string label, std::map<int, std::string> &labels)
00277 {
00278 int id = maxKey(labels)+1;
00279 labels.insert(std::pair<int, std::string>(id,label));
00280 return id;
00281 }
00282
00283 bool idExists(int id, std::map<int, std::string> &labels)
00284 {
00285 if(labels.find(id) == labels.end())
00286 return false;
00287 else
00288 return true;
00289 }
00290
00291 void writeFeatures(std::string trainingFileName, int id, float **features, int dim1, int dim2)
00292 {
00293 std::ofstream trainFile;
00294 trainFile.open(trainingFileName.c_str(),std::ios::app);
00295
00296 if (trainFile.is_open())
00297 {
00298 trainFile << id << " ";
00299 for(int i=0;i<dim1;i++)
00300 {
00301 for(int j=0;j<dim2;j++)
00302 {
00303 trainFile << std::setiosflags(std::ios::fixed) << std::setprecision(4) <<
00304 (i*dim2+j+1) << ":" << features[i][j] << " ";
00305 }
00306 }
00307 trainFile << std::endl;
00308 }
00309
00310 trainFile.close();
00311 }
00312
00313 }
00314
00315
00316 int main(const int argc, const char **argv)
00317 {
00318
00319 MYLOGVERB = LOG_INFO;
00320
00321 ModelManager *mgr = new ModelManager("Hmax with Feature Learning Server");
00322
00323
00324 mgr->exportOptions(MC_RECURSE);
00325
00326
00327 if (mgr->parseCommandLine(
00328 (const int)argc, (const char**)argv, "<labelFile> <c1patchesDir> <featuresFile> <localport> <server_ip> <serverport> <svmModelFile> <svmRangeFile> ", 6, 8) == false)
00329 return 1;
00330
00331 std::string devArg, serverIP,serverPortStr,localPortStr;
00332 std::string c1PatchesBaseDir;
00333 std::string svmModelFileName, svmRangeFileName;
00334 std::string c2FileName;
00335 std::string labelFileName, featuresFileName;
00336 std::string trainPosName;
00337
00338
00339 SVMClassifier svm;
00340
00341
00342 bool train = false;
00343
00344
00345
00346 mgr->start();
00347
00348
00349 signal(SIGHUP, terminateProc); signal(SIGINT, terminateProc);
00350 signal(SIGQUIT, terminateProc); signal(SIGTERM, terminateProc);
00351 signal(SIGALRM, terminateProc);
00352
00353
00354 labelFileName = mgr->getExtraArg(0);
00355 c1PatchesBaseDir = mgr->getExtraArg(1);
00356 featuresFileName = mgr->getExtraArg(2);
00357 localPortStr = mgr->getExtraArg(3);
00358 serverIP = mgr->getExtraArg(4);
00359 serverPortStr = mgr->getExtraArg(5);
00360
00361 if(mgr->numExtraArgs() > 6)
00362 {
00363 svmModelFileName = mgr->getExtraArg(7);
00364 svm.readModel(svmModelFileName);
00365 if(mgr->numExtraArgs() == 8)
00366 {
00367 svmRangeFileName = mgr->getExtraArg(8);
00368 svm.readRange(svmRangeFileName);
00369 }
00370 }
00371 else
00372 {
00373
00374 train = true;
00375 }
00376 std::map<int,std::string> labels = loadLabels(labelFileName);
00377
00378
00379 std::vector<int> scss(9);
00380 scss[0] = 1; scss[1] = 3; scss[2] = 5; scss[3] = 7; scss[4] = 9;
00381 scss[5] = 11; scss[6] = 13; scss[7] = 15; scss[8] = 17;
00382 std::vector<int> spss(8);
00383 spss[0] = 8; spss[1] = 10; spss[2] = 12; spss[3] = 14;
00384 spss[4] = 16; spss[5] = 18; spss[6] = 20; spss[7] = 22;
00385
00386
00387
00388
00389
00390 HmaxFL hmax(NORI, spss, scss);
00391
00392
00393 hmax.readInC1Patches(c1PatchesBaseDir);
00394
00395 std::vector<int> patchSizes = hmax.getC1PatchSizes();
00396
00397
00398 float **c2Res = new float*[patchSizes.size()];
00399 for(unsigned int i=0;i<patchSizes.size();i++) {
00400 c2Res[i] = new float[NUM_PATCHES_PER_SIZE];
00401 }
00402
00403
00404 XWinManaged xwin(Dims(256,256),
00405 -1, -1, "ILab Robot Head Demo");
00406
00407 int serverPort = strtol(serverPortStr.c_str(),NULL,0);
00408 int localPort = strtol(localPortStr.c_str(),NULL,0);
00409 struct nv2_label_server* labelServer =
00410 nv2_label_server_create(localPort,
00411 serverIP.c_str(),
00412 serverPort);
00413
00414 nv2_label_server_set_verbosity(labelServer,1);
00415
00416
00417 const size_t max_label_history = 1;
00418 std::deque<std::string> recent_labels;
00419
00420 Image<PixRGB<byte> > colorbars = makeColorbars(256, 256);
00421
00422
00423 while(!terminate)
00424 {
00425 Point2D<int> clickLoc = xwin.getLastMouseClick();
00426 if (clickLoc.isValid())
00427 train = !train;
00428
00429 struct nv2_image_patch p;
00430 const enum nv2_image_patch_result res =
00431 nv2_label_server_get_current_patch(labelServer, &p);
00432
00433 std::string objName;
00434 if (res == NV2_IMAGE_PATCH_END)
00435 {
00436 LINFO("ok, quitting");
00437 break;
00438 }
00439 else if (res == NV2_IMAGE_PATCH_NONE)
00440 {
00441 usleep(10000);
00442 continue;
00443 }
00444 else if (res == NV2_IMAGE_PATCH_VALID)
00445 {
00446 if (p.type != NV2_PIXEL_TYPE_RGB24)
00447 {
00448 LINFO("got a non-rgb24 patch; ignoring %i", p.type);
00449 continue;
00450 }
00451
00452 if (p.width * p.height == 1)
00453 {
00454 xwin.drawImage(addLabels(colorbars, p.id));
00455 continue;
00456 }
00457
00458 Image<PixRGB<byte> > img(p.width, p.height, NO_INIT);
00459
00460 memcpy(img.getArrayPtr(), p.data, p.width*p.height*3);
00461
00462 Image<PixRGB<byte> > inputImg = rescale(img, 256, 256);
00463
00464 xwin.drawImage(inputImg);
00465
00466 Image<float> inputf = luminanceNTSC(inputImg);
00467
00468
00469 hmax.getC2(inputf,c2Res);
00470 if(!train)
00471 {
00472
00473 double pred = svm.predict(c2Res,patchSizes.size(),NUM_PATCHES_PER_SIZE);
00474 printf("Prediction is %f\n",pred);
00475 int predId = (int) pred;
00476 bool knowObject = idExists(predId,labels);
00477 if(knowObject)
00478 {
00479 printf("Known object %d\n",predId);
00480 objName = labels[predId];
00481 }
00482 else
00483 {
00484 printf("Unknown object %d\n",predId);
00485 char tmp[200];
00486 sprintf(tmp,"Unknown-%d",predId);
00487 objName = std::string(tmp);
00488 }
00489 recent_labels.push_back(objName);
00490 while (recent_labels.size() > max_label_history)
00491 recent_labels.pop_front();
00492
00493 struct nv2_patch_label l;
00494 l.protocol_version = NV2_LABEL_PROTOCOL_VERSION;
00495 l.patch_id = p.id;
00496
00497 l.confidence = (int)(100.0F);
00498 snprintf(l.source, sizeof(l.source), "%s",
00499 "HmaxFL");
00500 snprintf(l.name, sizeof(l.name), "%s",
00501 objName.c_str());
00502 snprintf(l.extra_info, sizeof(l.extra_info),
00503 "%ux%u #%u",
00504 (unsigned int) p.width,
00505 (unsigned int) p.height,
00506 (unsigned int) p.id);
00507
00508 nv2_label_server_send_label(labelServer, &l);
00509 LINFO("sent label '%s (%s)'\n", l.name, l.extra_info);
00510 }
00511
00512 else
00513 {
00514 printf("Enter a label for this object:\n");
00515 std::getline(std::cin, objName);
00516 printf("You typed '%s'\n", objName.c_str());
00517
00518 if (objName == "exit")
00519 break;
00520 else if (objName != "")
00521 {
00522 int newId = findLabel(objName,labels);
00523 if(newId == -1)
00524 {
00525 newId = addLabel(objName,labels);
00526 printf("No existing label found, adding [%s]\n",objName.c_str());
00527 }
00528 else
00529 {
00530 printf("Found existing label\n");
00531 }
00532 writeFeatures(featuresFileName,newId,c2Res,patchSizes.size(),NUM_PATCHES_PER_SIZE);
00533 }
00534 }
00535
00536 nv2_image_patch_destroy(&p);
00537 }
00538 }
00539
00540 writeLabels(labelFileName,labels);
00541
00542 for(unsigned int i=0;i<patchSizes.size();i++) {
00543 delete[] c2Res[i];
00544 }
00545 delete [] c2Res;
00546
00547 if (terminate)
00548 LINFO("Ending application because a signal was caught");
00549
00550
00551 LINFO("Got Here");
00552
00553 return 0;
00554 }
00555
00556
00557
00558
00559
00560
00561
00562
00563