test-HMMSeq.C

00001 #include <iostream>
00002 #include <fstream>
00003 #include <string>
00004 #include <sstream>
00005 #include <vector>
00006 #include <cmath>
00007 
00008 #include "Learn/WiimoteGR/Quantizer.h"
00009 #include "Learn/WiimoteGR/HMMLib.h"
00010 #include "Learn/WiimoteGR/Database.h"
00011 #include "Image/Point2D.H"
00012 
00013 using namespace std;
00014 using namespace WiimoteGR;
00015 
00016 std::vector<Point2D<int> > square()
00017 {
00018   std::vector<Point2D<int> > acc;
00019   acc.push_back(Point2D<int>(5,0));
00020   acc.push_back(Point2D<int>(5,0));
00021   acc.push_back(Point2D<int>(5,0));
00022   acc.push_back(Point2D<int>(5,0));
00023   acc.push_back(Point2D<int>(0,0));
00024   acc.push_back(Point2D<int>(0,0));
00025   acc.push_back(Point2D<int>(0,0));
00026   acc.push_back(Point2D<int>(0,0));
00027   acc.push_back(Point2D<int>(-5,0));
00028   acc.push_back(Point2D<int>(-5,0));
00029   acc.push_back(Point2D<int>(-5,0));
00030   acc.push_back(Point2D<int>(-5,0));
00031 
00032   acc.push_back(Point2D<int>(0,5));
00033   acc.push_back(Point2D<int>(0,5));
00034   acc.push_back(Point2D<int>(0,5));
00035   acc.push_back(Point2D<int>(0,5));
00036   acc.push_back(Point2D<int>(0,0));
00037   acc.push_back(Point2D<int>(0,0));
00038   acc.push_back(Point2D<int>(0,0));
00039   acc.push_back(Point2D<int>(0,0));
00040   acc.push_back(Point2D<int>(0,-5));
00041   acc.push_back(Point2D<int>(0,-5));
00042   acc.push_back(Point2D<int>(0,-5));
00043   acc.push_back(Point2D<int>(0,-5));
00044 
00045   acc.push_back(Point2D<int>(-5,0));
00046   acc.push_back(Point2D<int>(-5,0));
00047   acc.push_back(Point2D<int>(-5,0));
00048   acc.push_back(Point2D<int>(-5,0));
00049   acc.push_back(Point2D<int>(0,0));
00050   acc.push_back(Point2D<int>(0,0));
00051   acc.push_back(Point2D<int>(0,0));
00052   acc.push_back(Point2D<int>(0,0));
00053   acc.push_back(Point2D<int>(5,0));
00054   acc.push_back(Point2D<int>(5,0));
00055   acc.push_back(Point2D<int>(5,0));
00056   acc.push_back(Point2D<int>(5,0));
00057 
00058   acc.push_back(Point2D<int>(0,-5));
00059   acc.push_back(Point2D<int>(0,-5));
00060   acc.push_back(Point2D<int>(0,-5));
00061   acc.push_back(Point2D<int>(0,-5));
00062   acc.push_back(Point2D<int>(0,0));
00063   acc.push_back(Point2D<int>(0,0));
00064   acc.push_back(Point2D<int>(0,0));
00065   acc.push_back(Point2D<int>(0,0));
00066   acc.push_back(Point2D<int>(0,5));
00067   acc.push_back(Point2D<int>(0,5));
00068   acc.push_back(Point2D<int>(0,5));
00069   acc.push_back(Point2D<int>(0,5));
00070 
00071   acc.push_back(Point2D<int>(0,0));
00072   acc.push_back(Point2D<int>(0,0));
00073   acc.push_back(Point2D<int>(0,0));
00074   acc.push_back(Point2D<int>(0,0));
00075 
00076   return acc;
00077 
00078 }
00079 
00080 std::vector<Point2D<int> > triangle()
00081 {
00082   std::vector<Point2D<int> > acc;
00083   acc.push_back(Point2D<int>(5,0));
00084   acc.push_back(Point2D<int>(5,0));
00085   acc.push_back(Point2D<int>(5,0));
00086   acc.push_back(Point2D<int>(5,0));
00087   acc.push_back(Point2D<int>(0,0));
00088   acc.push_back(Point2D<int>(0,0));
00089   acc.push_back(Point2D<int>(0,0));
00090   acc.push_back(Point2D<int>(0,0));
00091   acc.push_back(Point2D<int>(-5,0));
00092   acc.push_back(Point2D<int>(-5,0));
00093   acc.push_back(Point2D<int>(-5,0));
00094   acc.push_back(Point2D<int>(-5,0));
00095 
00096   acc.push_back(Point2D<int>(-5,5));
00097   acc.push_back(Point2D<int>(-5,5));
00098   acc.push_back(Point2D<int>(0,0));
00099   acc.push_back(Point2D<int>(0,0));
00100   acc.push_back(Point2D<int>(0,0));
00101   acc.push_back(Point2D<int>(0,0));
00102   acc.push_back(Point2D<int>(0,0));
00103   acc.push_back(Point2D<int>(0,0));
00104   acc.push_back(Point2D<int>(5,-5));
00105   acc.push_back(Point2D<int>(5,-5));
00106 
00107   acc.push_back(Point2D<int>(-5,-5));
00108   acc.push_back(Point2D<int>(-5,-5));
00109   acc.push_back(Point2D<int>(0,0));
00110   acc.push_back(Point2D<int>(0,0));
00111   acc.push_back(Point2D<int>(0,0));
00112   acc.push_back(Point2D<int>(0,0));
00113   acc.push_back(Point2D<int>(0,0));
00114   acc.push_back(Point2D<int>(0,0));
00115   acc.push_back(Point2D<int>(5,5));
00116   acc.push_back(Point2D<int>(5,5));
00117 
00118   acc.push_back(Point2D<int>(0,0));
00119   acc.push_back(Point2D<int>(0,0));
00120   acc.push_back(Point2D<int>(0,0));
00121   acc.push_back(Point2D<int>(0,0));
00122 
00123   return acc;
00124 
00125 }
00126 
00127 
00128 void Training(Database& db, const Quantizer& quantizer, HMMLib& trainer, HMM& hmm);
00129 
00130 
00131 int main()
00132 {
00133     /* initial HMM model */
00134     const char* initGestureName = "unknown";
00135     //Quantizer used, contain information of M
00136     M32Quantizer defaultQuantizer;
00137     const double OP = 1.0/defaultQuantizer.M;
00138     //style of model
00139     const char* initModelStyle = "5 state left to right";
00140     //not trained initially
00141     bool initTrained = false;
00142     //num of states
00143     const size_t initN = 5;
00144     //matrices
00145     double initA[] = { 0.5, 0.5, 0.0, 0.0, 0.0,
00146                        0.0, 0.5, 0.5, 0.0, 0.0,
00147                        0.0, 0.0, 0.5, 0.5, 0.0,
00148                        0.0, 0.0, 0.0, 0.5, 0.5,
00149                        0.0, 0.0, 0.0, 0.0, 1.0
00150                      };
00151     //B: N=5 * M=32
00152     double initB[] = { OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP,
00153                        OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP,
00154                        OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP,
00155                        OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP,
00156                        OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP
00157                      };
00158     double initPi[] = { 1.0, 0.0, 0.0, 0.0, 0.0};
00159     
00160     HMM initHMM(initGestureName, defaultQuantizer, initModelStyle, initTrained, initN, initA, initB, initPi);
00161 
00162 
00163     //database of HMMs, seqs gestures
00164     Database& db = Database::Open();
00165 
00166     //HMM Trainining Library
00167     HMMLib trainer;
00168 
00169     ////Train the guesters
00170     //{
00171     //  //push(copy) initial HMM to vector
00172     //  HMM tempHMM = initHMM;
00173     //  tempHMM.gestureName = "square";
00174     //  //train and save HMM to database
00175     //  Training(db, defaultQuantizer, trainer, tempHMM);
00176     //}
00177 
00178     //{
00179     //  //push(copy) initial HMM to vector
00180     //  HMM tempHMM = initHMM;
00181     //  tempHMM.gestureName = "triangle";
00182     //  //train and save HMM to database
00183     //  Training(db, defaultQuantizer, trainer, tempHMM);
00184     //}
00185 
00186     /*
00187         Testing
00188     */
00189     cout << endl << "============================TESTING===============================" << endl;
00190     vector<HMM> loadedHMMVec;
00191 
00192     //temp sequence for user input
00193     TimeSlot tempSeq("testingGesture",defaultQuantizer);
00194     Acceleration tempAcc;
00195 
00196     //load trained HMMsfrom database for testing
00197     db.LoadHMMs(defaultQuantizer, initModelStyle, true, loadedHMMVec);
00198     cout << loadedHMMVec.size() << " pre saved HMM loaded for testing, now gestures are:" << endl;
00199     for(vector<HMM>::iterator i=loadedHMMVec.begin() ; i<loadedHMMVec.end(); i++)
00200         cout << i->gestureName << " ";
00201     cout << endl;
00202 
00203     cout << "Start recognition of " << loadedHMMVec.size() << " gesture models." << endl;
00204 
00205     std::vector<Point2D<int> > acc = triangle();
00206 
00207     //Get gesture
00208     for(size_t i=0; i<acc.size(); i++)
00209     {
00210       tempAcc.x=acc[i].i; 
00211       tempAcc.y=acc[i].j; 
00212       tempAcc.z=0;
00213 
00214       tempSeq.AddObservableSymbol(defaultQuantizer.Quantize(tempAcc));
00215       //cout.width(3);
00216       //cout << defaultQuantizer.Quantize(tempAcc) << " ";
00217       cout << trainer.Recognize(loadedHMMVec,tempSeq).gestureName << "-> ";
00218     }
00219 
00220     //show probs to trained HMM
00221     cout << endl;
00222     for(size_t i = 0; i < loadedHMMVec.size(); i++)
00223       cout << "Prob to " << loadedHMMVec[i].gestureName << " model = " << exp(trainer.SeqLogProb(loadedHMMVec[i],tempSeq,false)) << endl;
00224     tempSeq.ClearObservableSymbols();
00225 
00226     return 0;
00227 }
00228 
00229 void Training(Database& db, const Quantizer& quantizer,
00230     HMMLib& trainer, HMM& hmm){
00231 
00232     size_t M = quantizer.M;
00233     if(hmm.M == M){
00234         //training gestures
00235         vector<Gesture> gestureVec;
00236         //training sequences
00237         vector<TimeSlot> seqVec;
00238 
00239         //Load existing gestures for this model
00240         //db.LoadGestures(hmm.gestureName.c_str(), gestureVec);
00241         //db.LoadObservationSequences(hmm.gestureName.c_str(),quantizer,seqVec);
00242 
00243         //cout << gestureVec.size() << " samples of " << hmm.gestureName << " gesture loaded" << endl;
00244         //cout << seqVec.size() << " observation sequences of " << hmm.gestureName << " gesture loaded" << endl;
00245 
00246         //Delete the gustures and start from scratch
00247         db.DeleteGestures(hmm.gestureName.c_str());
00248         db.DeleteObservationSequences(hmm.gestureName.c_str(),quantizer);
00249         cout << "Sequences of gesture " << hmm.gestureName << " deleted from database." << endl;
00250 
00251         //temp gesture for user input
00252         Gesture tempGesture(hmm.gestureName.c_str());
00253 
00254         trainer.ShowHMM(hmm);
00255         std::vector<Point2D<int> > acc;
00256         if (hmm.gestureName == "square")
00257           acc = square();
00258         else
00259           acc = triangle();
00260         
00261         //Feed the points in order
00262 
00263 
00264         for(size_t j=0; j<acc.size(); j++)
00265         {
00266 
00267           for(size_t i=0; i<acc.size(); i++)
00268           {
00269             Acceleration tempAcc;
00270             tempAcc.x=acc[(j+i)%acc.size()].i; 
00271             tempAcc.y=acc[(j+i)%acc.size()].j; 
00272             tempAcc.z=0; 
00273 
00274             tempGesture.data.push_back(tempAcc);
00275             cout.width(3);
00276             cout << quantizer.Quantize(tempAcc) << " ";
00277           }
00278 
00279           TimeSlot tempSeq(hmm.gestureName.c_str(), quantizer);
00280 
00281           quantizer.Quantize(tempGesture,tempSeq);
00282           cout << "/ Length = " << tempSeq.o.size() << endl << endl;
00283 
00284           gestureVec.push_back(tempGesture);
00285           seqVec.push_back(tempSeq);
00286 
00287           db.SaveGesture(tempGesture);
00288           db.SaveObservationSequence(tempSeq);
00289           tempGesture.data.clear();
00290         }
00291 
00292 
00293         /*
00294         Generate .sce file for plotting by scilab
00295         */
00296         //genSce(hmm.gestureName,gestureVec);
00297 
00298         //show probs to initial HMM
00299         for(size_t i = 0; i < seqVec.size(); i++)
00300             cout << "prob of seqs[" << i << "] to " << hmm.gestureName << " model = " << trainer.SeqLogProb(hmm,seqVec[i],false) << endl;
00301         //train by multi seq
00302         cout << hmm.gestureName << " gesture model is trained in " << trainer.EstimateModelBySeqs(hmm, seqVec, 50) << " loops" << endl;
00303 
00304         hmm.trained = true;
00305         trainer.ShowHMM(hmm);
00306         db.SaveHMM(hmm);
00307 
00308         //show probs to trained HMM
00309         for(size_t i = 0; i < seqVec.size(); i++)
00310             cout << "prob of seqs[" << i << "] to " << hmm.gestureName << " model = " << trainer.SeqLogProb(hmm,seqVec[i],false) << endl;
00311         
00312     }else{
00313         cout << "ERROR hmm.M != quantizer.M" << endl;
00314     }
00315 }
00316 
00317 /*
00318 
00319 double rho(double x, double y, double z)
00320 {
00321     return sqrt(x*x+y*y+z*z);
00322 }
00323 
00324 void genSce(string gestureName, vector<Gesture>& gestureVec)
00325 {
00326     ofstream sceFile;
00327 
00328     for(size_t i=0;i<gestureVec.size();i++){
00329         
00330         stringstream ss;
00331  
00332         ss << gestureName << i/10 << i%10;
00333 
00334         sceFile.open( (ss.str()+".sce").c_str() );
00335         if(!sceFile)
00336             cout << "error" << endl;
00337         
00338         sceFile << "t=[0:" << (gestureVec[i].data.size()-1) << "];" << endl;
00339 
00340         sceFile << "x=[";
00341         for(size_t j=0; j<gestureVec[i].data.size(); j++)
00342             sceFile << gestureVec[i].data[j].x << " ";
00343         sceFile << "];" << endl;
00344 
00345         sceFile << "y=[";
00346         for(size_t j=0; j<gestureVec[i].data.size(); j++)
00347             sceFile << gestureVec[i].data[j].y << " ";
00348         sceFile << "];" << endl;
00349 
00350         sceFile << "z=[";
00351         for(size_t j=0; j<gestureVec[i].data.size(); j++)
00352             sceFile << gestureVec[i].data[j].z << " ";
00353         sceFile << "];" << endl;
00354 
00355         sceFile << "r=[";
00356         for(size_t j=0; j<gestureVec[i].data.size(); j++)
00357             sceFile << rho(gestureVec[i].data[j].x,gestureVec[i].data[j].y,gestureVec[i].data[j].z) << " ";
00358         sceFile << "];" << endl;
00359 
00360         sceFile <<  "plot(t,x,\"ko-\",t,y,\"kx-\",t,z,\"k>-\",t,r,\"k.-\");\n"
00361                     "a=gca();\n"
00362                     "a.x_label.text=\"time(10ms)\";\n"
00363                     "a.x_label.font_size = 2;\n"
00364                     "a.x_label.font_style = 8;\n"
00365                     "a.y_label.text=\"acceleration(g)\";\n"
00366                     "a.y_label.font_size = 2;\n"
00367                     "a.y_label.font_style = 8;\n"
00368 
00369                     "a.title.text = \"" << ss.str() << "\";\n"
00370                     "a.title.font_size = 4;\n"
00371                     "a.title.font_style = 5;\n"
00372 
00373                     "a.font_size = 2;\n"
00374                     "a.x_location = \"middle\";\n"
00375 
00376                     //"l=legend([\"x\",\"y\",\"z\",\"r\"]);\n"
00377                     //"l.children(1).font_style=1;"
00378                     << endl;
00379 
00380         sceFile.close();
00381     }
00382 }
00383 */
Generated on Sun May 8 08:40:59 2011 for iLab Neuromorphic Vision Toolkit by  doxygen 1.6.3