00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "plugins/SceneUnderstanding/POMDP.H"
00039
00040 #include "Image/FilterOps.H"
00041 #include "Image/PixelsTypes.H"
00042 #include "Image/MathOps.H"
00043 #include "Image/Kernels.H"
00044 #include "Image/DrawOps.H"
00045 #include "GUI/DebugWin.H"
00046 #include "Raster/Raster.H"
00047
00048 #include <cstdio>
00049
00050
00051 POMDP::POMDP() :
00052 itsRetinaSize(640,480),
00053 itsAgentState(-1,-1),
00054 itsGoalState(-1,-1),
00055 itsLastDistance(10000),
00056 itsLastAction(-1)
00057 {
00058
00059 }
00060
00061
00062 POMDP::~POMDP()
00063 {
00064
00065 }
00066
00067
00068 bool POMDP::makeObservation(const Image<PixRGB<byte> > &img)
00069 {
00070 const char* personObj = "/home/lior/saliency/etc/objects/person.ppm";
00071 const char* doorObj = "/home/lior/saliency/etc/objects/door.ppm";
00072 const char* blockObj = "/home/lior/saliency/etc/objects/wall.ppm";
00073
00074
00075 itsAgentState = findObject(img, personObj);
00076
00077
00078
00079 {
00080 itsStateSpace = Image<byte>(img.getDims(), ZEROS);
00081
00082 itsGoalState = findObject(img, doorObj);
00083 drawFilledRect(itsStateSpace,
00084 Rectangle(itsGoalState-(31/2), Dims(31,31)),
00085 (byte)GOAL);
00086
00087
00088 std::vector<Point2D<int> > wallState = findMultipleObjects(img, blockObj);
00089 for(uint i=0; i<wallState.size(); i++)
00090 {
00091 drawFilledRect(itsStateSpace,
00092 Rectangle(wallState[i]-(31/2), Dims(31,31)),
00093 (byte)WALL);
00094 }
00095
00096 itsStateSpace = scaleBlock(itsStateSpace, itsStateSpace.getDims()/31);
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111 itsCurrentPercep[0] = Image<float>(itsStateSpace.getDims(), ZEROS);
00112 itsCurrentPercep[0].setVal(itsAgentState/31, 1.0F);
00113
00114
00115 return true;
00116
00117 }
00118
00119 return false;
00120
00121
00122 }
00123
00124 void POMDP::init()
00125 {
00126 Image<float> percep(itsRetinaSize, ZEROS);
00127 percep.clear(1.0e-5);
00128 itsCurrentPercep.push_back(percep);
00129
00130
00131
00132 const char* personObj = "/home/lior/saliency/etc/objects/person.ppm";
00133 Image<PixRGB<byte> > obj = Raster::ReadRGB(personObj);
00134 itsObjects.push_back(obj);
00135
00136 const char* doorObj = "/home/lior/saliency/etc/objects/door.ppm";
00137 obj = Raster::ReadRGB(doorObj);
00138 itsObjects.push_back(obj);
00139
00140 const char* blockObj = "/home/lior/saliency/etc/objects/wall.ppm";
00141 obj = Raster::ReadRGB(blockObj);
00142 itsObjects.push_back(obj);
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159 }
00160
00161
00162 Image<float> POMDP::getPerception(const uint obj)
00163 {
00164 if (obj < itsCurrentPercep.size())
00165 return itsCurrentPercep[obj];
00166
00167 return Image<float>();
00168 }
00169
00170 bool POMDP::goalReached()
00171 {
00172 Point2D<int> loc; float maxVal;
00173 findMax(itsCurrentPercep[0], loc, maxVal);
00174
00175 int currentState = loc.j*itsCurrentPercep[0].getWidth() + loc.i;
00176 LINFO("Current State %i\n", currentState);
00177
00178 if (itsStateSpace.getVal(currentState) == GOAL)
00179 return true;
00180
00181
00182 return false;
00183 }
00184
00185 Image<float> POMDP::makePrediction(const ACTIONS action)
00186 {
00187
00188 Point2D<int> loc; float maxVal;
00189 findMax(itsCurrentPercep[0], loc, maxVal);
00190
00191 int currentState = loc.j*itsCurrentPercep[0].getWidth() + loc.i;
00192 int newState = doAction(currentState, action);
00193
00194 Image<float> prediction(itsStateSpace.getDims(), ZEROS);
00195 if (newState != -1)
00196 prediction.setVal(newState, 1.0F);
00197
00198 itsPrediction = prediction;
00199 return prediction;
00200 }
00201
00202 float POMDP::updatePerception(const Image<PixRGB<byte> > &img)
00203 {
00204 const char* personObj = "/home/lior/saliency/etc/objects/person.ppm";
00205
00206 itsAgentState = findObject(img, personObj);
00207
00208 itsPreviousPercep = itsCurrentPercep[0];
00209
00210 itsCurrentPercep[0].clear(0.0F);
00211 itsCurrentPercep[0].setVal(itsAgentState/31, 1.0F);
00212
00213
00214
00215
00216
00217
00218 float klDist = 0;
00219 for(int i=0; i<itsCurrentPercep[0].getSize(); i++)
00220 {
00221 float posterior = itsCurrentPercep[0].getVal(i);
00222 float prior = itsPrediction.getVal(i);
00223
00224 if (prior == 0)
00225 prior = 1.0e-10;
00226 if (posterior == 0)
00227 posterior = 1.0e-10;
00228
00229 klDist += posterior * log(posterior/prior);
00230 }
00231
00232 return klDist;
00233 }
00234
00235 void POMDP::updateStateTransitions(const ACTIONS action)
00236 {
00237
00238 Point2D<int> loc; float maxVal;
00239
00240
00241
00242
00243 findMax(itsPrediction, loc, maxVal);
00244 int previousState = loc.j*itsPrediction.getWidth() + loc.i;
00245
00246 findMax(itsCurrentPercep[0], loc, maxVal);
00247 int currentState = loc.j*itsCurrentPercep[0].getWidth() + loc.i;
00248
00249
00250
00251
00252 LINFO("Previous %i action %i current %i",
00253 previousState,
00254 action,
00255 currentState);
00256
00257 }
00258
00259 Image<byte> POMDP::getStateSpace()
00260 {
00261 return itsStateSpace;
00262 }
00263
00264 Point2D<int> POMDP::getAgentState()
00265 {
00266 return itsAgentState;
00267 }
00268
00269 Point2D<int> POMDP::getGoalState()
00270 {
00271 return itsGoalState;
00272 }
00273
00274 void POMDP::showTransitions()
00275 {
00276 int nActions = 4;
00277 int nStates = itsStateSpace.getSize();
00278
00279 for(int state=0; state<nStates; state++)
00280 for(int act=0; act<nActions; act++)
00281 {
00282 printf("State %i ", state);
00283 switch(act)
00284 {
00285 case NORTH: printf("Action=North "); break;
00286 case EAST: printf("Action=East "); break;
00287 case WEST: printf("Action=West "); break;
00288 case SOUTH: printf("Action=South "); break;
00289 }
00290
00291 for(int newState=0; newState<nStates; newState++)
00292 if (itsTransitions[state][act][newState] > 0)
00293 printf("%i=%f ", newState, itsTransitions[state][act][newState]);
00294 printf("\n");
00295 }
00296
00297 }
00298
00299
00300 int POMDP::doAction(const int state, const int act)
00301 {
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312 int newState = state;
00313
00314
00315
00316
00317 if (false)
00318 {
00319
00320 } else {
00321 Point2D<int> currentPos;
00322 currentPos.j = state / itsStateSpace.getWidth();
00323 currentPos.i = state - (currentPos.j*itsStateSpace.getWidth());
00324
00325 switch(act)
00326 {
00327 case NORTH: currentPos.j -= 1; break;
00328 case EAST: currentPos.i += 1; break;
00329 case WEST: currentPos.i -= 1; break;
00330 case SOUTH: currentPos.j += 1; break;
00331 }
00332
00333
00334 if (itsStateSpace.coordsOk(currentPos) &&
00335 itsStateSpace.getVal(currentPos) != WALL)
00336 {
00337 newState = currentPos.j*itsStateSpace.getWidth() + currentPos.i;
00338 }
00339 }
00340
00341 return newState;
00342
00343 }
00344
00345 void POMDP::valueIteration()
00346 {
00347 int nActions = 4;
00348 int nStates = itsStateSpace.getSize();
00349
00350 Image<float> newUtility(itsStateSpace.getDims(), ZEROS);
00351
00352
00353 for(int state=0; state<newUtility.getSize(); state++)
00354 newUtility.setVal(state, getReward(state));
00355
00356 float thresh = 0.1;
00357 float discount = 1;
00358 float lemda = 1;
00359 while(lemda > thresh*(1-discount)/discount + thresh)
00360 {
00361
00362 itsUtility = newUtility;
00363 lemda = 0;
00364
00365 for(int state=0; state<nStates; state++)
00366 {
00367
00368
00369 float maxActVal = -std::numeric_limits<float>::max();
00370 for(int act=0; act<nActions; act++)
00371 {
00372 float sum = 0;
00373 for(int newState=0; newState<nStates; newState++)
00374 sum += getTransProb(state,act,newState) * itsUtility.getVal(newState);
00375
00376 if (sum > maxActVal)
00377 maxActVal = sum;
00378 }
00379
00380 float u = getReward(state) + discount*maxActVal;
00381 newUtility.setVal(state,u);
00382
00383 if (fabs(u - itsUtility.getVal(state)) > lemda)
00384 lemda = fabs(u - itsUtility.getVal(state));
00385 }
00386 LINFO("Lmeda %f\n", lemda);
00387
00388 }
00389
00390
00391
00392
00393
00394
00395
00396
00397
00398
00399
00400 }
00401
00402 float POMDP::getTransProb(int state, int action, int newState)
00403 {
00404
00405 float prob = 0;
00406
00407 if (doAction(state,action) == newState)
00408 prob += 0.8;
00409
00410
00411
00412 switch(action)
00413 {
00414 case NORTH:
00415 case SOUTH:
00416 if (doAction(state,EAST) == newState)
00417 prob += 0.1;
00418 if (doAction(state,WEST) == newState)
00419 prob += 0.1;
00420 break;
00421 case EAST:
00422 case WEST:
00423 if (doAction(state,NORTH) == newState)
00424 prob += 0.1;
00425 if (doAction(state,SOUTH) == newState)
00426 prob += 0.1;
00427 break;
00428 }
00429
00430 return prob;
00431
00432 }
00433
00434
00435 float POMDP::getReward(int state)
00436 {
00437
00438 switch(itsStateSpace.getVal(state))
00439 {
00440 case GOAL:
00441 return 1.0F;
00442 case HOLE:
00443 case WALL:
00444 return -1.0F;
00445 default:
00446 return -0.01;
00447 }
00448 return 0;
00449 }
00450
00451 void POMDP::doPolicy(const Point2D<int>& startPos)
00452 {
00453
00454 LINFO("Do policy from %ix%i\n", startPos.i, startPos.j);
00455 Image<byte> ssImg = itsStateSpace;
00456 inplaceNormalize(ssImg, (byte)0, (byte)255);
00457
00458 Image<PixRGB<byte> > disp = ssImg;
00459
00460 int state = (startPos.j/31)*itsStateSpace.getWidth() + (startPos.i/31);
00461 disp.setVal(state, PixRGB<byte>(255,0,0));
00462
00463 Image<float> utilDisp = itsUtility;
00464 inplaceNormalize(utilDisp, 0.0F, 255.0F);
00465 SHOWIMG(scaleBlock(utilDisp, itsUtility.getDims()*5));
00466 for(int i=0; i<1000 && state != -1; i++)
00467 {
00468 int action = getAction(state);
00469 state = doAction(state, action);
00470
00471 if (state != -1)
00472 disp.setVal(state, PixRGB<byte>(255,255,0));
00473 }
00474 if (state != -1)
00475 {
00476 LINFO("Can not solve, exploring");
00477 itsExploring = true;
00478 } else {
00479 itsExploring = false;
00480 }
00481
00482 SHOWIMG(scaleBlock(disp, disp.getDims()*5));
00483
00484
00485
00486 }
00487
00488 int POMDP::getPropAction()
00489 {
00490 if (itsCurrentPercep.size() == 0)
00491 return -1;
00492 Point2D<int> loc; float maxVal;
00493 findMax(itsCurrentPercep[0], loc, maxVal);
00494
00495 int currentState = loc.j*itsCurrentPercep[0].getWidth() + loc.i;
00496
00497 return getAction(currentState);
00498
00499 }
00500
00501
00502 POMDP::ACTIONS POMDP::getAction(int state)
00503 {
00504
00505 int nActions = 4;
00506 int nStates = itsStateSpace.getSize();
00507
00508 float maxActVal = -std::numeric_limits<float>::max();
00509 int action = -1;
00510
00511 for(int act=0; act<nActions; act++)
00512 {
00513 float sum = 0;
00514 for(int newState=0; newState<nStates; newState++)
00515 {
00516 if (newState < itsUtility.getSize())
00517 sum += getTransProb(state,act,newState) * itsUtility.getVal(newState);
00518 }
00519
00520 if (sum > maxActVal)
00521 {
00522 maxActVal = sum;
00523 action = act;
00524 }
00525 }
00526
00527 return (ACTIONS)action;
00528 }
00529
00530 Image<float> POMDP::locateObject(const Image<float>& src, Image<float>& filter)
00531 {
00532
00533 Image<float> result(src.getDims(), ZEROS);
00534 const int src_w = src.getWidth();
00535 const int src_h = src.getHeight();
00536
00537 Image<float>::const_iterator fptrBegin = filter.begin();
00538 const int fil_w = filter.getWidth();
00539 const int fil_h = filter.getHeight();
00540
00541 Image<float>::const_iterator sptr = src.begin();
00542
00543 const int srow_skip = src_w-fil_w;
00544
00545 float maxDiff = 256*fil_w*fil_h;
00546
00547 for (int dst_y = 0; dst_y < src_h-fil_h; dst_y++) {
00548
00549 for (int dst_x = 0; dst_x < src_w-fil_w; dst_x++) {
00550
00551 float dst_val = 0.0f;
00552
00553 Image<float>::const_iterator fptr = fptrBegin;
00554 Image<float>::const_iterator srow_ptr = sptr + (src_w*dst_y) + dst_x;
00555 for (int f_y = 0; f_y < fil_h; ++f_y)
00556 {
00557 for (int f_x = 0; f_x < fil_w; ++f_x){
00558 dst_val += fabs((*srow_ptr++) - (*fptr++));
00559 }
00560 srow_ptr += srow_skip;
00561 }
00562 float prob = 1-dst_val/(maxDiff * 0.25);
00563 if (prob < 0) prob = 0;
00564
00565 result.setVal(dst_x, dst_y, dst_val);
00566 }
00567 }
00568
00569 return result;
00570
00571 }
00572
00573
00574
00575 Point2D<int> POMDP::findObject(const Image<PixRGB<byte> > &img, const char* filename)
00576 {
00577
00578
00579
00580 Image<PixRGB<byte> > obj = Raster::ReadRGB(filename);
00581
00582 Image<float> objLum = luminance(obj);
00583 Image<float> imgLum = luminance(img);
00584
00585 Image<float> result = locateObject(imgLum, objLum);
00586
00587 Point2D<int> loc; float maxVal;
00588 findMax(result, loc, maxVal);
00589 loc.i += (objLum.getWidth()/2);
00590 loc.j += (objLum.getHeight()/2);
00591
00592 return loc;
00593 }
00594
00595
00596 std::vector<Point2D<int> > POMDP::findMultipleObjects(const Image<PixRGB<byte> > &img, const char* filename)
00597 {
00598
00599 std::vector<Point2D<int> > objectLocations;
00600
00601
00602 Image<PixRGB<byte> > obj = Raster::ReadRGB(filename);
00603
00604 Image<float> objLum = luminance(obj);
00605 Image<float> imgLum = luminance(img);
00606
00607 Image<float> result = locateObject(imgLum, objLum);
00608
00609 Point2D<int> loc; float maxVal;
00610 findMax(result, loc, maxVal);
00611 result.setVal(loc, 0.0F);
00612 loc.i += (objLum.getWidth()/2);
00613 loc.j += (objLum.getHeight()/2);
00614 objectLocations.push_back(loc);
00615
00616 float objMaxVal = maxVal;
00617
00618 while(1)
00619 {
00620 Point2D<int> loc; float maxVal;
00621 findMax(result, loc, maxVal);
00622 if (maxVal > objMaxVal*0.8)
00623 {
00624 result.setVal(loc, 0.0F);
00625 loc.i += (objLum.getWidth()/2);
00626 loc.j += (objLum.getHeight()/2);
00627 objectLocations.push_back(loc);
00628 } else {
00629 break;
00630 }
00631 }
00632
00633
00634 return objectLocations;
00635 }
00636
00637 float POMDP::bayesFilter(const int action, const Image<PixRGB<byte> > &img)
00638 {
00639
00640
00641
00642
00643
00644
00645 LINFO("Making prediction");
00646 Image<float> prevBelif = itsCurrentPercep[0];
00647 float entropy = getEntropy(prevBelif);
00648 LINFO("Entorpy %f", entropy);
00649 SHOWIMG(prevBelif);
00650
00651 Image<float> newBelif = prevBelif;
00652
00653
00654 if (entropy < 10)
00655 {
00656
00657
00658
00659 for(int state=0; state<prevBelif.getSize(); state++)
00660 {
00661 float sum=0;
00662 for(int i=0; i<prevBelif.getSize(); i++)
00663 {
00664
00665 if (prevBelif[i] > 0)
00666 sum += getTransProb(i, action, state) * prevBelif[i];
00667 }
00668 newBelif[state] = sum;
00669 }
00670 }
00671
00672 LINFO("Done");
00673
00674 SHOWIMG(newBelif);
00675
00676
00677 int objID = 0;
00678
00679 Image<float> objLum = luminance(itsObjects[objID]);
00680 Image<float> imgLum = luminance(img);
00681
00682 Image<float> result = locateObject(imgLum, objLum);
00683 result /= sum(result);
00684 result = rescale(result, newBelif.getDims());
00685
00686
00687 itsCurrentPercep[0] = result;
00688 itsCurrentPercep[0] /= sum(itsCurrentPercep[0]);
00689
00690 SHOWIMG(itsCurrentPercep[0]);
00691 entropy = getEntropy(itsCurrentPercep[0]);
00692 LINFO("new Entorpy %f", entropy);
00693
00694
00695 return 0;
00696 }
00697
00698 float POMDP::particleFilter(const int action, const Image<PixRGB<byte> > &img)
00699 {
00700 if (itsParticleStateSpace.size() == 0)
00701 {
00702 int nParticles = 10;
00703
00704 for(int i=0; i< nParticles; i++)
00705 {
00706 State state(2*i, 2*i);
00707 itsParticleStateSpace.push_back(state);
00708 }
00709 }
00710 Image<float> objLum = luminance(itsObjects[0]);
00711 Image<float> objBlob = gaussianBlob<float>(
00712 objLum.getDims(),
00713 Point2D<int>(objLum.getWidth()/2, objLum.getHeight()/2),
00714 (float)objLum.getWidth(), (float)objLum.getHeight());
00715 objLum *= objBlob;
00716
00717 Image<float> imgLum = luminance(img);
00718
00719 objLum = rescale(objLum, objLum.getDims()/2);
00720 imgLum = rescale(imgLum, imgLum.getDims()/2);
00721
00722
00723 Image<float> result(imgLum.getWidth()-objLum.getWidth()+1,
00724 imgLum.getHeight()-objLum.getHeight()+1,
00725 NO_INIT);
00726
00727 cvMatchTemplate(img2ipl(imgLum),
00728 img2ipl(objLum),
00729 img2ipl(result),
00730
00731 CV_TM_CCOEFF_NORMED);
00732
00733 result = abs(result);
00734 float entropy = getEntropy(result);
00735 LINFO("new Entorpy %f", entropy);
00736
00737 SHOWIMG(result);
00738
00739
00740 return 0;
00741
00742 }
00743
00744 float POMDP::getEntropy(Image<float> &belif)
00745 {
00746 float sum = 0;
00747 for(int i=0; i<belif.getSize(); i++)
00748 sum += belif[i] * log((belif[i] != 0) ? belif[i] : 1.0e-5);
00749 return -1.0*sum;
00750 }
00751
00752
00753
00754
00755 float POMDP::getObjProb(const Image<PixRGB<byte> > &img,
00756 const State state, const int objID)
00757 {
00758 Image<float> objLum = luminance(itsObjects[objID]);
00759 Image<float> imgLum = luminance(img);
00760
00761 Image<float>::const_iterator imgLumPtr = imgLum.begin();
00762 Image<float>::const_iterator fptrBegin = objLum.begin();
00763 const int fil_w = objLum.getWidth();
00764 const int fil_h = objLum.getHeight();
00765 const int srow_skip = imgLum.getWidth()-fil_w;
00766
00767 float prob = 1.0f;
00768
00769
00770 Image<float>::const_iterator fptr = fptrBegin;
00771 Image<float>::const_iterator srow_ptr = imgLumPtr + (imgLum.getWidth()*state.y) + state.x;
00772 for (int f_y = 0; f_y < fil_h; ++f_y)
00773 {
00774 for (int f_x = 0; f_x < fil_w; ++f_x){
00775 prob += fabs((*srow_ptr++) - (*fptr++));
00776 }
00777 srow_ptr += srow_skip;
00778 }
00779
00780 return 1/prob;
00781
00782 }
00783
00784
00785
00786
00787
00788
00789
00790