00001 #include <iostream> 00002 #include <fstream> 00003 #include <string> 00004 #include <sstream> 00005 #include <vector> 00006 #include <cmath> 00007 00008 #include "Learn/WiimoteGR/Quantizer.h" 00009 #include "Learn/WiimoteGR/HMMLib.h" 00010 #include "Learn/WiimoteGR/Database.h" 00011 #include "Image/Point2D.H" 00012 00013 using namespace std; 00014 using namespace WiimoteGR; 00015 00016 std::vector<Point2D<int> > square() 00017 { 00018 std::vector<Point2D<int> > acc; 00019 acc.push_back(Point2D<int>(5,0)); 00020 acc.push_back(Point2D<int>(5,0)); 00021 acc.push_back(Point2D<int>(5,0)); 00022 acc.push_back(Point2D<int>(5,0)); 00023 acc.push_back(Point2D<int>(0,0)); 00024 acc.push_back(Point2D<int>(0,0)); 00025 acc.push_back(Point2D<int>(0,0)); 00026 acc.push_back(Point2D<int>(0,0)); 00027 acc.push_back(Point2D<int>(-5,0)); 00028 acc.push_back(Point2D<int>(-5,0)); 00029 acc.push_back(Point2D<int>(-5,0)); 00030 acc.push_back(Point2D<int>(-5,0)); 00031 00032 acc.push_back(Point2D<int>(0,5)); 00033 acc.push_back(Point2D<int>(0,5)); 00034 acc.push_back(Point2D<int>(0,5)); 00035 acc.push_back(Point2D<int>(0,5)); 00036 acc.push_back(Point2D<int>(0,0)); 00037 acc.push_back(Point2D<int>(0,0)); 00038 acc.push_back(Point2D<int>(0,0)); 00039 acc.push_back(Point2D<int>(0,0)); 00040 acc.push_back(Point2D<int>(0,-5)); 00041 acc.push_back(Point2D<int>(0,-5)); 00042 acc.push_back(Point2D<int>(0,-5)); 00043 acc.push_back(Point2D<int>(0,-5)); 00044 00045 acc.push_back(Point2D<int>(-5,0)); 00046 acc.push_back(Point2D<int>(-5,0)); 00047 acc.push_back(Point2D<int>(-5,0)); 00048 acc.push_back(Point2D<int>(-5,0)); 00049 acc.push_back(Point2D<int>(0,0)); 00050 acc.push_back(Point2D<int>(0,0)); 00051 acc.push_back(Point2D<int>(0,0)); 00052 acc.push_back(Point2D<int>(0,0)); 00053 acc.push_back(Point2D<int>(5,0)); 00054 acc.push_back(Point2D<int>(5,0)); 00055 acc.push_back(Point2D<int>(5,0)); 00056 acc.push_back(Point2D<int>(5,0)); 00057 00058 acc.push_back(Point2D<int>(0,-5)); 00059 acc.push_back(Point2D<int>(0,-5)); 00060 acc.push_back(Point2D<int>(0,-5)); 00061 acc.push_back(Point2D<int>(0,-5)); 00062 acc.push_back(Point2D<int>(0,0)); 00063 acc.push_back(Point2D<int>(0,0)); 00064 acc.push_back(Point2D<int>(0,0)); 00065 acc.push_back(Point2D<int>(0,0)); 00066 acc.push_back(Point2D<int>(0,5)); 00067 acc.push_back(Point2D<int>(0,5)); 00068 acc.push_back(Point2D<int>(0,5)); 00069 acc.push_back(Point2D<int>(0,5)); 00070 00071 acc.push_back(Point2D<int>(0,0)); 00072 acc.push_back(Point2D<int>(0,0)); 00073 acc.push_back(Point2D<int>(0,0)); 00074 acc.push_back(Point2D<int>(0,0)); 00075 00076 return acc; 00077 00078 } 00079 00080 std::vector<Point2D<int> > triangle() 00081 { 00082 std::vector<Point2D<int> > acc; 00083 acc.push_back(Point2D<int>(5,0)); 00084 acc.push_back(Point2D<int>(5,0)); 00085 acc.push_back(Point2D<int>(5,0)); 00086 acc.push_back(Point2D<int>(5,0)); 00087 acc.push_back(Point2D<int>(0,0)); 00088 acc.push_back(Point2D<int>(0,0)); 00089 acc.push_back(Point2D<int>(0,0)); 00090 acc.push_back(Point2D<int>(0,0)); 00091 acc.push_back(Point2D<int>(-5,0)); 00092 acc.push_back(Point2D<int>(-5,0)); 00093 acc.push_back(Point2D<int>(-5,0)); 00094 acc.push_back(Point2D<int>(-5,0)); 00095 00096 acc.push_back(Point2D<int>(-5,5)); 00097 acc.push_back(Point2D<int>(-5,5)); 00098 acc.push_back(Point2D<int>(0,0)); 00099 acc.push_back(Point2D<int>(0,0)); 00100 acc.push_back(Point2D<int>(0,0)); 00101 acc.push_back(Point2D<int>(0,0)); 00102 acc.push_back(Point2D<int>(0,0)); 00103 acc.push_back(Point2D<int>(0,0)); 00104 acc.push_back(Point2D<int>(5,-5)); 00105 acc.push_back(Point2D<int>(5,-5)); 00106 00107 acc.push_back(Point2D<int>(-5,-5)); 00108 acc.push_back(Point2D<int>(-5,-5)); 00109 acc.push_back(Point2D<int>(0,0)); 00110 acc.push_back(Point2D<int>(0,0)); 00111 acc.push_back(Point2D<int>(0,0)); 00112 acc.push_back(Point2D<int>(0,0)); 00113 acc.push_back(Point2D<int>(0,0)); 00114 acc.push_back(Point2D<int>(0,0)); 00115 acc.push_back(Point2D<int>(5,5)); 00116 acc.push_back(Point2D<int>(5,5)); 00117 00118 acc.push_back(Point2D<int>(0,0)); 00119 acc.push_back(Point2D<int>(0,0)); 00120 acc.push_back(Point2D<int>(0,0)); 00121 acc.push_back(Point2D<int>(0,0)); 00122 00123 return acc; 00124 00125 } 00126 00127 00128 void Training(Database& db, const Quantizer& quantizer, HMMLib& trainer, HMM& hmm); 00129 00130 00131 int main() 00132 { 00133 /* initial HMM model */ 00134 const char* initGestureName = "unknown"; 00135 //Quantizer used, contain information of M 00136 M32Quantizer defaultQuantizer; 00137 const double OP = 1.0/defaultQuantizer.M; 00138 //style of model 00139 const char* initModelStyle = "5 state left to right"; 00140 //not trained initially 00141 bool initTrained = false; 00142 //num of states 00143 const size_t initN = 5; 00144 //matrices 00145 double initA[] = { 0.5, 0.5, 0.0, 0.0, 0.0, 00146 0.0, 0.5, 0.5, 0.0, 0.0, 00147 0.0, 0.0, 0.5, 0.5, 0.0, 00148 0.0, 0.0, 0.0, 0.5, 0.5, 00149 0.0, 0.0, 0.0, 0.0, 1.0 00150 }; 00151 //B: N=5 * M=32 00152 double initB[] = { OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, 00153 OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, 00154 OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, 00155 OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, 00156 OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP, OP 00157 }; 00158 double initPi[] = { 1.0, 0.0, 0.0, 0.0, 0.0}; 00159 00160 HMM initHMM(initGestureName, defaultQuantizer, initModelStyle, initTrained, initN, initA, initB, initPi); 00161 00162 00163 //database of HMMs, seqs gestures 00164 Database& db = Database::Open(); 00165 00166 //HMM Trainining Library 00167 HMMLib trainer; 00168 00169 ////Train the guesters 00170 //{ 00171 // //push(copy) initial HMM to vector 00172 // HMM tempHMM = initHMM; 00173 // tempHMM.gestureName = "square"; 00174 // //train and save HMM to database 00175 // Training(db, defaultQuantizer, trainer, tempHMM); 00176 //} 00177 00178 //{ 00179 // //push(copy) initial HMM to vector 00180 // HMM tempHMM = initHMM; 00181 // tempHMM.gestureName = "triangle"; 00182 // //train and save HMM to database 00183 // Training(db, defaultQuantizer, trainer, tempHMM); 00184 //} 00185 00186 /* 00187 Testing 00188 */ 00189 cout << endl << "============================TESTING===============================" << endl; 00190 vector<HMM> loadedHMMVec; 00191 00192 //temp sequence for user input 00193 TimeSlot tempSeq("testingGesture",defaultQuantizer); 00194 Acceleration tempAcc; 00195 00196 //load trained HMMsfrom database for testing 00197 db.LoadHMMs(defaultQuantizer, initModelStyle, true, loadedHMMVec); 00198 cout << loadedHMMVec.size() << " pre saved HMM loaded for testing, now gestures are:" << endl; 00199 for(vector<HMM>::iterator i=loadedHMMVec.begin() ; i<loadedHMMVec.end(); i++) 00200 cout << i->gestureName << " "; 00201 cout << endl; 00202 00203 cout << "Start recognition of " << loadedHMMVec.size() << " gesture models." << endl; 00204 00205 std::vector<Point2D<int> > acc = triangle(); 00206 00207 //Get gesture 00208 for(size_t i=0; i<acc.size(); i++) 00209 { 00210 tempAcc.x=acc[i].i; 00211 tempAcc.y=acc[i].j; 00212 tempAcc.z=0; 00213 00214 tempSeq.AddObservableSymbol(defaultQuantizer.Quantize(tempAcc)); 00215 //cout.width(3); 00216 //cout << defaultQuantizer.Quantize(tempAcc) << " "; 00217 cout << trainer.Recognize(loadedHMMVec,tempSeq).gestureName << "-> "; 00218 } 00219 00220 //show probs to trained HMM 00221 cout << endl; 00222 for(size_t i = 0; i < loadedHMMVec.size(); i++) 00223 cout << "Prob to " << loadedHMMVec[i].gestureName << " model = " << exp(trainer.SeqLogProb(loadedHMMVec[i],tempSeq,false)) << endl; 00224 tempSeq.ClearObservableSymbols(); 00225 00226 return 0; 00227 } 00228 00229 void Training(Database& db, const Quantizer& quantizer, 00230 HMMLib& trainer, HMM& hmm){ 00231 00232 size_t M = quantizer.M; 00233 if(hmm.M == M){ 00234 //training gestures 00235 vector<Gesture> gestureVec; 00236 //training sequences 00237 vector<TimeSlot> seqVec; 00238 00239 //Load existing gestures for this model 00240 //db.LoadGestures(hmm.gestureName.c_str(), gestureVec); 00241 //db.LoadObservationSequences(hmm.gestureName.c_str(),quantizer,seqVec); 00242 00243 //cout << gestureVec.size() << " samples of " << hmm.gestureName << " gesture loaded" << endl; 00244 //cout << seqVec.size() << " observation sequences of " << hmm.gestureName << " gesture loaded" << endl; 00245 00246 //Delete the gustures and start from scratch 00247 db.DeleteGestures(hmm.gestureName.c_str()); 00248 db.DeleteObservationSequences(hmm.gestureName.c_str(),quantizer); 00249 cout << "Sequences of gesture " << hmm.gestureName << " deleted from database." << endl; 00250 00251 //temp gesture for user input 00252 Gesture tempGesture(hmm.gestureName.c_str()); 00253 00254 trainer.ShowHMM(hmm); 00255 std::vector<Point2D<int> > acc; 00256 if (hmm.gestureName == "square") 00257 acc = square(); 00258 else 00259 acc = triangle(); 00260 00261 //Feed the points in order 00262 00263 00264 for(size_t j=0; j<acc.size(); j++) 00265 { 00266 00267 for(size_t i=0; i<acc.size(); i++) 00268 { 00269 Acceleration tempAcc; 00270 tempAcc.x=acc[(j+i)%acc.size()].i; 00271 tempAcc.y=acc[(j+i)%acc.size()].j; 00272 tempAcc.z=0; 00273 00274 tempGesture.data.push_back(tempAcc); 00275 cout.width(3); 00276 cout << quantizer.Quantize(tempAcc) << " "; 00277 } 00278 00279 TimeSlot tempSeq(hmm.gestureName.c_str(), quantizer); 00280 00281 quantizer.Quantize(tempGesture,tempSeq); 00282 cout << "/ Length = " << tempSeq.o.size() << endl << endl; 00283 00284 gestureVec.push_back(tempGesture); 00285 seqVec.push_back(tempSeq); 00286 00287 db.SaveGesture(tempGesture); 00288 db.SaveObservationSequence(tempSeq); 00289 tempGesture.data.clear(); 00290 } 00291 00292 00293 /* 00294 Generate .sce file for plotting by scilab 00295 */ 00296 //genSce(hmm.gestureName,gestureVec); 00297 00298 //show probs to initial HMM 00299 for(size_t i = 0; i < seqVec.size(); i++) 00300 cout << "prob of seqs[" << i << "] to " << hmm.gestureName << " model = " << trainer.SeqLogProb(hmm,seqVec[i],false) << endl; 00301 //train by multi seq 00302 cout << hmm.gestureName << " gesture model is trained in " << trainer.EstimateModelBySeqs(hmm, seqVec, 50) << " loops" << endl; 00303 00304 hmm.trained = true; 00305 trainer.ShowHMM(hmm); 00306 db.SaveHMM(hmm); 00307 00308 //show probs to trained HMM 00309 for(size_t i = 0; i < seqVec.size(); i++) 00310 cout << "prob of seqs[" << i << "] to " << hmm.gestureName << " model = " << trainer.SeqLogProb(hmm,seqVec[i],false) << endl; 00311 00312 }else{ 00313 cout << "ERROR hmm.M != quantizer.M" << endl; 00314 } 00315 } 00316 00317 /* 00318 00319 double rho(double x, double y, double z) 00320 { 00321 return sqrt(x*x+y*y+z*z); 00322 } 00323 00324 void genSce(string gestureName, vector<Gesture>& gestureVec) 00325 { 00326 ofstream sceFile; 00327 00328 for(size_t i=0;i<gestureVec.size();i++){ 00329 00330 stringstream ss; 00331 00332 ss << gestureName << i/10 << i%10; 00333 00334 sceFile.open( (ss.str()+".sce").c_str() ); 00335 if(!sceFile) 00336 cout << "error" << endl; 00337 00338 sceFile << "t=[0:" << (gestureVec[i].data.size()-1) << "];" << endl; 00339 00340 sceFile << "x=["; 00341 for(size_t j=0; j<gestureVec[i].data.size(); j++) 00342 sceFile << gestureVec[i].data[j].x << " "; 00343 sceFile << "];" << endl; 00344 00345 sceFile << "y=["; 00346 for(size_t j=0; j<gestureVec[i].data.size(); j++) 00347 sceFile << gestureVec[i].data[j].y << " "; 00348 sceFile << "];" << endl; 00349 00350 sceFile << "z=["; 00351 for(size_t j=0; j<gestureVec[i].data.size(); j++) 00352 sceFile << gestureVec[i].data[j].z << " "; 00353 sceFile << "];" << endl; 00354 00355 sceFile << "r=["; 00356 for(size_t j=0; j<gestureVec[i].data.size(); j++) 00357 sceFile << rho(gestureVec[i].data[j].x,gestureVec[i].data[j].y,gestureVec[i].data[j].z) << " "; 00358 sceFile << "];" << endl; 00359 00360 sceFile << "plot(t,x,\"ko-\",t,y,\"kx-\",t,z,\"k>-\",t,r,\"k.-\");\n" 00361 "a=gca();\n" 00362 "a.x_label.text=\"time(10ms)\";\n" 00363 "a.x_label.font_size = 2;\n" 00364 "a.x_label.font_style = 8;\n" 00365 "a.y_label.text=\"acceleration(g)\";\n" 00366 "a.y_label.font_size = 2;\n" 00367 "a.y_label.font_style = 8;\n" 00368 00369 "a.title.text = \"" << ss.str() << "\";\n" 00370 "a.title.font_size = 4;\n" 00371 "a.title.font_style = 5;\n" 00372 00373 "a.font_size = 2;\n" 00374 "a.x_location = \"middle\";\n" 00375 00376 //"l=legend([\"x\",\"y\",\"z\",\"r\"]);\n" 00377 //"l.children(1).font_style=1;" 00378 << endl; 00379 00380 sceFile.close(); 00381 } 00382 } 00383 */