00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064 #if !defined(INVT_HAVE_LIBSURF) \
00065 || !defined(INVT_HAVE_LIBTORCH)
00066
00067 #include "Util/log.H"
00068
00069 int main()
00070 {
00071 LERROR("Sorry, this program needs the SURF, PMK and torch libraries.") ;
00072 return 255 ;
00073 }
00074
00075 #else // the actual program in all its hideous glory
00076
00077
00078
00079
00080 #include "Neuro/GistEstimatorSurfPMK.H"
00081
00082
00083 #include "Neuro/StdBrain.H"
00084 #include "Neuro/NeuroOpts.H"
00085 #include "Neuro/NeuroSimEvents.H"
00086
00087 #include "Media/SimFrameSeries.H"
00088 #include "Media/MediaOpts.H"
00089
00090 #include "Simulation/SimEventQueue.H"
00091 #include "Simulation/SimEventQueueConfigurator.H"
00092
00093 #include "Channels/ChannelOpts.H"
00094 #include "Component/ModelManager.H"
00095 #include "Component/ModelOptionDef.H"
00096
00097 #include "Image/Point2D.H"
00098
00099 #include "nub/ref.h"
00100
00101
00102 #include <opensurf/opensurf.hh>
00103
00104
00105
00106
00107 #include <glob.h>
00108 #include <unistd.h>
00109
00110
00111 #include <fstream>
00112 #include <sstream>
00113 #include <ios>
00114 #include <numeric>
00115 #include <algorithm>
00116 #include <functional>
00117 #include <map>
00118 #include <vector>
00119 #include <iterator>
00120 #include <stdexcept>
00121 #include <utility>
00122 #include <limits>
00123 #include <cmath>
00124
00125
00126 #include <torch/QCTrainer.h>
00127 #include <torch/SVMClassification.h>
00128 #include <torch/Kernel.h>
00129 #include <torch/MatDataSet.h>
00130
00131
00132
00133
00134
00135
00136
00137
00138 template<typename T>
00139 std::string to_string(const T& t)
00140 {
00141 std::ostringstream str ;
00142 str << t ;
00143 return str.str() ;
00144 }
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155 namespace {
00156
00157 const ModelOptionCateg MOC_SURFPMK = {
00158 MOC_SORTPRI_3,
00159 "Options specific to the Surf PMK program",
00160 } ;
00161
00162
00163
00164
00165 #ifndef SPMK_DEFAULT_TRAINING_DESCRIPTORS_FILE
00166 #define SPMK_DEFAULT_TRAINING_DESCRIPTORS_FILE "surf_descriptors.txt"
00167 #endif
00168
00169 const ModelOptionDef OPT_SurfDescriptors = {
00170 MODOPT_ARG_STRING, "SurfDescriptors", & MOC_SURFPMK, OPTEXP_CORE,
00171 "This option specifies the name of the file where SURF descriptors\n"
00172 "for the training images are to be accumulated. This is a plain text\n"
00173 "file containing the descriptors that will be fed into the hierarchical\n"
00174 "K-means procedure during the second training phase.\n",
00175 "surf-descriptors", '\0', "surf-descriptors-file",
00176 SPMK_DEFAULT_TRAINING_DESCRIPTORS_FILE,
00177 } ;
00178
00179
00180
00181
00182 #ifndef SPMK_DEFAULT_VOCABULARY_FILE
00183 #define SPMK_DEFAULT_VOCABULARY_FILE "surf_vocabulary.txt"
00184 #endif
00185
00186 const ModelOptionDef OPT_SurfVocabulary = {
00187 MODOPT_ARG_STRING, "SurfVocabulary", & MOC_SURFPMK, OPTEXP_CORE,
00188 "This option specifies the name of the file in which the \"prototypical\"\n"
00189 "SURF descriptors are (or are to be) stored. This is a plain text\n"
00190 "file containing the centroids of the hierarchical K-means clusters,\n"
00191 "which are used during gist vector computation to create feature maps\n"
00192 "and, subsequently, the flattened multi-level histograms using the\n"
00193 "pyramid matching as described in the Murillo paper.\n",
00194 "surf-vocabulary", '\0', "surf-vocabulary-file",
00195 SPMK_DEFAULT_VOCABULARY_FILE,
00196 } ;
00197
00198
00199
00200
00201
00202 #ifndef SPMK_DEFAULT_TRAINING_HISTOGRAMS_FILE
00203 #define SPMK_DEFAULT_TRAINING_HISTOGRAMS_FILE "training_histograms.txt"
00204 #endif
00205
00206 const ModelOptionDef OPT_HistogramsFile = {
00207 MODOPT_ARG_STRING, "HistogramsFile", & MOC_SURFPMK, OPTEXP_CORE,
00208 "This option specifies the name of the training histograms database,\n"
00209 "a plain text file containing one histogram entry per line. The\n"
00210 "first field specifies the name plus number of the entry (e.g.,\n"
00211 "foo.mpg:1, bar.mpg:5, and so on). The second field specifies the ground\n"
00212 "truth for this particular image. The remaining fields are simply the\n"
00213 "numbers making up the image's flattened out multi-level histogram,\n"
00214 "which serves as its gist vector.\n",
00215 "training-histograms", '\0', "training-histograms-file",
00216 SPMK_DEFAULT_TRAINING_HISTOGRAMS_FILE,
00217 } ;
00218
00219
00220
00221
00222
00223
00224
00225 #ifndef SPMK_DEFAULT_SVM_CLASSIFIER_FILE
00226 #define SPMK_DEFAULT_SVM_CLASSIFIER_FILE "svm_classifier.bin"
00227 #endif
00228
00229 const ModelOptionDef OPT_SvmClassifierFile = {
00230 MODOPT_ARG_STRING, "SvmClassifierFile", & MOC_SURFPMK, OPTEXP_CORE,
00231 "This option specifies the name of the file that will hold the SVM\n"
00232 "classifier for a given segment. This file is read and written by the\n"
00233 "torch library.",
00234 "svm-classifier", '\0', "svm-classifier-file",
00235 SPMK_DEFAULT_SVM_CLASSIFIER_FILE,
00236 } ;
00237
00238
00239
00240
00241
00242
00243 #ifndef SPMK_DEFAULT_SVM_TEMP_FILE
00244 #define SPMK_DEFAULT_SVM_TEMP_FILE "/tmp/train-surfpmk-torch-dataset.txt"
00245 #endif
00246
00247 const ModelOptionDef OPT_SvmTempFile = {
00248 MODOPT_ARG_STRING, "SvmTempFile", & MOC_SURFPMK, OPTEXP_CORE,
00249 "This option specifies the name of the temp file that will hold the SVM\n"
00250 "training data in the format required by the torch library. This file is\n"
00251 "automatically deleted when it is no longer required.",
00252 "svm-temp", '\0', "svm-temp-file",
00253 SPMK_DEFAULT_SVM_TEMP_FILE,
00254 } ;
00255
00256
00257
00258 #ifndef SPMK_DEFAULT_CLASSIFICATION_RESULTS_FILE
00259 #define SPMK_DEFAULT_CLASSIFICATION_RESULTS_FILE \
00260 "surfpmk_classifications.txt"
00261 #endif
00262
00263 const ModelOptionDef OPT_ResultsFile = {
00264 MODOPT_ARG_STRING, "ResultsFile", & MOC_SURFPMK, OPTEXP_CORE,
00265 "This option specifies the name of the classification results file,\n"
00266 "a plain text file containing one result entry per line. The first\n"
00267 "field specifies the name of the input image plus number of the entry,\n"
00268 "(e.g., foo.mpg:1, bar.mpg:5, and so on). Then comes the ground truth\n"
00269 "for this image followed by its classification result.\n",
00270 "results-file", '\0', "classification-results-file",
00271 SPMK_DEFAULT_CLASSIFICATION_RESULTS_FILE,
00272 } ;
00273
00274
00275
00276
00277
00278
00279
00280
00281 #ifndef SPMK_DEFAULT_IMAGE_NAME
00282 #define SPMK_DEFAULT_IMAGE_NAME "some_image"
00283 #endif
00284 #ifndef SPMK_DEFAULT_SEGMENT_NUMBER
00285 #define SPMK_DEFAULT_SEGMENT_NUMBER "0"
00286 #endif
00287
00288 const ModelOptionDef OPT_ImageName = {
00289 MODOPT_ARG_STRING, "ImageName", & MOC_SURFPMK, OPTEXP_CORE,
00290 "This option specifies the \"root\" name for an image. The image number\n"
00291 "will be automatically appended to this \"root\" name with a colon as the\n"
00292 "separator between name and frame number. The current input MPEG file\n"
00293 "name is a good choice for the value of this option.\n",
00294 "image-name", '\0', "input-MPEG-file-name",
00295 SPMK_DEFAULT_IMAGE_NAME,
00296 } ;
00297
00298 const ModelOptionDef OPT_SegmentNumber = {
00299 MODOPT_ARG_STRING, "SegmentNumber", & MOC_SURFPMK, OPTEXP_CORE,
00300 "This option specifies the segment number for an image in the training\n"
00301 "set. The segment number is used to specify the ground truth for the\n"
00302 "image classification.\n",
00303 "segment-number", '\0', "image-segment-number",
00304 SPMK_DEFAULT_SEGMENT_NUMBER,
00305 } ;
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348
00349
00350
00351
00352
00353
00354
00355
00356
00357
00358
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368
00369 #ifndef SPMK_SURF_CMD
00370 #define SPMK_SURF_CMD "surf"
00371 #endif
00372 #ifndef SPMK_VOCABULARY_CMD
00373 #define SPMK_VOCABULARY_CMD "vocab"
00374 #endif
00375 #ifndef SPMK_HISTOGRAM_CMD
00376 #define SPMK_HISTOGRAM_CMD "hist"
00377 #endif
00378 #ifndef SPMK_SVM_CMD
00379 #define SPMK_SVM_CMD "svm"
00380 #endif
00381 #ifndef SPMK_CLASSIFY_CMD
00382 #define SPMK_CLASSIFY_CMD "classify"
00383 #endif
00384
00385
00386 #ifndef SPMK_ACTIONS
00387 #define SPMK_ACTIONS ("{"SPMK_SURF_CMD"|"SPMK_VOCABULARY_CMD"|"\
00388 SPMK_HISTOGRAM_CMD"|"SPMK_SVM_CMD"|"\
00389 SPMK_CLASSIFY_CMD"}")
00390 #endif
00391
00392 }
00393
00394
00395
00396
00397
00398
00399 namespace {
00400
00401 class SPMKSimulation {
00402 ModelManager model_manager ;
00403 nub::soft_ref<SimEventQueueConfigurator> configurator ;
00404 nub::soft_ref<StdBrain> brain ;
00405 nub::ref<SimInputFrameSeries> input_frame_series ;
00406
00407
00408 OModelParam<std::string> sd_option ;
00409 OModelParam<std::string> sv_option ;
00410 OModelParam<std::string> th_option ;
00411 OModelParam<std::string> sc_option ;
00412 OModelParam<std::string> st_option ;
00413 OModelParam<std::string> rf_option ;
00414 OModelParam<std::string> in_option ;
00415 OModelParam<std::string> sn_option ;
00416
00417 public :
00418 SPMKSimulation(const std::string& model_name) ;
00419 void parse_command_line(int argc, const char* argv[]) ;
00420 void run() ;
00421 ~SPMKSimulation() ;
00422
00423 private :
00424
00425 typedef void (SPMKSimulation::*Action)() ;
00426 typedef std::map<std::string, Action> ActionMap ;
00427 ActionMap action_map ;
00428
00429 void accumulate_surf_descriptors() ;
00430 void compute_surf_vocabulary() ;
00431 void compute_training_histograms() ;
00432 void generate_svm_classifier() ;
00433 void classify_input_images() ;
00434
00435
00436 std::string surf_descriptors_file() {return sd_option.getVal() ;}
00437 std::string surf_vocabulary_file() {return sv_option.getVal() ;}
00438 std::string histograms_file() {return th_option.getVal() ;}
00439 std::string svm_classifier_file() {return sc_option.getVal() ;}
00440 std::string svm_temp_file() {return st_option.getVal() ;}
00441 std::string results_file() {return rf_option.getVal() ;}
00442 std::string image_name() {return in_option.getVal() ;}
00443 std::string segment_number() {return sn_option.getVal() ;}
00444 } ;
00445
00446
00447
00448 SPMKSimulation::SPMKSimulation(const std::string& model_name)
00449 : model_manager(model_name),
00450 configurator(new SimEventQueueConfigurator(model_manager)),
00451 brain(new StdBrain(model_manager)),
00452 input_frame_series(new SimInputFrameSeries(model_manager)),
00453 sd_option(& OPT_SurfDescriptors, & model_manager),
00454 sv_option(& OPT_SurfVocabulary, & model_manager),
00455 th_option(& OPT_HistogramsFile, & model_manager),
00456 sc_option(& OPT_SvmClassifierFile, & model_manager),
00457 st_option(& OPT_SvmTempFile, & model_manager),
00458 rf_option(& OPT_ResultsFile, & model_manager),
00459 in_option(& OPT_ImageName, & model_manager),
00460 sn_option(& OPT_SegmentNumber, & model_manager)
00461 {
00462 model_manager.addSubComponent(configurator) ;
00463 model_manager.addSubComponent(brain) ;
00464 model_manager.addSubComponent(input_frame_series) ;
00465
00466 typedef SPMKSimulation me ;
00467 action_map[SPMK_SURF_CMD] = & me::accumulate_surf_descriptors ;
00468 action_map[SPMK_VOCABULARY_CMD] = & me::compute_surf_vocabulary ;
00469 action_map[SPMK_HISTOGRAM_CMD] = & me::compute_training_histograms ;
00470 action_map[SPMK_SVM_CMD] = & me::generate_svm_classifier ;
00471 action_map[SPMK_CLASSIFY_CMD] = & me::classify_input_images ;
00472 }
00473
00474
00475
00476
00477 void SPMKSimulation::parse_command_line(int argc, const char* argv[])
00478 {
00479 model_manager.setOptionValString(& OPT_GistEstimatorType, "SurfPMK") ;
00480
00481 model_manager.setOptionValString(& OPT_SurfDescriptors,
00482 SPMK_DEFAULT_TRAINING_DESCRIPTORS_FILE) ;
00483 model_manager.setOptionValString(& OPT_SurfVocabulary,
00484 SPMK_DEFAULT_VOCABULARY_FILE) ;
00485 model_manager.setOptionValString(& OPT_HistogramsFile,
00486 SPMK_DEFAULT_TRAINING_HISTOGRAMS_FILE ) ;
00487 model_manager.setOptionValString(& OPT_SvmClassifierFile,
00488 SPMK_DEFAULT_SVM_CLASSIFIER_FILE ) ;
00489 model_manager.setOptionValString(& OPT_SvmTempFile,
00490 SPMK_DEFAULT_SVM_TEMP_FILE ) ;
00491 model_manager.setOptionValString(& OPT_ResultsFile,
00492 SPMK_DEFAULT_CLASSIFICATION_RESULTS_FILE) ;
00493
00494 model_manager.setOptionValString(& OPT_ImageName,
00495 SPMK_DEFAULT_IMAGE_NAME) ;
00496 model_manager.setOptionValString(& OPT_SegmentNumber,
00497 SPMK_DEFAULT_SEGMENT_NUMBER) ;
00498
00499 if (! model_manager.parseCommandLine(argc, argv, SPMK_ACTIONS, 1, 1))
00500 throw std::runtime_error("command line parse error") ;
00501 }
00502
00503
00504
00505 void SPMKSimulation::run()
00506 {
00507 std::string cmd(model_manager.getExtraArg(0)) ;
00508 ActionMap::iterator action = action_map.find(cmd) ;
00509 if (action == action_map.end())
00510 throw std::runtime_error(cmd + ": sorry, unknown action") ;
00511 (this->*(action->second))() ;
00512 }
00513
00514
00515
00516
00517 SPMKSimulation::~SPMKSimulation(){}
00518
00519
00520
00521
00522
00523 class ModelManagerStarter {
00524 ModelManager& mgr ;
00525 public :
00526 ModelManagerStarter(ModelManager& m) : mgr(m) {mgr.start() ;}
00527 ~ModelManagerStarter() {mgr.stop() ;}
00528 } ;
00529
00530 }
00531
00532
00533
00534 int main(int argc, const char* argv[])
00535 {
00536 MYLOGVERB = LOG_INFO ;
00537 try
00538 {
00539 SPMKSimulation S("train-surfpmk Model") ;
00540 S.parse_command_line(argc, argv) ;
00541 S.run() ;
00542 }
00543 catch (std::exception& e)
00544 {
00545 LFATAL("%s", e.what()) ;
00546 return 1 ;
00547 }
00548 return 0 ;
00549 }
00550
00551
00552
00553
00554
00555 namespace {
00556
00557
00558 typedef GistEstimatorSurfPMK::SurfKeypoints SurfKeypoints ;
00559
00560
00561
00562 class surf_descriptors_accumulator {
00563 surf_descriptors_accumulator() ;
00564 ~surf_descriptors_accumulator() ;
00565 public :
00566 static std::string output_file ;
00567 static std::string image_name ;
00568 static int frame_number ;
00569 static std::string segment_number ;
00570
00571 static void write(const SurfKeypoints&) ;
00572 } ;
00573
00574
00575
00576
00577
00578
00579
00580
00581 void SPMKSimulation::accumulate_surf_descriptors()
00582 {
00583
00584 LFATAL("please fix me!");
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595
00596
00597
00598
00599
00600
00601
00602
00603
00604
00605
00606
00607
00608
00609
00610
00611
00612
00613
00614
00615
00616
00617
00618 }
00619
00620
00621
00622
00623 std::string surf_descriptors_accumulator::output_file ;
00624 std::string surf_descriptors_accumulator::image_name ;
00625 int surf_descriptors_accumulator::frame_number ;
00626 std::string surf_descriptors_accumulator::segment_number ;
00627
00628
00629
00630
00631
00632
00633
00634
00635
00636
00637
00638
00639
00640
00641
00642
00643
00644
00645
00646
00647
00648
00649
00650
00651
00652
00653
00654
00655
00656 void surf_descriptors_accumulator::write(const SurfKeypoints& G)
00657 {
00658 if (output_file.empty())
00659 throw std::runtime_error("SURF descriptors accumulator output file "
00660 "not specified") ;
00661
00662 std::ofstream ofs(output_file.c_str(), std::ios::out | std::ios::app) ;
00663 for (unsigned int i = 0; i < G.size(); ++i)
00664 ofs << image_name << ':' << frame_number << ' '
00665 << segment_number << ' ' << i << ' ' << G[i] << '\n' ;
00666 }
00667
00668 }
00669
00670
00671
00672
00673
00674
00675 namespace {
00676
00677
00678 typedef Image<float> Vocabulary ;
00679
00680
00681 int count_lines(const std::string& file_name) ;
00682 void load_surf_descriptors(const std::string& file_name, int num_lines) ;
00683 void kmeans(int K) ;
00684 void save_vocabulary(const std::string& file_name) ;
00685
00686
00687
00688
00689
00690 void SPMKSimulation::compute_surf_vocabulary()
00691 {
00692 LINFO("MVN: counting lines in %s", surf_descriptors_file().c_str()) ;
00693 int num_rows = count_lines(surf_descriptors_file()) ;
00694
00695 LINFO("MVN: reading %d SURF descriptors from %s",
00696 num_rows, surf_descriptors_file().c_str()) ;
00697
00698 load_surf_descriptors(surf_descriptors_file(), num_rows) ;
00699
00700
00701
00702
00703
00704 LINFO("MVN: K-means done; saving SURF vocabulary to %s",
00705 surf_vocabulary_file().c_str()) ;
00706 save_vocabulary(surf_vocabulary_file()) ;
00707 }
00708
00709
00710
00711
00712
00713
00714 void load_surf_descriptors(const std::string& file_name, int num_rows)
00715 {
00716 int num_cols = GistEstimatorSurfPMK::SURF_DESCRIPTOR_SIZE ;
00717
00718 double d ; std::string dummy ;
00719 std::ifstream ifs(file_name.c_str()) ;
00720 for (int i = 0; i < num_rows; ++i)
00721 {
00722 std::string str ;
00723 std::getline(ifs, str) ;
00724 if (! ifs || str.empty()) {
00725 if (i == num_rows - 1)
00726 break ;
00727 else {
00728 throw std::runtime_error(file_name +
00729 ": missing SURF descriptors or other read error") ;
00730 }
00731 }
00732 std::istringstream line(str) ;
00733 line >> dummy >> dummy >> dummy >> dummy ;
00734
00735 for (int j = 0; j < num_cols; ++j) {
00736 if (! line) {
00737 throw std::runtime_error(file_name +
00738 ": missing SURF descriptor values on line " + to_string(i)) ;
00739 }
00740 line >> d ;
00741 }
00742 }
00743
00744
00745
00746 }
00747
00748
00749
00750 void kmeans(int K)
00751 {
00752
00753 LINFO("MVN: computing K-means cluster assignments with libPMK") ;
00754
00755 LINFO("MVN: cluster assignments done; computing centroids...") ;
00756
00757
00758 LFATAL("hum, I did nothing here, please fix my code!");
00759 }
00760
00761
00762 void save_vocabulary(
00763 const std::string& file_name)
00764 {
00765 std::ofstream ofs(file_name.c_str()) ;
00766
00767
00768
00769
00770
00771
00772
00773 }
00774
00775
00776 Vocabulary load_vocabulary(const std::string& file_name)
00777 {
00778 const int M = count_lines(file_name) ;
00779 const int N = GistEstimatorSurfPMK::SURF_DESCRIPTOR_SIZE ;
00780 Vocabulary V(N, M, ZEROS) ;
00781
00782 float f ;
00783 std::ifstream ifs(file_name.c_str()) ;
00784 for (int j = 0; j < M; ++j)
00785 for (int i = 0; i < N; ++i) {
00786 if (! ifs)
00787 throw std::runtime_error(file_name + ": out of data?!?") ;
00788 ifs >> f ;
00789 V.setVal(i, j, f) ;
00790 }
00791
00792 return V ;
00793 }
00794
00795 }
00796
00797
00798
00799
00800
00801
00802
00803
00804 namespace {
00805
00806
00807 typedef Image<double> Histogram ;
00808
00809
00810 void save_histogram(const Histogram& histogram, const std::string& file_name,
00811 const std::string& image_name, int frame_number,
00812 const std::string& segment_number) ;
00813
00814
00815
00816
00817
00818
00819
00820
00821
00822 void SPMKSimulation::compute_training_histograms()
00823 {
00824 LFATAL("please fix me!!");
00825
00826
00827
00828
00829
00830
00831
00832
00833
00834
00835
00836
00837
00838
00839
00840
00841
00842
00843
00844
00845
00846
00847
00848
00849
00850
00851
00852
00853
00854
00855
00856
00857
00858
00859
00860
00861
00862 }
00863
00864
00865
00866
00867
00868
00869
00870
00871 void save_histogram(const Histogram& histogram, const std::string& file_name,
00872 const std::string& image_name, int frame_number,
00873 const std::string& segment_number)
00874 {
00875 std::ofstream ofs(file_name.c_str(), std::ios::out | std::ios::app) ;
00876 ofs << image_name << ':' << frame_number << ' '
00877 << segment_number << ' ' ;
00878 for (int y = 0; y < histogram.getHeight(); ++y)
00879 for (int x = 0; x < histogram.getWidth(); ++x)
00880 ofs << histogram.getVal(x, y) << ' ' ;
00881 ofs << '\n' ;
00882 }
00883
00884 }
00885
00886
00887
00888 namespace {
00889
00890
00891 void create_torch_dataset(const std::string&, const std::string&,
00892 const std::string&) ;
00893 Torch::SVMClassification* create_torch_classifier(const std::string&) ;
00894 std::string temp_file_name() ;
00895
00896
00897 void SPMKSimulation::generate_svm_classifier()
00898 {
00899 create_torch_dataset(histograms_file(), segment_number(), svm_temp_file()) ;
00900 Torch::SVMClassification* svm = create_torch_classifier(svm_temp_file()) ;
00901 svm->save(svm_classifier_file().c_str()) ;
00902 delete svm ;
00903 unlink(svm_temp_file().c_str()) ;
00904 }
00905
00906
00907 struct GistVector {
00908 std::vector<double> values ;
00909 GistVector() ;
00910 } ;
00911
00912 GistVector::GistVector()
00913 : values(GistEstimatorSurfPMK::GIST_VECTOR_SIZE)
00914 {}
00915
00916 std::istream& operator>>(std::istream& is, GistVector& g)
00917 {
00918 for (int i = 0; i < GistEstimatorSurfPMK::GIST_VECTOR_SIZE; ++i)
00919 is >> g.values[i] ;
00920 return is ;
00921 }
00922
00923 std::ostream& operator<<(std::ostream& os, const GistVector& g)
00924 {
00925 for (int i = 0; i < GistEstimatorSurfPMK::GIST_VECTOR_SIZE; ++i)
00926 os << g.values[i] << ' ' ;
00927 return os ;
00928 }
00929
00930
00931
00932
00933
00934
00935
00936 void create_torch_dataset(const std::string& hist_file,
00937 const std::string& target,
00938 const std::string& torch_dataset)
00939 {
00940 const int n = count_lines(hist_file) ;
00941
00942 std::ifstream in(hist_file.c_str()) ;
00943 std::ofstream out(torch_dataset.c_str()) ;
00944
00945 std::string dummy, segment, str ;
00946 GistVector gist_vector ;
00947 out << n << ' ' << (GistEstimatorSurfPMK::GIST_VECTOR_SIZE + 1) << '\n' ;
00948 for (int i = 0; i < n; ++i)
00949 {
00950 std::getline(in, str) ;
00951 if (! in || str.empty()) {
00952 if (i == n - 1)
00953 break ;
00954 else {
00955 out.close() ;
00956 unlink(torch_dataset.c_str()) ;
00957 throw std::runtime_error(hist_file +
00958 ": missing data or other read error") ;
00959 }
00960 }
00961 std::istringstream line(str) ;
00962 line >> dummy >> segment >> gist_vector ;
00963 out << gist_vector << ' ' << ((segment == target) ? +1 : -1) << '\n' ;
00964 }
00965 }
00966
00967
00968
00969 class HistIntKernel : public Torch::Kernel {
00970
00971 real eval(Torch::Sequence*, Torch::Sequence*) ;
00972 } ;
00973
00974
00975 real HistIntKernel::eval(Torch::Sequence* a, Torch::Sequence* b)
00976 {
00977
00978 real sum = 0 ;
00979 for (int i = 0; i < a->frame_size; ++i)
00980
00981 sum += min(a->frames[0][i], b->frames[0][i]) ;
00982 return sum ;
00983 }
00984
00985
00986
00987 Torch::SVMClassification* create_torch_classifier(const std::string& dataset)
00988 {
00989 HistIntKernel kernel ;
00990 Torch::SVMClassification* svm = new Torch::SVMClassification(& kernel) ;
00991 Torch::QCTrainer trainer(svm) ;
00992 Torch::MatDataSet data(dataset.c_str(),
00993 GistEstimatorSurfPMK::GIST_VECTOR_SIZE, 1) ;
00994 trainer.train(& data, 0) ;
00995 return svm ;
00996 }
00997
00998 }
00999
01000
01001
01002 namespace {
01003
01004
01005 typedef std::vector<Torch::SVMClassification*> Classifiers ;
01006
01007
01008 Classifiers load_classifiers(std::string, HistIntKernel*) ;
01009 void classify_image(const Histogram&, const Classifiers&,
01010 const std::string&, int, const std::string&,
01011 const std::string&) ;
01012 void nuke_classifiers(Classifiers&) ;
01013
01014
01015
01016
01017
01018
01019 void SPMKSimulation::classify_input_images()
01020 {
01021 LFATAL("please fixme too!");
01022
01023
01024
01025
01026
01027
01028
01029
01030
01031
01032
01033
01034
01035
01036
01037
01038
01039
01040
01041
01042
01043
01044
01045
01046
01047
01048
01049
01050
01051
01052
01053
01054
01055
01056
01057
01058
01059
01060
01061
01062
01063
01064
01065
01066 }
01067
01068
01069
01070
01071 void classify_image(const Histogram& gist_vector,
01072 const Classifiers& classifiers,
01073 const std::string& image_name, int frame_number,
01074 const std::string& ground_truth,
01075 const std::string& results_file)
01076 {
01077 std::ofstream ofs(results_file.c_str(), std::ios::out | std::ios::app) ;
01078 ofs << image_name << ':' << frame_number << ' ' << ground_truth << ' ' ;
01079
01080 Torch::Sequence gv(1, GistEstimatorSurfPMK::GIST_VECTOR_SIZE) ;
01081 std::copy(gist_vector.begin(), gist_vector.end(), gv.frames[0]) ;
01082
01083 int n = 0 ;
01084 const int N = classifiers.size() ;
01085 for (int i = 0; i < N; ++i) {
01086 classifiers[i]->forward(& gv) ;
01087 if (classifiers[i]->outputs->frames[0][0] > 0) {
01088 ofs << (i+1) << ' ' ;
01089 ++n ;
01090 }
01091 }
01092
01093 if (! n)
01094 ofs << '0' ;
01095 ofs << '\n' ;
01096 }
01097
01098
01099
01100
01101
01102
01103
01104
01105
01106
01107
01108
01109
01110 Classifiers
01111 load_classifiers(std::string classifiers_root_name, HistIntKernel* kernel)
01112 {
01113 classifiers_root_name += ".*" ;
01114 glob_t buf ;
01115 if (glob(classifiers_root_name.c_str(), 0, 0, & buf) != 0)
01116 throw std::runtime_error("couldn't find/load the SVM classifiers") ;
01117
01118 const int N = buf.gl_pathc ;
01119 Classifiers classifiers(N) ;
01120 for (int i = 0; i < N; ++i) {
01121 classifiers[i] = new Torch::SVMClassification(kernel) ;
01122 classifiers[i]->load(buf.gl_pathv[i]) ;
01123 }
01124
01125 globfree(& buf) ;
01126 return classifiers ;
01127 }
01128
01129
01130 void nuke_classifiers(Classifiers& C)
01131 {
01132 const int N = C.size() ;
01133 for (int i = 0; i < N; ++i)
01134 delete C[i] ;
01135 }
01136
01137 }
01138
01139
01140
01141 namespace {
01142
01143
01144 int count_lines(const std::string& file_name)
01145 {
01146 int n = -1 ;
01147 std::ifstream ifs(file_name.c_str()) ;
01148
01149 std::string dummy ;
01150 while (ifs) {
01151 getline(ifs, dummy) ;
01152 ++n ;
01153 }
01154 return n ;
01155 }
01156
01157 }
01158
01159
01160
01161 #endif // #if !defined(HAVE_OPENCV || INVT_HAVE_LIBPMK || INVT_HAVE_LIBTORCH)
01162
01163
01164
01165
01166