00001 /*!@file Learn/QuadTree.C QuadTree Multi-Class Classifier */ 00002 // //////////////////////////////////////////////////////////////////// // 00003 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the // 00004 // University of Southern California (USC) and the iLab at USC. // 00005 // See http://iLab.usc.edu for information about this project. // 00006 // //////////////////////////////////////////////////////////////////// // 00007 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00008 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00009 // in Visual Environments, and Applications'' by Christof Koch and // 00010 // Laurent Itti, California Institute of Technology, 2001 (patent // 00011 // pending; application number 09/912,225 filed July 23, 2001; see // 00012 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00013 // //////////////////////////////////////////////////////////////////// // 00014 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00015 // // 00016 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00017 // redistribute it and/or modify it under the terms of the GNU General // 00018 // Public License as published by the Free Software Foundation; either // 00019 // version 2 of the License, or (at your option) any later version. // 00020 // // 00021 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00022 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00023 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00024 // PURPOSE. See the GNU General Public License for more details. // 00025 // // 00026 // You should have received a copy of the GNU General Public License // 00027 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00028 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00029 // Boston, MA 02111-1307 USA. // 00030 // //////////////////////////////////////////////////////////////////// // 00031 // 00032 // Primary maintainer for this file: John Shen <shenjohn@usc.edu> 00033 // $HeadURL$ 00034 // $Id$ 00035 // 00036 // Implementation of the segmentation algorithm described in: 00037 // 00038 // Recursive Segmentation and Recognition Templates for 2D Parsing 00039 // Leo Zhu, Yuanhao Chen, Yuan Lin, Chenxi Lin, Alan Yuille 00040 // Advances in Neural Information Processing Systems, 2008 00041 // 00042 00043 #include "Channels/IntensityChannel.H" 00044 #include "Channels/InputFrame.H" 00045 #include "Channels/ChannelOpts.H" 00046 #include "Component/ModelManager.H" 00047 #include "GUI/DebugWin.H" 00048 #include "Image/ColorMap.H" // for colorize() 00049 #include "Image/ColorOps.H" // for colorize() 00050 #include "Image/CutPaste.H" // for inplaceEmbed() 00051 #include "Image/Dims.H" 00052 #include "Image/Image.H" 00053 #include "Image/LevelSpec.H" 00054 #include "Image/MathOps.H" // for absDiff() 00055 #include "Image/Pixels.H" 00056 #include "Image/Point3D.H" 00057 #include "Learn/QuadTree.H" 00058 #include "Util/log.H" 00059 #include "Util/StringConversions.H" 00060 00061 #include <cmath> 00062 #include <vector> 00063 #include <queue> // for DP queue 00064 #include <iostream> 00065 #include <algorithm> // for std::swap 00066 00067 // ###################################################################### 00068 QuadTree::QuadTree(int Nlevels, Dims d) : itsNumLevels(Nlevels) 00069 { 00070 // initialize tree 00071 rutz::shared_ptr<QuadNode> root_ref(new QuadNode()); //this is the temporary top, no ptrs initialized 00072 Rectangle thisWindow(Point2D<int>(0,0), d); 00073 addTreeUnder(root_ref, Nlevels, thisWindow); 00074 00075 itsRootNode = root_ref->getChild(0); 00076 initAlphas(); 00077 } 00078 00079 // ###################################################################### 00080 QuadTree::QuadTree(int Nlevels, Image<PixRGB<byte> > im) : itsNumLevels(Nlevels) 00081 { 00082 // QuadTree::QuadTree(Nlevels, im.getDims()); 00083 Dims d = im.getDims(); 00084 00085 // initialize tree 00086 rutz::shared_ptr<QuadNode> root_ref(new QuadNode()); //this is the temporary top, no ptrs initialized 00087 Rectangle thisWindow(Point2D<int>(0,0), d); 00088 addTreeUnder(root_ref, Nlevels, thisWindow); 00089 00090 itsRootNode = root_ref->getChild(0); 00091 00092 itsImage = im; 00093 initAlphas(); 00094 } 00095 00096 // ###################################################################### 00097 void QuadTree::addTreeUnder(rutz::shared_ptr<QuadNode> parent, int Nlevel, Rectangle r) 00098 { 00099 // initialize node 00100 rutz::shared_ptr<QuadNode> myNewNode(new QuadNode(parent)); 00101 myNewNode->setArea(r); 00102 myNewNode->setDepth(itsNumLevels-Nlevel); 00103 00104 // add to tree 00105 parent->addChild(myNewNode); 00106 00107 // add to internal deque 00108 itsNodes.push_back(myNewNode); 00109 00110 // if there are levels below, add sub-trees to this node 00111 if (Nlevel > 0) { 00112 // find 4 smaller rectangles 00113 Point2D<int> middle = r.center(); 00114 00115 Rectangle tl = Rectangle::tlbrO(r.top(),r.left(),middle.j,middle.i); 00116 Rectangle tr = Rectangle::tlbrO(r.top(),middle.i,middle.j,r.rightO()); 00117 Rectangle bl = Rectangle::tlbrO(middle.j,r.left(),r.bottomO(),middle.i); 00118 Rectangle br = Rectangle::tlbrO(middle.j,middle.i,r.bottomO(),r.rightO()); 00119 00120 addTreeUnder(myNewNode, Nlevel - 1, tl); 00121 addTreeUnder(myNewNode, Nlevel - 1, tr); 00122 addTreeUnder(myNewNode, Nlevel - 1, bl); 00123 addTreeUnder(myNewNode, Nlevel - 1, br); 00124 } 00125 } 00126 00127 // ###################################################################### 00128 void QuadTree::cacheClassifierResult() 00129 { 00130 uint NClasses = itsClassifier->getNumClasses(); 00131 Image<double> output(itsImage.getDims(),ZEROS); 00132 Image<double> denom(itsImage.getDims(),ZEROS); 00133 Dims patch_size(5,5); 00134 Image<PixRGB<byte> > patch(patch_size,ZEROS); 00135 Rectangle im_rect(Point2D<int>(0,0), itsImage.getDims()); 00136 00137 itsClassifierOutput.clear(); 00138 itsBestClassOutput.resize(itsImage.getDims()); 00139 for(uint i = 0; i < NClasses; i++) 00140 itsClassifierOutput.push_back(output); 00141 00142 Point2D<int> P_l; // local coords 00143 for(P_l.i = 0; P_l.i < itsImage.getWidth(); P_l.i++) 00144 for(P_l.j = 0; P_l.j < itsImage.getHeight(); P_l.j++) { 00145 00146 Rectangle rect_patch = Rectangle::centerDims(P_l, patch_size); 00147 rect_patch = constrainRect(rect_patch, im_rect,0,itsImage.getWidth(),0,itsImage.getHeight()); 00148 patch = crop(itsImage, rect_patch); 00149 00150 byte best_class = 0; 00151 for(uint i = 0; i < NClasses; i++) { 00152 double res = itsClassifier->classifyAt(patch, i); 00153 if(res > itsClassifierOutput[best_class][P_l]) 00154 best_class = i; 00155 00156 itsClassifierOutput[i][P_l] = res; 00157 // get partition denominator 00158 denom[P_l] += exp(res); 00159 } 00160 itsBestClassOutput[P_l] = best_class; 00161 } 00162 00163 for(uint i = 0; i < NClasses; i++) 00164 itsClassifierOutput[i] = exp(itsClassifierOutput[i])/denom; 00165 } 00166 00167 // ###################################################################### 00168 double QuadTree::evaluateClassifierAt(rutz::shared_ptr<QuadNode> q) const 00169 { 00170 int Npts = itsImage.getDims().sz(); 00171 double E = 0; 00172 Point2D<int> P_l, P_g; // local coords 00173 for(P_l.i = 0; P_l.i < itsImage.getWidth(); P_l.i++) 00174 for(P_l.j = 0; P_l.j < itsImage.getHeight(); P_l.j++) { 00175 P_g = q->convertToGlobal(P_l); 00176 E -= log(itsClassifierOutput[q->getObjLabelAt(P_g)][P_l]); 00177 } 00178 00179 return E/Npts; 00180 } 00181 00182 // ###################################################################### 00183 double QuadTree::evaluateCohesionAt(rutz::shared_ptr<QuadNode> q) const 00184 { 00185 // a negative energy term where pixels that belong to the same partitions have similar appearance 00186 // loop over all pairs of points in the image at that quadnode 00187 Neighborhood one_away; 00188 // one_away.push_back(Point2D<int>(-1,-1)); 00189 //one_away.push_back(Point2D<int>(-1,0)); 00190 //one_away.push_back(Point2D<int>(-1,1)); 00191 //one_away.push_back(Point2D<int>(0,-1)); 00192 const float lambda = 1; // term weighting the importance of "cohesion" relative to the other energy terms, might depend on scale 00193 const float sigma_col = 75; // the gaussian st-dev in color space 00194 const Point2D<int> origin(0,0); 00195 float e_coh = 0; 00196 00197 one_away.push_back(Point2D<int>(0,1)); 00198 one_away.push_back(Point2D<int>(1,-1)); 00199 one_away.push_back(Point2D<int>(1,0)); 00200 one_away.push_back(Point2D<int>(1,1)); 00201 00202 00203 Image<byte> segImage = q->getSegImage(); 00204 00205 int nedges = 0; 00206 // int n_same = 0; 00207 //int N_bad = 0; 00208 00209 Point2D<int> P_l,Q_l, P_g, Q_g; // local and global coords 00210 for(P_l.i = 0; P_l.i < segImage.getWidth(); P_l.i++) { 00211 for(P_l.j = 0; P_l.j < segImage.getHeight(); P_l.j++) { 00212 for(uint k = 0; k < one_away.size(); k++) { 00213 // calculate the points 00214 Q_l = P_l + one_away[k]; 00215 P_g = q->convertToGlobal(P_l); 00216 Q_g = q->convertToGlobal(Q_l); 00217 00218 if (!segImage.coordsOk(Q_l)) continue; // if Q is out of bounds, skip 00219 00220 nedges++; // count the edge 00221 //if P and Q are not in the same segment, skip 00222 if(segImage[P_l] != segImage[Q_l]) continue; 00223 00224 // the computation of the energy term 00225 double color_dist = colorDistance(itsImage[P_g], itsImage[Q_g]); 00226 double space_dist = one_away[k].distance(origin); 00227 e_coh -= lambda / space_dist * exp(- color_dist * color_dist / (2 * sigma_col * sigma_col)); 00228 } 00229 } 00230 } 00231 return e_coh / nedges; 00232 } 00233 00234 // ###################################################################### 00235 double QuadTree::evaluateCorrespondenceAt(rutz::shared_ptr<QuadNode> q) const 00236 { 00237 if(q->isLeaf()) return 0; 00238 00239 Point2D<int> P_l,Q_l, P_g, Q_g; // local and global coords 00240 Image<byte> parentImage = q->getSegImage(); 00241 Image<byte> childImage = q->getChildSegImage(); 00242 00243 Image<byte> deltaImage = absDiff(parentImage, childImage); 00244 double ret = double(-emptyArea(deltaImage))/deltaImage.getDims().sz(); 00245 00246 return ret; 00247 } 00248 00249 // ###################################################################### 00250 double QuadTree::evaluateTotalEnergyAt(rutz::shared_ptr<QuadNode> q) const 00251 { 00252 double E = evaluateClassifierAt(q) * itsAlphas[0]; 00253 E += evaluateCohesionAt(q) * itsAlphas[1]; 00254 E += evaluateCorrespondenceAt(q) * itsAlphas[2]; 00255 00256 if(q->isLeaf()) return E; 00257 for(uint i = 0; i < 4; i++) 00258 E += evaluateTotalEnergyAt(q->getChild(i)); 00259 00260 return E; 00261 } 00262 00263 // ###################################################################### 00264 void QuadTree::printTree() const 00265 { 00266 LINFO("Tree of depth %u, number of nodes %zu", itsNumLevels, itsNodes.size()); 00267 // NB: not a traversal for now, just reading off the queue 00268 for(uint i = 0; i < itsNodes.size(); i++) 00269 LINFO("%s", toStr(*itsNodes[i]).c_str()); 00270 00271 } 00272 00273 // ###################################################################### 00274 std::string QuadTree::writeTree() const 00275 { 00276 std::string ret = ""; 00277 ret += sformat("Tree of depth %u, number of nodes %zu\n", itsNumLevels, itsNodes.size()); 00278 // NB: not a traversal for now, just reading off the queue 00279 for(uint i = 0; i < itsNodes.size(); i++) 00280 ret += sformat("%s\n", toStr(*itsNodes[i]).c_str()); 00281 00282 return ret; 00283 } 00284 00285 // ###################################################################### 00286 std::vector<QuadNode::NodeState> QuadTree::generateProposalsAt(rutz::shared_ptr<QuadNode> q, double thresh) 00287 { 00288 // this code is just for the leaf nodes right now 00289 uint NClasses = itsClassifier->getNumClasses(); 00290 00291 if(!q->isLeaf()) { //combine proposals 00292 for(uint i = 0; i < 4; i++) { // children loop 00293 std::vector<QuadNode::NodeState> child_props = generateProposalsAt(q->getChild(i), thresh); 00294 if(child_props.size() == 0) LFATAL("no proposals made for node %s", toStr(q->getChild(i)).c_str()); 00295 q->getChild(i)->setState(child_props[0]); 00296 } 00297 } 00298 00299 QuadNode::NodeState probe(0,0,1,2), realstate = q->getState(); 00300 std::vector<QuadNode::NodeState> ret; 00301 00302 // try fitting each template first 00303 // uint common[3][NClasses]; 00304 00305 byte top_class[3][NClasses]; 00306 uint prevalence[3][NClasses]; 00307 Rectangle r = q->getArea(); 00308 00309 if(1) { 00310 for(; probe.segTemplate < 30; probe.segTemplate++) { 00311 00312 //clear prevalence 00313 for(uint i = 0; i < 3; i++) 00314 for(uint j= 0; j < NClasses; j++) { 00315 prevalence[i][j]=0; 00316 top_class[i][j]=j; 00317 } 00318 00319 // find the frequency of each label for each region in the classifier 00320 for(uint i = 0; i < 3; i++) probe.objLabels[i] = i; 00321 00322 q->setState(probe); //initialize just for counting purposes 00323 00324 Point2D<int> P_g; 00325 for(P_g.i = r.left(); P_g.i < r.rightO(); P_g.i++) 00326 for(P_g.j = r.top(); P_g.j < r.bottomO(); P_g.j++) 00327 prevalence[q->getObjLabelAt(P_g)][itsBestClassOutput[P_g]]++; 00328 00329 // sort each entry by the class prevalence 00330 for(uint i = 0; i < 3; i ++) { 00331 for(uint j = 0; j < NClasses; j++) { 00332 for(uint k = 0; k < NClasses-j-1; k++) 00333 if(prevalence[i][k] < prevalence[i][k+1]) 00334 { 00335 std::swap(top_class[i][k], top_class[i][k+1]); 00336 std::swap(prevalence[i][k], prevalence[i][k+1]); 00337 } 00338 } 00339 } 00340 00341 // finding proposals - DP setup 00342 std::queue<Point3D<uint> > tryme; 00343 bool tested[NClasses][NClasses][NClasses]; 00344 for(uint i = 0; i < NClasses; i++) 00345 for(uint j = 0; j < NClasses; j++) 00346 for(uint k = 0; k < NClasses; k++) 00347 tested[i][j][k] = false; 00348 00349 uint area = r.area(); 00350 const double occ_tol = 0.0001; 00351 for(uint i = 0; i < NClasses; i++) { 00352 if(prevalence[0][0] - prevalence[0][i] > occ_tol * area) break; 00353 for(uint j = 0; j < NClasses; j++) { 00354 if(prevalence[1][0] - prevalence[1][j] > occ_tol * area) break; 00355 for(uint k = 0; k < NClasses; k++) { 00356 if(prevalence[2][0] - prevalence[2][k] > occ_tol * area) break; 00357 tryme.push(Point3D<uint>(i,j,k)); 00358 if(probe.isDoubleton()) break; //the last label doesn't matter 00359 } 00360 if(probe.isSingleton()) break; // the 2nd to last label doesn't matter 00361 } 00362 } 00363 while(!tryme.empty()) { 00364 Point3D<uint> n = tryme.front(); 00365 tryme.pop(); 00366 00367 if(tested[n.x][n.y][n.z]) continue; 00368 if(n.x >= NClasses || n.y >= NClasses || n.z >= NClasses) continue; 00369 probe.objLabels[0] = top_class[0][n.x]; 00370 probe.objLabels[1] = top_class[1][n.y]; 00371 probe.objLabels[2] = top_class[2][n.z]; 00372 q->setState(probe); 00373 q->storeEnergy(evaluateTotalEnergyAt(q)); 00374 // probe.evaled = true; 00375 tested[n.x][n.y][n.z]=true; 00376 00377 if(q->getEnergy() < thresh) { 00378 probe.E = q->getEnergy(); 00379 ret.push_back(probe); 00380 tryme.push(Point3D<uint>(n.x+1,n.y,n.z)); 00381 if(!probe.isSingleton()) tryme.push(Point3D<uint>(n.x,n.y+1,n.z)); 00382 if(!probe.isDoubleton()) tryme.push(Point3D<uint>(n.x,n.y,n.z+1)); 00383 } 00384 } 00385 } //end seg template loop 00386 00387 for(uint j = 0; j < ret.size(); j++) 00388 for(uint k = 0; k < ret.size()-j-1; k++) 00389 if(ret[k].E > ret[k+1].E) 00390 std::swap(ret[k],ret[k+1]); 00391 00392 q->setState(realstate); 00393 } 00394 return ret; 00395 } 00396 00397 // ###################################################################### 00398 QuadNode::QuadNode() : itsIsStale(true), itsState(0) 00399 { 00400 for(uint i = 0; i < 3; i++) itsState.objLabels.push_back(i); 00401 } 00402 00403 // ###################################################################### 00404 QuadNode::QuadNode(rutz::shared_ptr<QuadNode> q) 00405 : itsIsLeaf(true), itsIsStale(true), 00406 itsState(0), 00407 itsParent(q) 00408 { 00409 for(uint i = 0; i < 3; i++) itsState.objLabels.push_back(i); 00410 } 00411 00412 // ###################################################################### 00413 QuadNode::QuadNode(rutz::shared_ptr<QuadNode> q, NodeState n) 00414 : itsIsLeaf(true), itsIsStale(true), 00415 itsState(n), 00416 itsParent(q) 00417 { 00418 } 00419 00420 // ###################################################################### 00421 Image<byte> QuadNode::getChildSegImage() 00422 { 00423 if(isLeaf()) return getSegImage(); 00424 Image<byte> ret(getArea().dims(),ZEROS); 00425 for(uint i = 0; i < 4; i++) { 00426 rutz::shared_ptr<QuadNode> child = getChild(i); 00427 inplaceEmbed(ret, child->getSegImage(), child->getArea(),byte(-1)); 00428 } 00429 return ret; 00430 } 00431 00432 // ###################################################################### 00433 Image<PixRGB<byte> > QuadNode::getColorizedSegImage() 00434 { 00435 return colorLabels(getSegImage()); 00436 } 00437 00438 // ###################################################################### 00439 Image<PixRGB<byte> > QuadNode::getColorizedChildSegImage() 00440 { 00441 return colorLabels(getChildSegImage()); 00442 } 00443 00444 // ###################################################################### 00445 void QuadNode::refreshSegImage() 00446 { 00447 ASSERT(itsState.objLabels.size() > 0); 00448 00449 Image<byte> ret(itsArea.dims(),NO_INIT); 00450 Point2D<int> P_local; 00451 for(P_local.i = 0; P_local.i < itsArea.dims().w(); P_local.i++) 00452 for (P_local.j = 0; P_local.j < itsArea.dims().h(); P_local.j++) 00453 ret[P_local] = getObjLabelAt(convertToGlobal(P_local)); 00454 00455 itsSegImage = ret; 00456 itsIsStale = false; 00457 } 00458 00459 // ###################################################################### 00460 Image<PixRGB<byte> > QuadNode::colorLabels(Image<byte> im) const 00461 { 00462 ColorMap cm(256); 00463 PixRGB<byte> col; 00464 for(uint i = 2; i <= 2; i--) { // NB: once we get a classifier, we will need a real colormap instead of an adhoc one, this 00465 col = PixRGB<byte>(0,0,0); 00466 col[i] = 255; 00467 cm[i] = col; 00468 } 00469 return colorize(im, cm); 00470 } 00471 00472 // ###################################################################### 00473 byte QuadNode::getObjLabelAt(Point2D<int> loc) const 00474 { 00475 // TODO: move this logic to drawing the template all at once 00476 00477 // check if point resides in the area 00478 if(!itsArea.contains(loc)) { 00479 LINFO("Node at window (%s) does not contain point (%s)", 00480 toStr(itsArea).c_str(), toStr(loc).c_str()); 00481 return -1; 00482 } 00483 00484 // result is already memoized 00485 if(!itsIsStale) { 00486 return itsSegImage[convertToLocal(loc)]; 00487 } 00488 00489 //convert point to [0,1] x [0,1] (scaled) coordinates; 00490 Point2D<double> intLoc(double(loc.i-itsArea.left())/(itsArea.width()), 00491 double(loc.j-itsArea.top())/(itsArea.height())); 00492 00493 uint iST = itsState.segTemplate; 00494 byte lvl = 0, lvl1, lvl2; 00495 double keydim1 = 0, keydim2 = 0; 00496 00497 if (iST == 0) 00498 lvl = 0; 00499 else if(iST == 1 || iST == 2) { 00500 // horizontal/vertical edges 00501 if(iST == 1) keydim1 = intLoc.i; 00502 else keydim1 = intLoc.j; 00503 lvl = keydim1 * 3; 00504 } 00505 else if(iST == 3 || iST == 4) { 00506 //diagonal edges 00507 if(iST == 3) keydim1 = intLoc.i+intLoc.j-1; 00508 else keydim1 = intLoc.j-intLoc.i; 00509 lvl = (keydim1 < 0) ? 0 : 1; 00510 } 00511 else if(iST == 5) { 00512 //box inside another box 00513 keydim1 = fabs(intLoc.i - 0.5); 00514 keydim2 = fabs(intLoc.j - 0.5); 00515 lvl1 = keydim1 < 0.25 ? 1 : 0; 00516 lvl2 = keydim2 < 0.25 ? 1 : 0; 00517 lvl = lvl1 * lvl2; 00518 } 00519 else if (iST >= 6 && iST <= 9) { 00520 //V-junctions 00521 switch(iST) { 00522 case 6: 00523 keydim1 = intLoc.i; 00524 keydim2 = intLoc.j; 00525 break; 00526 case 7: 00527 keydim1 = 1-intLoc.i; 00528 keydim2 = intLoc.j; 00529 break; 00530 case 8: 00531 keydim1 = intLoc.j; 00532 keydim2 = intLoc.i; 00533 break; 00534 case 9: 00535 keydim1 = 1-intLoc.j; 00536 keydim2 = intLoc.i; 00537 break; 00538 } 00539 lvl1 = (2*keydim2 - keydim1) < 0 ? 0 : 1; 00540 lvl2 = (2*keydim2 - 2 + keydim1) < 0 ? 0 : 1; 00541 lvl = lvl1 + lvl2; 00542 } 00543 else if (iST >= 10 && iST <= 13) { 00544 //diagonal orientations 00545 if(iST == 10 || iST == 12) keydim1 = intLoc.i + intLoc.j - 1; 00546 else keydim1 = intLoc.j - intLoc.i; 00547 00548 if(iST == 10 || iST == 11) { 00549 lvl1 = (keydim1 < -0.5) ? 0 : 1; 00550 lvl2 = (keydim1 < 0.5) ? 0 : 1; 00551 } 00552 else { 00553 lvl1 = (keydim1 < -0.25) ? 0 : 1; 00554 lvl2 = (keydim1 < 0.25) ? 0 : 1; 00555 } 00556 lvl = lvl1 + lvl2; 00557 } 00558 else if (iST >= 18 && iST <= 21) { 00559 //Y-junctions 00560 switch(iST) { 00561 case 18: 00562 keydim1 = intLoc.i; 00563 keydim2 = intLoc.j-fabs(intLoc.i-0.5); 00564 break; 00565 case 19: 00566 keydim1 = intLoc.j; 00567 keydim2 = intLoc.i-fabs(intLoc.j-0.5); 00568 break; 00569 case 20: 00570 keydim1 = intLoc.i; 00571 keydim2 = 1-intLoc.j-fabs(intLoc.i-0.5); 00572 break; 00573 case 21: 00574 keydim1 = intLoc.j; 00575 keydim2 = 1-intLoc.i-fabs(intLoc.j-0.5); 00576 break; 00577 } 00578 lvl1 = (keydim2 > 0.5) ? 0 : 1; 00579 lvl2 = (keydim1 < 0.5) ? 1 : 2; 00580 lvl = lvl1*lvl2; 00581 } 00582 else { 00583 // T-junctions 00584 switch(iST) { 00585 case 14: 00586 case 22: 00587 case 26: 00588 keydim1 = intLoc.j; 00589 keydim2 = intLoc.i; 00590 break; 00591 case 15: 00592 case 25: 00593 case 29: 00594 keydim1 = intLoc.i; 00595 keydim2 = intLoc.j; 00596 break; 00597 case 16: 00598 case 24: 00599 case 28: 00600 keydim1 = 1-intLoc.j; 00601 keydim2 = intLoc.i; 00602 break; 00603 case 17: 00604 case 23: 00605 case 27: 00606 keydim1 = 1-intLoc.i; 00607 keydim2 = intLoc.j; 00608 break; 00609 } 00610 00611 switch(iST) { 00612 case 14: 00613 case 15: 00614 case 16: 00615 case 17: 00616 lvl1 = (keydim1 < 0.5) ? 0 : 1; 00617 lvl2 = (keydim2 < 0.5) ? 1 : 2; 00618 lvl = lvl1 * lvl2; 00619 break; 00620 case 22: 00621 case 23: 00622 case 24: 00623 case 25: 00624 lvl1 = (3 * keydim1 < 2) ? 0 : 1; 00625 lvl2 = (keydim2 < 0.5) ? 1 : 2; 00626 lvl = lvl1 * lvl2; 00627 break; 00628 case 26: 00629 case 27: 00630 case 28: 00631 case 29: 00632 lvl1 = (keydim1 < 0.25) ? 0 : 1; 00633 lvl2 = (keydim2 < 0.5) ? 1 : 2; 00634 lvl = lvl1 * lvl2; 00635 break; 00636 } 00637 } 00638 00639 return itsState.objLabels[lvl]; 00640 } 00641 00642 // ###################################################################### 00643 // #### ColorPixelClassifier 00644 // ###################################################################### 00645 00646 ColorPixelClassifier::ColorPixelClassifier() : PixelClassifier() 00647 {} 00648 00649 // ###################################################################### 00650 double ColorPixelClassifier::classifyAt(Image<PixRGB<byte> > im, uint C) 00651 { 00652 ASSERT(C < itsNumClasses); 00653 ColorCat cc = itsCats[C]; 00654 Point2D<int> center(im.getWidth()/2, im.getHeight()/2); 00655 double dist = colorL2Distance(im[center],cc.color); 00656 return -dist / cc.sig_cdist; 00657 } 00658 00659 // ###################################################################### 00660 // #### GistPixelClassifier 00661 // ###################################################################### 00662 00663 GistPixelClassifier::GistPixelClassifier() : PixelClassifier() 00664 { 00665 00666 } 00667 00668 // ###################################################################### 00669 00670 void GistPixelClassifier::learnInput(Image<PixRGB<byte> > im, Image<uint> labels) 00671 { 00672 } 00673 00674 // ###################################################################### 00675 double GistPixelClassifier::classifyAt(Image<PixRGB<byte> > im, uint C) 00676 { 00677 return 0; 00678 } 00679 // Free functions: 00680 // ###################################################################### 00681 00682 // ###################################################################### 00683 std::string convertToString(const QuadNode &q) 00684 { 00685 std::string ret = ""; 00686 for(uint i = 0; i < q.getDepth(); i++) ret+='\t'; 00687 return ret + "(" + toStr(q.getArea()) + "): " + 00688 "seg class " + toStr(q.getState()); 00689 } 00690 00691 // ###################################################################### 00692 std::string convertToString(const QuadNode::NodeState& n) 00693 { 00694 return convertToString(n.segTemplate) + " :(" + 00695 convertToString(n.objLabels[0]) + "," + 00696 convertToString(n.objLabels[1]) + "," + 00697 convertToString(n.objLabels[2]) + ")"; 00698 } 00699 00700 /* So things look consistent in everyone's emacs... */ 00701 /* Local Variables: */ 00702 /* indent-tabs-mode: nil */ 00703 /* End: */