test-HMM.C

00001 #include <iostream>
00002 #include <fstream>
00003 #include <string>
00004 #include <sstream>
00005 #include <vector>
00006 #include <cmath>
00007 #include "Image/Point2D.H"
00008 #include "Image/Dims.H"
00009 #include "Image/DrawOps.H"
00010 #include "Learn/HMM.H"
00011 #include "GUI/DebugWin.H"
00012 
00013 std::vector<Point2D<int> > square(Point2D<int> pos, Dims dim)
00014 {
00015   std::vector<Point2D<int> > lines;
00016   lines.push_back(Point2D<int>(pos.i,pos.j));
00017   lines.push_back(Point2D<int>(pos.i+dim.w(),pos.j));
00018   lines.push_back(Point2D<int>(pos.i+dim.w(),pos.j+dim.h()));
00019   lines.push_back(Point2D<int>(pos.i,pos.j+dim.h()));
00020   lines.push_back(Point2D<int>(pos.i,pos.j));
00021   return lines;
00022 }
00023 
00024 std::vector<Point2D<int> > triangle(Point2D<int> pos)
00025 {
00026   std::vector<Point2D<int> > lines;
00027   lines.push_back(Point2D<int>(pos.i,pos.j));
00028   lines.push_back(Point2D<int>(pos.i+100,pos.j));
00029   lines.push_back(Point2D<int>(pos.i+100-50,pos.j+70));
00030   lines.push_back(Point2D<int>(pos.i,pos.j));
00031 
00032   return lines;
00033 }
00034 
00035 std::vector<Point2D<int> > applelogo()
00036 {
00037   std::vector<Point2D<int> > lines;
00038   lines.push_back(Point2D<int>(7, 71));
00039   lines.push_back(Point2D<int>(17, 88));
00040   lines.push_back(Point2D<int>(19, 90));
00041   lines.push_back(Point2D<int>(27, 95)); 
00042   lines.push_back(Point2D<int>(29, 95));
00043   lines.push_back(Point2D<int>(38, 91));
00044   lines.push_back(Point2D<int>(39, 91));
00045   lines.push_back(Point2D<int>(55, 95));
00046   lines.push_back(Point2D<int>(56, 95));
00047   lines.push_back(Point2D<int>(61, 93));
00048   lines.push_back(Point2D<int>(62, 93));
00049   lines.push_back(Point2D<int>(74, 73)); 
00050   lines.push_back(Point2D<int>(74, 72));
00051   lines.push_back(Point2D<int>(65, 62));
00052   lines.push_back(Point2D<int>(65, 61));
00053   lines.push_back(Point2D<int>(63, 57));
00054   lines.push_back(Point2D<int>(63, 56));
00055   lines.push_back(Point2D<int>(65, 42));
00056   lines.push_back(Point2D<int>(66, 42));
00057   lines.push_back(Point2D<int>(71, 36));
00058   lines.push_back(Point2D<int>(70, 35));
00059   lines.push_back(Point2D<int>(68, 33));
00060   lines.push_back(Point2D<int>(67, 32));
00061   lines.push_back(Point2D<int>(53, 29));
00062   lines.push_back(Point2D<int>(52, 30));
00063   lines.push_back(Point2D<int>(45, 32));
00064   lines.push_back(Point2D<int>(44, 27));
00065   lines.push_back(Point2D<int>(56, 15));
00066   lines.push_back(Point2D<int>(56, 14));
00067   lines.push_back(Point2D<int>(57, 7));
00068   lines.push_back(Point2D<int>(56, 6));
00069   lines.push_back(Point2D<int>(53, 7));
00070   lines.push_back(Point2D<int>(52, 7));
00071   lines.push_back(Point2D<int>(40, 19)); 
00072   lines.push_back(Point2D<int>(40, 20));
00073   lines.push_back(Point2D<int>(42, 26));
00074   lines.push_back(Point2D<int>(44, 33));
00075   lines.push_back(Point2D<int>(25, 29));
00076   lines.push_back(Point2D<int>(24, 29));
00077   lines.push_back(Point2D<int>(17, 31));
00078   lines.push_back(Point2D<int>(15, 32));
00079   lines.push_back(Point2D<int>(6, 43));
00080   lines.push_back(Point2D<int>(5, 45));
00081   lines.push_back(Point2D<int>(5, 64));
00082   lines.push_back(Point2D<int>(6, 65));
00083   lines.push_back(Point2D<int>(7, 70));
00084   
00085   return lines;
00086 }
00087 
00088 
00089 std::vector<Point2D<double> > getVel(const std::vector<Point2D<int> >& lines)
00090 {
00091   std::vector<Point2D<double> > vel;
00092 
00093   for(uint i=0; i<lines.size()-1; i++)
00094   {
00095     Point2D<int> dPos = lines[i+1]-lines[i];
00096     double mag = sqrt((dPos.i*dPos.i) + (dPos.j*dPos.j))/4;
00097     for(int j=0; j<int(mag+0.5); j++)
00098       vel.push_back(Point2D<double>(dPos/mag));
00099   }
00100 
00101   return vel;
00102 
00103 }
00104 
00105 int quantize(float x, float y, float z)
00106 {
00107     int val = 0;
00108 
00109     //Get the magnitude 
00110     double rho = sqrt(x*x+y*y+z*z);
00111 
00112     if (rho>3.0)
00113       val = 3<<3;
00114     else if (rho > 2.0)
00115       val = 2<<3;
00116     else if (rho>1.0)
00117       val = 1<<3;
00118     else
00119       val = 0;
00120 
00121     if (x>y) val |= 1<<2; 
00122     if (y>z) val |= 1<<1; 
00123     if (z>x) val |= 1; 
00124 
00125     return val;
00126 }
00127 
00128 //Train an HMM with specific observations
00129 HMM<uint> getHMM(const std::string& name, const std::vector<Point2D<int> >& lines)
00130 {
00131   //Add an hmm with 5 states and 32 possible observations
00132   std::vector<uint> states; //5 States
00133   for(uint i=0; i<5; i++) 
00134     states.push_back(i);
00135 
00136   std::vector<uint> posibleObservations; //32
00137   for(uint i=0; i<32; i++) 
00138     posibleObservations.push_back(i);
00139 
00140   HMM<uint> hmm(states, posibleObservations, name);
00141 
00142   //Set the default transitions
00143   hmm.setStateTransition(0, 0, 0.5);
00144   hmm.setStateTransition(0, 1, 0.5);
00145   hmm.setStateTransition(1, 1, 0.5);
00146   hmm.setStateTransition(1, 2, 0.5);
00147   hmm.setStateTransition(2, 2, 0.5);
00148   hmm.setStateTransition(2, 3, 0.5);
00149   hmm.setStateTransition(3, 3, 0.5);
00150   hmm.setStateTransition(3, 4, 0.5);
00151   hmm.setStateTransition(4, 4, 1);
00152 
00153   //set the initial sstate
00154   hmm.setCurrentState(0, 1); //We start at the first state
00155 
00156   //Quantize the acc values into 32 numbers to represent the observations
00157 
00158   std::vector<Point2D<double> > vel = getVel(lines);
00159 
00160   std::vector< std::vector<uint> > observations;
00161   for(size_t j=0; j<vel.size(); j++)
00162   {
00163     std::vector<uint> observation;
00164     printf("InputValue ");
00165     for(size_t i=0; i<vel.size(); i++)
00166     {
00167       uint value = quantize(vel[(i+j)%vel.size()].i,
00168                           vel[(i+j)%vel.size()].j, 0);
00169       printf("%i ", value);
00170       observation.push_back(value);
00171     }
00172     printf("\n");
00173     observations.push_back(observation);
00174   }
00175   LINFO("Train");
00176   hmm.train(observations, 50);
00177   LINFO("Done");
00178 
00179   //hmm.show(); //Show the internal state of the HMM
00180 
00181   return hmm;
00182 }
00183 
00184 int main()
00185 {
00186 
00187   //Testing the viterbi algorithm from wikipidia
00188   {
00189     //The posible states we can be in
00190     std::vector<std::string> states;
00191     states.push_back("Rainy");
00192     states.push_back("Sunny");
00193 
00194     //The posible observations, observed in each state
00195     std::vector<std::string> posibleObservations;
00196     posibleObservations.push_back("walk");
00197     posibleObservations.push_back("shop");
00198     posibleObservations.push_back("clean");
00199 
00200     //Initialize the hmm
00201     HMM<std::string> hmm(states, posibleObservations);
00202 
00203     //The Transition probability matrix;
00204     hmm.setStateTransition("Rainy", "Rainy", 0.7);
00205     hmm.setStateTransition("Rainy", "Sunny", 0.3);
00206     hmm.setStateTransition("Sunny", "Rainy", 0.4);
00207     hmm.setStateTransition("Sunny", "Sunny", 0.6);
00208 
00209 
00210     //////The state emission probability
00211     hmm.setStateEmission("Rainy", "walk", 0.1);
00212     hmm.setStateEmission("Rainy", "shop", 0.4);
00213     hmm.setStateEmission("Rainy", "clean", 0.5);
00214     hmm.setStateEmission("Sunny", "walk", 0.6);
00215     hmm.setStateEmission("Sunny", "shop", 0.3);
00216     hmm.setStateEmission("Sunny", "clean", 0.1);
00217 
00218     //Set our current state
00219     hmm.setCurrentState("Rainy", 0.6);
00220     hmm.setCurrentState("Sunny", 0.4);
00221 
00222     std::vector<std::string> observations;
00223     observations.push_back("walk");
00224     observations.push_back("shop");
00225     observations.push_back("clean");
00226 
00227     double prob;
00228     std::vector<std::string> path = hmm.getLikelyStates(observations, prob); 
00229     LINFO("FinalState prob %f", prob);
00230     printf("Path: ");
00231     for(uint i=0; i<path.size(); i++)
00232       printf("%s ", path[i].c_str());
00233     printf("\n");
00234   }
00235 
00236 
00237   //////////////////////////////////////////////////////////////////////////////////////////////////
00238   {
00239     //std::vector<Point2D<int> > lines = square(Point2D<int>(50,50), Dims(150,50));
00240     //std::vector<Point2D<int> > lines = triangle(Point2D<int>(50,50));
00241     std::vector<Point2D<int> > lines = applelogo(); //triangle(Point2D<int>(50,50));
00242 
00243     Image<PixRGB<byte> > img(320,240,ZEROS);
00244     for(size_t i=0; i<lines.size()-1; i++)
00245       drawLine(img, lines[i], lines[i+1], PixRGB<byte>(255,0,0));
00246 
00247     std::vector<Point2D<double> > vel = getVel(lines);
00248     //Show the vel
00249     Point2D<double> pos(150,150);
00250     for(uint i=0; i<vel.size(); i++)
00251     {
00252       if (img.coordsOk(Point2D<int>(pos)))
00253           img.setVal(Point2D<int>(pos), PixRGB<byte>(0,255,0));
00254       //LINFO("V: %f %f P: %f %f", vel[i].i, vel[i].j, pos.i, pos.j);
00255       pos += vel[i];
00256     }
00257 
00258     SHOWIMG(img);
00259 
00260 
00261     //Test the HMM for sequence recognition 
00262     std::vector<HMM<uint> > hmms;
00263 
00264     hmms.push_back(getHMM("Square", square(Point2D<int>(50,50), Dims(50,50))));
00265     hmms.push_back(getHMM("Triangle", triangle(Point2D<int>(50,50))));
00266 
00267     LINFO("Test the hmm");
00268     lines = triangle(Point2D<int>(70,70));
00269     //lines = square(Point2D<int>(5,5), Dims(100,150));
00270     vel = getVel(lines);
00271 
00272     for(uint j=0; j<vel.size(); j++)
00273     {
00274       std::vector<uint> observations; 
00275       LINFO("Observations");
00276       for(size_t i=0; i<vel.size(); i++)
00277       {
00278         uint value = quantize(vel[(i+j)%vel.size()].i,
00279                             vel[(i+j)%vel.size()].j, 0);
00280         printf("%i ", value);
00281         observations.push_back(value);
00282       }
00283       printf("\n");
00284 
00285       //Check each HMM to see if it has the probability of being the sequence
00286       double maxProb = -1e100;
00287       std::string name = "";
00288 
00289       for(size_t i=0; i<hmms.size(); i++)
00290       {
00291         double prob = hmms[i].forward(observations);
00292         LINFO("HMM %s prob %e", hmms[i].getName().c_str(), exp(prob));
00293         if (prob > maxProb)
00294         {
00295           maxProb = prob;
00296           name = hmms[i].getName();
00297         }
00298       }
00299 
00300       LINFO("Max is %s", name.c_str());
00301     }
00302   }
00303 
00304 }
Generated on Sun May 8 08:40:59 2011 for iLab Neuromorphic Vision Toolkit by  doxygen 1.6.3