00001 #include <iostream>
00002 #include <fstream>
00003 #include <string>
00004 #include <sstream>
00005 #include <vector>
00006 #include <cmath>
00007 #include "Image/Point2D.H"
00008 #include "Image/Dims.H"
00009 #include "Image/DrawOps.H"
00010 #include "Learn/HMM.H"
00011 #include "GUI/DebugWin.H"
00012
00013 std::vector<Point2D<int> > square(Point2D<int> pos, Dims dim)
00014 {
00015 std::vector<Point2D<int> > lines;
00016 lines.push_back(Point2D<int>(pos.i,pos.j));
00017 lines.push_back(Point2D<int>(pos.i+dim.w(),pos.j));
00018 lines.push_back(Point2D<int>(pos.i+dim.w(),pos.j+dim.h()));
00019 lines.push_back(Point2D<int>(pos.i,pos.j+dim.h()));
00020 lines.push_back(Point2D<int>(pos.i,pos.j));
00021 return lines;
00022 }
00023
00024 std::vector<Point2D<int> > triangle(Point2D<int> pos)
00025 {
00026 std::vector<Point2D<int> > lines;
00027 lines.push_back(Point2D<int>(pos.i,pos.j));
00028 lines.push_back(Point2D<int>(pos.i+100,pos.j));
00029 lines.push_back(Point2D<int>(pos.i+100-50,pos.j+70));
00030 lines.push_back(Point2D<int>(pos.i,pos.j));
00031
00032 return lines;
00033 }
00034
00035 std::vector<Point2D<int> > applelogo()
00036 {
00037 std::vector<Point2D<int> > lines;
00038 lines.push_back(Point2D<int>(7, 71));
00039 lines.push_back(Point2D<int>(17, 88));
00040 lines.push_back(Point2D<int>(19, 90));
00041 lines.push_back(Point2D<int>(27, 95));
00042 lines.push_back(Point2D<int>(29, 95));
00043 lines.push_back(Point2D<int>(38, 91));
00044 lines.push_back(Point2D<int>(39, 91));
00045 lines.push_back(Point2D<int>(55, 95));
00046 lines.push_back(Point2D<int>(56, 95));
00047 lines.push_back(Point2D<int>(61, 93));
00048 lines.push_back(Point2D<int>(62, 93));
00049 lines.push_back(Point2D<int>(74, 73));
00050 lines.push_back(Point2D<int>(74, 72));
00051 lines.push_back(Point2D<int>(65, 62));
00052 lines.push_back(Point2D<int>(65, 61));
00053 lines.push_back(Point2D<int>(63, 57));
00054 lines.push_back(Point2D<int>(63, 56));
00055 lines.push_back(Point2D<int>(65, 42));
00056 lines.push_back(Point2D<int>(66, 42));
00057 lines.push_back(Point2D<int>(71, 36));
00058 lines.push_back(Point2D<int>(70, 35));
00059 lines.push_back(Point2D<int>(68, 33));
00060 lines.push_back(Point2D<int>(67, 32));
00061 lines.push_back(Point2D<int>(53, 29));
00062 lines.push_back(Point2D<int>(52, 30));
00063 lines.push_back(Point2D<int>(45, 32));
00064 lines.push_back(Point2D<int>(44, 27));
00065 lines.push_back(Point2D<int>(56, 15));
00066 lines.push_back(Point2D<int>(56, 14));
00067 lines.push_back(Point2D<int>(57, 7));
00068 lines.push_back(Point2D<int>(56, 6));
00069 lines.push_back(Point2D<int>(53, 7));
00070 lines.push_back(Point2D<int>(52, 7));
00071 lines.push_back(Point2D<int>(40, 19));
00072 lines.push_back(Point2D<int>(40, 20));
00073 lines.push_back(Point2D<int>(42, 26));
00074 lines.push_back(Point2D<int>(44, 33));
00075 lines.push_back(Point2D<int>(25, 29));
00076 lines.push_back(Point2D<int>(24, 29));
00077 lines.push_back(Point2D<int>(17, 31));
00078 lines.push_back(Point2D<int>(15, 32));
00079 lines.push_back(Point2D<int>(6, 43));
00080 lines.push_back(Point2D<int>(5, 45));
00081 lines.push_back(Point2D<int>(5, 64));
00082 lines.push_back(Point2D<int>(6, 65));
00083 lines.push_back(Point2D<int>(7, 70));
00084
00085 return lines;
00086 }
00087
00088
00089 std::vector<Point2D<double> > getVel(const std::vector<Point2D<int> >& lines)
00090 {
00091 std::vector<Point2D<double> > vel;
00092
00093 for(uint i=0; i<lines.size()-1; i++)
00094 {
00095 Point2D<int> dPos = lines[i+1]-lines[i];
00096 double mag = sqrt((dPos.i*dPos.i) + (dPos.j*dPos.j))/4;
00097 for(int j=0; j<int(mag+0.5); j++)
00098 vel.push_back(Point2D<double>(dPos/mag));
00099 }
00100
00101 return vel;
00102
00103 }
00104
00105 int quantize(float x, float y, float z)
00106 {
00107 int val = 0;
00108
00109
00110 double rho = sqrt(x*x+y*y+z*z);
00111
00112 if (rho>3.0)
00113 val = 3<<3;
00114 else if (rho > 2.0)
00115 val = 2<<3;
00116 else if (rho>1.0)
00117 val = 1<<3;
00118 else
00119 val = 0;
00120
00121 if (x>y) val |= 1<<2;
00122 if (y>z) val |= 1<<1;
00123 if (z>x) val |= 1;
00124
00125 return val;
00126 }
00127
00128
00129 HMM<uint> getHMM(const std::string& name, const std::vector<Point2D<int> >& lines)
00130 {
00131
00132 std::vector<uint> states;
00133 for(uint i=0; i<5; i++)
00134 states.push_back(i);
00135
00136 std::vector<uint> posibleObservations;
00137 for(uint i=0; i<32; i++)
00138 posibleObservations.push_back(i);
00139
00140 HMM<uint> hmm(states, posibleObservations, name);
00141
00142
00143 hmm.setStateTransition(0, 0, 0.5);
00144 hmm.setStateTransition(0, 1, 0.5);
00145 hmm.setStateTransition(1, 1, 0.5);
00146 hmm.setStateTransition(1, 2, 0.5);
00147 hmm.setStateTransition(2, 2, 0.5);
00148 hmm.setStateTransition(2, 3, 0.5);
00149 hmm.setStateTransition(3, 3, 0.5);
00150 hmm.setStateTransition(3, 4, 0.5);
00151 hmm.setStateTransition(4, 4, 1);
00152
00153
00154 hmm.setCurrentState(0, 1);
00155
00156
00157
00158 std::vector<Point2D<double> > vel = getVel(lines);
00159
00160 std::vector< std::vector<uint> > observations;
00161 for(size_t j=0; j<vel.size(); j++)
00162 {
00163 std::vector<uint> observation;
00164 printf("InputValue ");
00165 for(size_t i=0; i<vel.size(); i++)
00166 {
00167 uint value = quantize(vel[(i+j)%vel.size()].i,
00168 vel[(i+j)%vel.size()].j, 0);
00169 printf("%i ", value);
00170 observation.push_back(value);
00171 }
00172 printf("\n");
00173 observations.push_back(observation);
00174 }
00175 LINFO("Train");
00176 hmm.train(observations, 50);
00177 LINFO("Done");
00178
00179
00180
00181 return hmm;
00182 }
00183
00184 int main()
00185 {
00186
00187
00188 {
00189
00190 std::vector<std::string> states;
00191 states.push_back("Rainy");
00192 states.push_back("Sunny");
00193
00194
00195 std::vector<std::string> posibleObservations;
00196 posibleObservations.push_back("walk");
00197 posibleObservations.push_back("shop");
00198 posibleObservations.push_back("clean");
00199
00200
00201 HMM<std::string> hmm(states, posibleObservations);
00202
00203
00204 hmm.setStateTransition("Rainy", "Rainy", 0.7);
00205 hmm.setStateTransition("Rainy", "Sunny", 0.3);
00206 hmm.setStateTransition("Sunny", "Rainy", 0.4);
00207 hmm.setStateTransition("Sunny", "Sunny", 0.6);
00208
00209
00210
00211 hmm.setStateEmission("Rainy", "walk", 0.1);
00212 hmm.setStateEmission("Rainy", "shop", 0.4);
00213 hmm.setStateEmission("Rainy", "clean", 0.5);
00214 hmm.setStateEmission("Sunny", "walk", 0.6);
00215 hmm.setStateEmission("Sunny", "shop", 0.3);
00216 hmm.setStateEmission("Sunny", "clean", 0.1);
00217
00218
00219 hmm.setCurrentState("Rainy", 0.6);
00220 hmm.setCurrentState("Sunny", 0.4);
00221
00222 std::vector<std::string> observations;
00223 observations.push_back("walk");
00224 observations.push_back("shop");
00225 observations.push_back("clean");
00226
00227 double prob;
00228 std::vector<std::string> path = hmm.getLikelyStates(observations, prob);
00229 LINFO("FinalState prob %f", prob);
00230 printf("Path: ");
00231 for(uint i=0; i<path.size(); i++)
00232 printf("%s ", path[i].c_str());
00233 printf("\n");
00234 }
00235
00236
00237
00238 {
00239
00240
00241 std::vector<Point2D<int> > lines = applelogo();
00242
00243 Image<PixRGB<byte> > img(320,240,ZEROS);
00244 for(size_t i=0; i<lines.size()-1; i++)
00245 drawLine(img, lines[i], lines[i+1], PixRGB<byte>(255,0,0));
00246
00247 std::vector<Point2D<double> > vel = getVel(lines);
00248
00249 Point2D<double> pos(150,150);
00250 for(uint i=0; i<vel.size(); i++)
00251 {
00252 if (img.coordsOk(Point2D<int>(pos)))
00253 img.setVal(Point2D<int>(pos), PixRGB<byte>(0,255,0));
00254
00255 pos += vel[i];
00256 }
00257
00258 SHOWIMG(img);
00259
00260
00261
00262 std::vector<HMM<uint> > hmms;
00263
00264 hmms.push_back(getHMM("Square", square(Point2D<int>(50,50), Dims(50,50))));
00265 hmms.push_back(getHMM("Triangle", triangle(Point2D<int>(50,50))));
00266
00267 LINFO("Test the hmm");
00268 lines = triangle(Point2D<int>(70,70));
00269
00270 vel = getVel(lines);
00271
00272 for(uint j=0; j<vel.size(); j++)
00273 {
00274 std::vector<uint> observations;
00275 LINFO("Observations");
00276 for(size_t i=0; i<vel.size(); i++)
00277 {
00278 uint value = quantize(vel[(i+j)%vel.size()].i,
00279 vel[(i+j)%vel.size()].j, 0);
00280 printf("%i ", value);
00281 observations.push_back(value);
00282 }
00283 printf("\n");
00284
00285
00286 double maxProb = -1e100;
00287 std::string name = "";
00288
00289 for(size_t i=0; i<hmms.size(); i++)
00290 {
00291 double prob = hmms[i].forward(observations);
00292 LINFO("HMM %s prob %e", hmms[i].getName().c_str(), exp(prob));
00293 if (prob > maxProb)
00294 {
00295 maxProb = prob;
00296 name = hmms[i].getName();
00297 }
00298 }
00299
00300 LINFO("Max is %s", name.c_str());
00301 }
00302 }
00303
00304 }