00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043 #include "Channels/IntensityChannel.H"
00044 #include "Channels/InputFrame.H"
00045 #include "Channels/ChannelOpts.H"
00046 #include "Component/ModelManager.H"
00047 #include "GUI/DebugWin.H"
00048 #include "Image/ColorMap.H"
00049 #include "Image/ColorOps.H"
00050 #include "Image/CutPaste.H"
00051 #include "Image/Dims.H"
00052 #include "Image/Image.H"
00053 #include "Image/LevelSpec.H"
00054 #include "Image/MathOps.H"
00055 #include "Image/Pixels.H"
00056 #include "Image/Point3D.H"
00057 #include "Learn/QuadTree.H"
00058 #include "Util/log.H"
00059 #include "Util/StringConversions.H"
00060
00061 #include <cmath>
00062 #include <vector>
00063 #include <queue>
00064 #include <iostream>
00065 #include <algorithm>
00066
00067
00068 QuadTree::QuadTree(int Nlevels, Dims d) : itsNumLevels(Nlevels)
00069 {
00070
00071 rutz::shared_ptr<QuadNode> root_ref(new QuadNode());
00072 Rectangle thisWindow(Point2D<int>(0,0), d);
00073 addTreeUnder(root_ref, Nlevels, thisWindow);
00074
00075 itsRootNode = root_ref->getChild(0);
00076 initAlphas();
00077 }
00078
00079
00080 QuadTree::QuadTree(int Nlevels, Image<PixRGB<byte> > im) : itsNumLevels(Nlevels)
00081 {
00082
00083 Dims d = im.getDims();
00084
00085
00086 rutz::shared_ptr<QuadNode> root_ref(new QuadNode());
00087 Rectangle thisWindow(Point2D<int>(0,0), d);
00088 addTreeUnder(root_ref, Nlevels, thisWindow);
00089
00090 itsRootNode = root_ref->getChild(0);
00091
00092 itsImage = im;
00093 initAlphas();
00094 }
00095
00096
00097 void QuadTree::addTreeUnder(rutz::shared_ptr<QuadNode> parent, int Nlevel, Rectangle r)
00098 {
00099
00100 rutz::shared_ptr<QuadNode> myNewNode(new QuadNode(parent));
00101 myNewNode->setArea(r);
00102 myNewNode->setDepth(itsNumLevels-Nlevel);
00103
00104
00105 parent->addChild(myNewNode);
00106
00107
00108 itsNodes.push_back(myNewNode);
00109
00110
00111 if (Nlevel > 0) {
00112
00113 Point2D<int> middle = r.center();
00114
00115 Rectangle tl = Rectangle::tlbrO(r.top(),r.left(),middle.j,middle.i);
00116 Rectangle tr = Rectangle::tlbrO(r.top(),middle.i,middle.j,r.rightO());
00117 Rectangle bl = Rectangle::tlbrO(middle.j,r.left(),r.bottomO(),middle.i);
00118 Rectangle br = Rectangle::tlbrO(middle.j,middle.i,r.bottomO(),r.rightO());
00119
00120 addTreeUnder(myNewNode, Nlevel - 1, tl);
00121 addTreeUnder(myNewNode, Nlevel - 1, tr);
00122 addTreeUnder(myNewNode, Nlevel - 1, bl);
00123 addTreeUnder(myNewNode, Nlevel - 1, br);
00124 }
00125 }
00126
00127
00128 void QuadTree::cacheClassifierResult()
00129 {
00130 uint NClasses = itsClassifier->getNumClasses();
00131 Image<double> output(itsImage.getDims(),ZEROS);
00132 Image<double> denom(itsImage.getDims(),ZEROS);
00133 Dims patch_size(5,5);
00134 Image<PixRGB<byte> > patch(patch_size,ZEROS);
00135 Rectangle im_rect(Point2D<int>(0,0), itsImage.getDims());
00136
00137 itsClassifierOutput.clear();
00138 itsBestClassOutput.resize(itsImage.getDims());
00139 for(uint i = 0; i < NClasses; i++)
00140 itsClassifierOutput.push_back(output);
00141
00142 Point2D<int> P_l;
00143 for(P_l.i = 0; P_l.i < itsImage.getWidth(); P_l.i++)
00144 for(P_l.j = 0; P_l.j < itsImage.getHeight(); P_l.j++) {
00145
00146 Rectangle rect_patch = Rectangle::centerDims(P_l, patch_size);
00147 rect_patch = constrainRect(rect_patch, im_rect,0,itsImage.getWidth(),0,itsImage.getHeight());
00148 patch = crop(itsImage, rect_patch);
00149
00150 byte best_class = 0;
00151 for(uint i = 0; i < NClasses; i++) {
00152 double res = itsClassifier->classifyAt(patch, i);
00153 if(res > itsClassifierOutput[best_class][P_l])
00154 best_class = i;
00155
00156 itsClassifierOutput[i][P_l] = res;
00157
00158 denom[P_l] += exp(res);
00159 }
00160 itsBestClassOutput[P_l] = best_class;
00161 }
00162
00163 for(uint i = 0; i < NClasses; i++)
00164 itsClassifierOutput[i] = exp(itsClassifierOutput[i])/denom;
00165 }
00166
00167
00168 double QuadTree::evaluateClassifierAt(rutz::shared_ptr<QuadNode> q) const
00169 {
00170 int Npts = itsImage.getDims().sz();
00171 double E = 0;
00172 Point2D<int> P_l, P_g;
00173 for(P_l.i = 0; P_l.i < itsImage.getWidth(); P_l.i++)
00174 for(P_l.j = 0; P_l.j < itsImage.getHeight(); P_l.j++) {
00175 P_g = q->convertToGlobal(P_l);
00176 E -= log(itsClassifierOutput[q->getObjLabelAt(P_g)][P_l]);
00177 }
00178
00179 return E/Npts;
00180 }
00181
00182
00183 double QuadTree::evaluateCohesionAt(rutz::shared_ptr<QuadNode> q) const
00184 {
00185
00186
00187 Neighborhood one_away;
00188
00189
00190
00191
00192 const float lambda = 1;
00193 const float sigma_col = 75;
00194 const Point2D<int> origin(0,0);
00195 float e_coh = 0;
00196
00197 one_away.push_back(Point2D<int>(0,1));
00198 one_away.push_back(Point2D<int>(1,-1));
00199 one_away.push_back(Point2D<int>(1,0));
00200 one_away.push_back(Point2D<int>(1,1));
00201
00202
00203 Image<byte> segImage = q->getSegImage();
00204
00205 int nedges = 0;
00206
00207
00208
00209 Point2D<int> P_l,Q_l, P_g, Q_g;
00210 for(P_l.i = 0; P_l.i < segImage.getWidth(); P_l.i++) {
00211 for(P_l.j = 0; P_l.j < segImage.getHeight(); P_l.j++) {
00212 for(uint k = 0; k < one_away.size(); k++) {
00213
00214 Q_l = P_l + one_away[k];
00215 P_g = q->convertToGlobal(P_l);
00216 Q_g = q->convertToGlobal(Q_l);
00217
00218 if (!segImage.coordsOk(Q_l)) continue;
00219
00220 nedges++;
00221
00222 if(segImage[P_l] != segImage[Q_l]) continue;
00223
00224
00225 double color_dist = colorDistance(itsImage[P_g], itsImage[Q_g]);
00226 double space_dist = one_away[k].distance(origin);
00227 e_coh -= lambda / space_dist * exp(- color_dist * color_dist / (2 * sigma_col * sigma_col));
00228 }
00229 }
00230 }
00231 return e_coh / nedges;
00232 }
00233
00234
00235 double QuadTree::evaluateCorrespondenceAt(rutz::shared_ptr<QuadNode> q) const
00236 {
00237 if(q->isLeaf()) return 0;
00238
00239 Point2D<int> P_l,Q_l, P_g, Q_g;
00240 Image<byte> parentImage = q->getSegImage();
00241 Image<byte> childImage = q->getChildSegImage();
00242
00243 Image<byte> deltaImage = absDiff(parentImage, childImage);
00244 double ret = double(-emptyArea(deltaImage))/deltaImage.getDims().sz();
00245
00246 return ret;
00247 }
00248
00249
00250 double QuadTree::evaluateTotalEnergyAt(rutz::shared_ptr<QuadNode> q) const
00251 {
00252 double E = evaluateClassifierAt(q) * itsAlphas[0];
00253 E += evaluateCohesionAt(q) * itsAlphas[1];
00254 E += evaluateCorrespondenceAt(q) * itsAlphas[2];
00255
00256 if(q->isLeaf()) return E;
00257 for(uint i = 0; i < 4; i++)
00258 E += evaluateTotalEnergyAt(q->getChild(i));
00259
00260 return E;
00261 }
00262
00263
00264 void QuadTree::printTree() const
00265 {
00266 LINFO("Tree of depth %u, number of nodes %zu", itsNumLevels, itsNodes.size());
00267
00268 for(uint i = 0; i < itsNodes.size(); i++)
00269 LINFO("%s", toStr(*itsNodes[i]).c_str());
00270
00271 }
00272
00273
00274 std::string QuadTree::writeTree() const
00275 {
00276 std::string ret = "";
00277 ret += sformat("Tree of depth %u, number of nodes %zu\n", itsNumLevels, itsNodes.size());
00278
00279 for(uint i = 0; i < itsNodes.size(); i++)
00280 ret += sformat("%s\n", toStr(*itsNodes[i]).c_str());
00281
00282 return ret;
00283 }
00284
00285
00286 std::vector<QuadNode::NodeState> QuadTree::generateProposalsAt(rutz::shared_ptr<QuadNode> q, double thresh)
00287 {
00288
00289 uint NClasses = itsClassifier->getNumClasses();
00290
00291 if(!q->isLeaf()) {
00292 for(uint i = 0; i < 4; i++) {
00293 std::vector<QuadNode::NodeState> child_props = generateProposalsAt(q->getChild(i), thresh);
00294 if(child_props.size() == 0) LFATAL("no proposals made for node %s", toStr(q->getChild(i)).c_str());
00295 q->getChild(i)->setState(child_props[0]);
00296 }
00297 }
00298
00299 QuadNode::NodeState probe(0,0,1,2), realstate = q->getState();
00300 std::vector<QuadNode::NodeState> ret;
00301
00302
00303
00304
00305 byte top_class[3][NClasses];
00306 uint prevalence[3][NClasses];
00307 Rectangle r = q->getArea();
00308
00309 if(1) {
00310 for(; probe.segTemplate < 30; probe.segTemplate++) {
00311
00312
00313 for(uint i = 0; i < 3; i++)
00314 for(uint j= 0; j < NClasses; j++) {
00315 prevalence[i][j]=0;
00316 top_class[i][j]=j;
00317 }
00318
00319
00320 for(uint i = 0; i < 3; i++) probe.objLabels[i] = i;
00321
00322 q->setState(probe);
00323
00324 Point2D<int> P_g;
00325 for(P_g.i = r.left(); P_g.i < r.rightO(); P_g.i++)
00326 for(P_g.j = r.top(); P_g.j < r.bottomO(); P_g.j++)
00327 prevalence[q->getObjLabelAt(P_g)][itsBestClassOutput[P_g]]++;
00328
00329
00330 for(uint i = 0; i < 3; i ++) {
00331 for(uint j = 0; j < NClasses; j++) {
00332 for(uint k = 0; k < NClasses-j-1; k++)
00333 if(prevalence[i][k] < prevalence[i][k+1])
00334 {
00335 std::swap(top_class[i][k], top_class[i][k+1]);
00336 std::swap(prevalence[i][k], prevalence[i][k+1]);
00337 }
00338 }
00339 }
00340
00341
00342 std::queue<Point3D<uint> > tryme;
00343 bool tested[NClasses][NClasses][NClasses];
00344 for(uint i = 0; i < NClasses; i++)
00345 for(uint j = 0; j < NClasses; j++)
00346 for(uint k = 0; k < NClasses; k++)
00347 tested[i][j][k] = false;
00348
00349 uint area = r.area();
00350 const double occ_tol = 0.0001;
00351 for(uint i = 0; i < NClasses; i++) {
00352 if(prevalence[0][0] - prevalence[0][i] > occ_tol * area) break;
00353 for(uint j = 0; j < NClasses; j++) {
00354 if(prevalence[1][0] - prevalence[1][j] > occ_tol * area) break;
00355 for(uint k = 0; k < NClasses; k++) {
00356 if(prevalence[2][0] - prevalence[2][k] > occ_tol * area) break;
00357 tryme.push(Point3D<uint>(i,j,k));
00358 if(probe.isDoubleton()) break;
00359 }
00360 if(probe.isSingleton()) break;
00361 }
00362 }
00363 while(!tryme.empty()) {
00364 Point3D<uint> n = tryme.front();
00365 tryme.pop();
00366
00367 if(tested[n.x][n.y][n.z]) continue;
00368 if(n.x >= NClasses || n.y >= NClasses || n.z >= NClasses) continue;
00369 probe.objLabels[0] = top_class[0][n.x];
00370 probe.objLabels[1] = top_class[1][n.y];
00371 probe.objLabels[2] = top_class[2][n.z];
00372 q->setState(probe);
00373 q->storeEnergy(evaluateTotalEnergyAt(q));
00374
00375 tested[n.x][n.y][n.z]=true;
00376
00377 if(q->getEnergy() < thresh) {
00378 probe.E = q->getEnergy();
00379 ret.push_back(probe);
00380 tryme.push(Point3D<uint>(n.x+1,n.y,n.z));
00381 if(!probe.isSingleton()) tryme.push(Point3D<uint>(n.x,n.y+1,n.z));
00382 if(!probe.isDoubleton()) tryme.push(Point3D<uint>(n.x,n.y,n.z+1));
00383 }
00384 }
00385 }
00386
00387 for(uint j = 0; j < ret.size(); j++)
00388 for(uint k = 0; k < ret.size()-j-1; k++)
00389 if(ret[k].E > ret[k+1].E)
00390 std::swap(ret[k],ret[k+1]);
00391
00392 q->setState(realstate);
00393 }
00394 return ret;
00395 }
00396
00397
00398 QuadNode::QuadNode() : itsIsStale(true), itsState(0)
00399 {
00400 for(uint i = 0; i < 3; i++) itsState.objLabels.push_back(i);
00401 }
00402
00403
00404 QuadNode::QuadNode(rutz::shared_ptr<QuadNode> q)
00405 : itsIsLeaf(true), itsIsStale(true),
00406 itsState(0),
00407 itsParent(q)
00408 {
00409 for(uint i = 0; i < 3; i++) itsState.objLabels.push_back(i);
00410 }
00411
00412
00413 QuadNode::QuadNode(rutz::shared_ptr<QuadNode> q, NodeState n)
00414 : itsIsLeaf(true), itsIsStale(true),
00415 itsState(n),
00416 itsParent(q)
00417 {
00418 }
00419
00420
00421 Image<byte> QuadNode::getChildSegImage()
00422 {
00423 if(isLeaf()) return getSegImage();
00424 Image<byte> ret(getArea().dims(),ZEROS);
00425 for(uint i = 0; i < 4; i++) {
00426 rutz::shared_ptr<QuadNode> child = getChild(i);
00427 inplaceEmbed(ret, child->getSegImage(), child->getArea(),byte(-1));
00428 }
00429 return ret;
00430 }
00431
00432
00433 Image<PixRGB<byte> > QuadNode::getColorizedSegImage()
00434 {
00435 return colorLabels(getSegImage());
00436 }
00437
00438
00439 Image<PixRGB<byte> > QuadNode::getColorizedChildSegImage()
00440 {
00441 return colorLabels(getChildSegImage());
00442 }
00443
00444
00445 void QuadNode::refreshSegImage()
00446 {
00447 ASSERT(itsState.objLabels.size() > 0);
00448
00449 Image<byte> ret(itsArea.dims(),NO_INIT);
00450 Point2D<int> P_local;
00451 for(P_local.i = 0; P_local.i < itsArea.dims().w(); P_local.i++)
00452 for (P_local.j = 0; P_local.j < itsArea.dims().h(); P_local.j++)
00453 ret[P_local] = getObjLabelAt(convertToGlobal(P_local));
00454
00455 itsSegImage = ret;
00456 itsIsStale = false;
00457 }
00458
00459
00460 Image<PixRGB<byte> > QuadNode::colorLabels(Image<byte> im) const
00461 {
00462 ColorMap cm(256);
00463 PixRGB<byte> col;
00464 for(uint i = 2; i <= 2; i--) {
00465 col = PixRGB<byte>(0,0,0);
00466 col[i] = 255;
00467 cm[i] = col;
00468 }
00469 return colorize(im, cm);
00470 }
00471
00472
00473 byte QuadNode::getObjLabelAt(Point2D<int> loc) const
00474 {
00475
00476
00477
00478 if(!itsArea.contains(loc)) {
00479 LINFO("Node at window (%s) does not contain point (%s)",
00480 toStr(itsArea).c_str(), toStr(loc).c_str());
00481 return -1;
00482 }
00483
00484
00485 if(!itsIsStale) {
00486 return itsSegImage[convertToLocal(loc)];
00487 }
00488
00489
00490 Point2D<double> intLoc(double(loc.i-itsArea.left())/(itsArea.width()),
00491 double(loc.j-itsArea.top())/(itsArea.height()));
00492
00493 uint iST = itsState.segTemplate;
00494 byte lvl = 0, lvl1, lvl2;
00495 double keydim1 = 0, keydim2 = 0;
00496
00497 if (iST == 0)
00498 lvl = 0;
00499 else if(iST == 1 || iST == 2) {
00500
00501 if(iST == 1) keydim1 = intLoc.i;
00502 else keydim1 = intLoc.j;
00503 lvl = keydim1 * 3;
00504 }
00505 else if(iST == 3 || iST == 4) {
00506
00507 if(iST == 3) keydim1 = intLoc.i+intLoc.j-1;
00508 else keydim1 = intLoc.j-intLoc.i;
00509 lvl = (keydim1 < 0) ? 0 : 1;
00510 }
00511 else if(iST == 5) {
00512
00513 keydim1 = fabs(intLoc.i - 0.5);
00514 keydim2 = fabs(intLoc.j - 0.5);
00515 lvl1 = keydim1 < 0.25 ? 1 : 0;
00516 lvl2 = keydim2 < 0.25 ? 1 : 0;
00517 lvl = lvl1 * lvl2;
00518 }
00519 else if (iST >= 6 && iST <= 9) {
00520
00521 switch(iST) {
00522 case 6:
00523 keydim1 = intLoc.i;
00524 keydim2 = intLoc.j;
00525 break;
00526 case 7:
00527 keydim1 = 1-intLoc.i;
00528 keydim2 = intLoc.j;
00529 break;
00530 case 8:
00531 keydim1 = intLoc.j;
00532 keydim2 = intLoc.i;
00533 break;
00534 case 9:
00535 keydim1 = 1-intLoc.j;
00536 keydim2 = intLoc.i;
00537 break;
00538 }
00539 lvl1 = (2*keydim2 - keydim1) < 0 ? 0 : 1;
00540 lvl2 = (2*keydim2 - 2 + keydim1) < 0 ? 0 : 1;
00541 lvl = lvl1 + lvl2;
00542 }
00543 else if (iST >= 10 && iST <= 13) {
00544
00545 if(iST == 10 || iST == 12) keydim1 = intLoc.i + intLoc.j - 1;
00546 else keydim1 = intLoc.j - intLoc.i;
00547
00548 if(iST == 10 || iST == 11) {
00549 lvl1 = (keydim1 < -0.5) ? 0 : 1;
00550 lvl2 = (keydim1 < 0.5) ? 0 : 1;
00551 }
00552 else {
00553 lvl1 = (keydim1 < -0.25) ? 0 : 1;
00554 lvl2 = (keydim1 < 0.25) ? 0 : 1;
00555 }
00556 lvl = lvl1 + lvl2;
00557 }
00558 else if (iST >= 18 && iST <= 21) {
00559
00560 switch(iST) {
00561 case 18:
00562 keydim1 = intLoc.i;
00563 keydim2 = intLoc.j-fabs(intLoc.i-0.5);
00564 break;
00565 case 19:
00566 keydim1 = intLoc.j;
00567 keydim2 = intLoc.i-fabs(intLoc.j-0.5);
00568 break;
00569 case 20:
00570 keydim1 = intLoc.i;
00571 keydim2 = 1-intLoc.j-fabs(intLoc.i-0.5);
00572 break;
00573 case 21:
00574 keydim1 = intLoc.j;
00575 keydim2 = 1-intLoc.i-fabs(intLoc.j-0.5);
00576 break;
00577 }
00578 lvl1 = (keydim2 > 0.5) ? 0 : 1;
00579 lvl2 = (keydim1 < 0.5) ? 1 : 2;
00580 lvl = lvl1*lvl2;
00581 }
00582 else {
00583
00584 switch(iST) {
00585 case 14:
00586 case 22:
00587 case 26:
00588 keydim1 = intLoc.j;
00589 keydim2 = intLoc.i;
00590 break;
00591 case 15:
00592 case 25:
00593 case 29:
00594 keydim1 = intLoc.i;
00595 keydim2 = intLoc.j;
00596 break;
00597 case 16:
00598 case 24:
00599 case 28:
00600 keydim1 = 1-intLoc.j;
00601 keydim2 = intLoc.i;
00602 break;
00603 case 17:
00604 case 23:
00605 case 27:
00606 keydim1 = 1-intLoc.i;
00607 keydim2 = intLoc.j;
00608 break;
00609 }
00610
00611 switch(iST) {
00612 case 14:
00613 case 15:
00614 case 16:
00615 case 17:
00616 lvl1 = (keydim1 < 0.5) ? 0 : 1;
00617 lvl2 = (keydim2 < 0.5) ? 1 : 2;
00618 lvl = lvl1 * lvl2;
00619 break;
00620 case 22:
00621 case 23:
00622 case 24:
00623 case 25:
00624 lvl1 = (3 * keydim1 < 2) ? 0 : 1;
00625 lvl2 = (keydim2 < 0.5) ? 1 : 2;
00626 lvl = lvl1 * lvl2;
00627 break;
00628 case 26:
00629 case 27:
00630 case 28:
00631 case 29:
00632 lvl1 = (keydim1 < 0.25) ? 0 : 1;
00633 lvl2 = (keydim2 < 0.5) ? 1 : 2;
00634 lvl = lvl1 * lvl2;
00635 break;
00636 }
00637 }
00638
00639 return itsState.objLabels[lvl];
00640 }
00641
00642
00643
00644
00645
00646 ColorPixelClassifier::ColorPixelClassifier() : PixelClassifier()
00647 {}
00648
00649
00650 double ColorPixelClassifier::classifyAt(Image<PixRGB<byte> > im, uint C)
00651 {
00652 ASSERT(C < itsNumClasses);
00653 ColorCat cc = itsCats[C];
00654 Point2D<int> center(im.getWidth()/2, im.getHeight()/2);
00655 double dist = colorL2Distance(im[center],cc.color);
00656 return -dist / cc.sig_cdist;
00657 }
00658
00659
00660
00661
00662
00663 GistPixelClassifier::GistPixelClassifier() : PixelClassifier()
00664 {
00665
00666 }
00667
00668
00669
00670 void GistPixelClassifier::learnInput(Image<PixRGB<byte> > im, Image<uint> labels)
00671 {
00672 }
00673
00674
00675 double GistPixelClassifier::classifyAt(Image<PixRGB<byte> > im, uint C)
00676 {
00677 return 0;
00678 }
00679
00680
00681
00682
00683 std::string convertToString(const QuadNode &q)
00684 {
00685 std::string ret = "";
00686 for(uint i = 0; i < q.getDepth(); i++) ret+='\t';
00687 return ret + "(" + toStr(q.getArea()) + "): " +
00688 "seg class " + toStr(q.getState());
00689 }
00690
00691
00692 std::string convertToString(const QuadNode::NodeState& n)
00693 {
00694 return convertToString(n.segTemplate) + " :(" +
00695 convertToString(n.objLabels[0]) + "," +
00696 convertToString(n.objLabels[1]) + "," +
00697 convertToString(n.objLabels[2]) + ")";
00698 }
00699
00700
00701
00702
00703