00001 /*!@file Learn/QuadTree.H QuadTree Multi-Class Classifier */ 00002 // //////////////////////////////////////////////////////////////////// // 00003 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the // 00004 // University of Southern California (USC) and the iLab at USC. // 00005 // See http://iLab.usc.edu for information about this project. // 00006 // //////////////////////////////////////////////////////////////////// // 00007 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00008 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00009 // in Visual Environments, and Applications'' by Christof Koch and // 00010 // Laurent Itti, California Institute of Technology, 2001 (patent // 00011 // pending; application number 09/912,225 filed July 23, 2001; see // 00012 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00013 // //////////////////////////////////////////////////////////////////// // 00014 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00015 // // 00016 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00017 // redistribute it and/or modify it under the terms of the GNU General // 00018 // Public License as published by the Free Software Foundation; either // 00019 // version 2 of the License, or (at your option) any later version. // 00020 // // 00021 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00022 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00023 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00024 // PURPOSE. See the GNU General Public License for more details. // 00025 // // 00026 // You should have received a copy of the GNU General Public License // 00027 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00028 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00029 // Boston, MA 02111-1307 USA. // 00030 // //////////////////////////////////////////////////////////////////// // 00031 // 00032 // Primary maintainer for this file: John Shen <shenjohn@usc.edu> 00033 // $HeadURL$ 00034 // $Id$ 00035 // 00036 ////////////////////////////////////////////////////////////////////////// 00037 // 00038 // Implementation of the segmentation algorithm described in: 00039 // 00040 // Recursive Segmentation and Recognition Templates for 2D Parsing 00041 // Leo Zhu, Yuanhao Chen, Yuan Lin, Chenxi Lin, Alan Yuille 00042 // Advances in Neural Information Processing Systems, 2008 00043 // 00044 00045 #ifndef QUADTREE_H_DEFINED 00046 #define QUADTREE_H_DEFINED 00047 00048 #include "Channels/RawVisualCortex.H" 00049 #include "Image/Image.H" 00050 #include "Image/PixelsTypes.H" 00051 #include "Image/Point2D.H" 00052 #include "Image/Rectangle.H" 00053 #include "rutz/shared_ptr.h" 00054 #include "Util/log.H" 00055 #include "Util/StringConversions.H" 00056 #include "nub/ref.h" 00057 #include <limits> 00058 #include <math.h> 00059 #include <stdio.h> 00060 00061 //! QuadTree for segmentation and recognition from Leo Zhu's 2008 NIPS paper 00062 class PixelClassifier; 00063 class ColorPixelClassifier; 00064 00065 typedef std::vector<Point2D<int> > Neighborhood; 00066 class QuadNode 00067 { 00068 public: 00069 struct NodeState { 00070 uint segTemplate; //region partition 00071 std::vector<byte> objLabels; //3 object labels 00072 double E; 00073 bool evaled; 00074 00075 NodeState(uint st) : segTemplate(st) {} 00076 NodeState(uint st, std::vector<byte> oL) : segTemplate(st), objLabels(oL) {} 00077 NodeState(uint st, byte l1, byte l2, byte l3) : 00078 segTemplate(st) { 00079 objLabels.push_back(l1); 00080 objLabels.push_back(l2); 00081 objLabels.push_back(l3); 00082 } 00083 00084 bool isSingleton() const {return segTemplate==0;} 00085 bool isDoubleton() const {return segTemplate==0 || (segTemplate >= 3 && segTemplate <= 5);} 00086 }; 00087 00088 QuadNode(); 00089 QuadNode(rutz::shared_ptr<QuadNode> q); 00090 QuadNode(rutz::shared_ptr<QuadNode> q, NodeState n); 00091 Image<byte> getSegImage() {if(itsIsStale) refreshSegImage(); return itsSegImage;} 00092 Image<byte> getChildSegImage(); 00093 Image<PixRGB<byte> > getColorizedSegImage(); 00094 Image<PixRGB<byte> > getColorizedChildSegImage(); 00095 byte getObjLabelAt(Point2D<int> loc) const; 00096 00097 bool isLeaf() const {return itsIsLeaf;} 00098 00099 void setArea(Rectangle r) {itsArea = r;} 00100 Rectangle getArea() const {return itsArea;} 00101 00102 void setDepth(uint l) {itsDepth = l;} 00103 uint getDepth() const {return itsDepth;} 00104 00105 void setLabels(std::vector<byte> v) {itsState.objLabels = v; itsIsStale = true;} 00106 void setLabel(uint i, byte v) {itsState.objLabels[i] = v; itsIsStale = true;} 00107 std::vector<byte> getLabels() const {return itsState.objLabels;} 00108 00109 void setSegTemplate(byte c) {itsState.segTemplate = c; itsIsStale = true;} 00110 uint getSegTemplate() const {return itsState.segTemplate;} 00111 00112 void setState(NodeState n) {itsState = n; itsIsStale = true;} 00113 NodeState getState() const {return itsState;} 00114 00115 void storeEnergy(double e) {itsState.E = e; itsState.evaled = true;} 00116 double getEnergy() const {return itsState.E;} 00117 bool energySaved() const {return itsState.evaled;} 00118 00119 void addChild(rutz::shared_ptr<QuadNode> n) { 00120 itsIsLeaf = false; 00121 if(itsChildren.size() < 4) itsChildren.push_back(n); 00122 else LINFO("Tried to add node in excess of 4"); 00123 } 00124 00125 Point2D<int> convertToGlobal(Point2D<int> p) const 00126 {return Point2D<int>(p.i + itsArea.left(), p.j + itsArea.top());} 00127 00128 Point2D<int> convertToLocal(Point2D<int> p) const 00129 {return Point2D<int>(p.i - itsArea.left(), p.j - itsArea.top());} 00130 00131 rutz::shared_ptr<QuadNode> getChild(uint i) const { 00132 return itsChildren[i]; 00133 } 00134 00135 private: 00136 void refreshSegImage(); 00137 Image<PixRGB<byte> > colorLabels(Image<byte> im) const; 00138 00139 uint itsDepth; 00140 00141 // NB: if the node is a leaf, 00142 // then the shared_ptrs to children are uninitialized 00143 bool itsIsLeaf; 00144 bool itsIsStale; 00145 00146 Rectangle itsArea; 00147 00148 NodeState itsState; 00149 Image<byte> itsSegImage; 00150 rutz::shared_ptr<QuadNode> itsParent; 00151 00152 // top left, top right, bottom left, bottom right 00153 // NB: the tree structure is static once initialized 00154 std::vector<rutz::shared_ptr<QuadNode> > itsChildren; 00155 }; 00156 00157 class QuadTree 00158 { 00159 public: 00160 // construct a quad tree on a given depth and image size 00161 QuadTree(int Nlevels, Dims d); 00162 QuadTree(int Nlevels, Image<PixRGB<byte> > im); 00163 00164 void initAlphas() {double weights[] = {1, 2.5, 2.5}; itsAlphas.assign(weights,weights+3);} 00165 void addTreeUnder(rutz::shared_ptr<QuadNode> parent,int Nlevels, Rectangle r); 00166 00167 void cacheClassifierResult(); 00168 double evaluateClassifierAt(rutz::shared_ptr<QuadNode> q) const; // classifier accuracy 00169 00170 double evaluateCohesionAt(rutz::shared_ptr<QuadNode> q) const; //cohesion - similar pixels in similar groups in one region 00171 00172 double evaluateCorrespondenceAt(rutz::shared_ptr<QuadNode> q) const; //labeling is the same at both layers, between a parent and its children 00173 00174 double evaluateTotalEnergyAt(rutz::shared_ptr<QuadNode> q) const; 00175 // byte getClassOutputAt(Point2D<int> P) const; 00176 void printTree() const; 00177 std::string writeTree() const; 00178 00179 void setClassifier(rutz::shared_ptr<PixelClassifier> cc) {itsClassifier = cc;} 00180 rutz::shared_ptr<PixelClassifier> getClassifier() const {return itsClassifier;} 00181 00182 rutz::shared_ptr<QuadNode> getRootNode() const {return itsRootNode;} 00183 00184 std::vector<QuadNode::NodeState> generateProposalsAt(rutz::shared_ptr<QuadNode> q, double thresh); 00185 private: 00186 00187 uint itsNumLevels; 00188 Dims itsDims; 00189 Image<PixRGB<byte> > itsImage; 00190 00191 rutz::shared_ptr<PixelClassifier> itsClassifier; 00192 std::vector<Image<double> > itsClassifierOutput; 00193 Image<byte> itsBestClassOutput; 00194 rutz::shared_ptr<QuadNode> itsRootNode; 00195 std::deque<rutz::shared_ptr<QuadNode> > itsNodes; 00196 std::vector<double> itsAlphas; 00197 }; 00198 00199 00200 // base class for a pixel-based classifier 00201 class PixelClassifier 00202 { 00203 public: 00204 PixelClassifier() {itsNumClasses=0;} 00205 00206 uint getNumClasses() const {return itsNumClasses;} 00207 00208 virtual void learnInput(Image<PixRGB<byte> > im, Image<uint> labels) = 0; 00209 virtual double classifyAt(Image<PixRGB<byte> > im, uint C) = 0; 00210 protected: 00211 uint itsNumClasses; 00212 }; 00213 00214 class ColorPixelClassifier : public PixelClassifier 00215 { 00216 public: 00217 struct ColorCat { 00218 PixRGB<byte> color; 00219 double sig_cdist; 00220 00221 ColorCat() {} 00222 ColorCat(PixRGB<byte> c, double d) {color = c; sig_cdist = d;} 00223 }; 00224 00225 ColorPixelClassifier(); 00226 00227 void learnInput(Image<PixRGB<byte> > im, Image<uint> labels) {} 00228 void addCategory(ColorCat cc) {itsCats.push_back(cc); 00229 itsNumClasses++;} 00230 00231 // just uses color of center pixel for now 00232 double classifyAt(Image<PixRGB<byte> > im, uint C); 00233 private: 00234 std::vector<ColorCat> itsCats; 00235 }; 00236 00237 class GistPixelClassifier : public PixelClassifier 00238 { 00239 public: 00240 // calculates local gist in a sliding window around each pixel 00241 GistPixelClassifier(); 00242 00243 // void setVC(nub::ref<RawVisualCortex> VC) {itsVC = VC;} 00244 void learnInput(Image<PixRGB<byte> > im, Image<uint> labels); 00245 double classifyAt(Image<PixRGB<byte> > im, uint C); 00246 00247 private: 00248 // nub::ref<RawVisualCortex> itsVC; 00249 }; 00250 00251 //////////////////////////////////////////////////// 00252 // free functions 00253 00254 // just the manhattan distance (L1) 00255 uint inline colorDistance(PixRGB<byte> A, PixRGB<byte>B) { 00256 return abs(A[0]-B[0])+abs(A[1]-B[1])+abs(A[2]-B[2]); 00257 } 00258 00259 // euclidean distance in (L2) space 00260 double inline colorL2Distance(PixRGB<byte> A, PixRGB<byte>B) { 00261 return sqrt((A[0]-B[0])*(A[0]-B[0])+(A[1]-B[1])*(A[1]-B[1])+(A[2]-B[2])*(A[2]-B[2])); 00262 } 00263 00264 00265 std::string convertToString(const QuadNode &q); 00266 std::string convertToString(const QuadNode::NodeState& n); 00267 00268 #endif // QUADTREE_H_DEFINED 00269 00270 // ###################################################################### 00271 /* So things look consistent in everyone's emacs... */ 00272 /* Local Variables: */ 00273 /* indent-tabs-mode: nil */ 00274 /* End: */