QuadTree.H

Go to the documentation of this file.
00001 /*!@file Learn/QuadTree.H QuadTree Multi-Class Classifier */
00002 // //////////////////////////////////////////////////////////////////// //
00003 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the //
00004 // University of Southern California (USC) and the iLab at USC.         //
00005 // See http://iLab.usc.edu for information about this project.          //
00006 // //////////////////////////////////////////////////////////////////// //
00007 // Major portions of the iLab Neuromorphic Vision Toolkit are protected //
00008 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency //
00009 // in Visual Environments, and Applications'' by Christof Koch and      //
00010 // Laurent Itti, California Institute of Technology, 2001 (patent       //
00011 // pending; application number 09/912,225 filed July 23, 2001; see      //
00012 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status).     //
00013 // //////////////////////////////////////////////////////////////////// //
00014 // This file is part of the iLab Neuromorphic Vision C++ Toolkit.       //
00015 //                                                                      //
00016 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can   //
00017 // redistribute it and/or modify it under the terms of the GNU General  //
00018 // Public License as published by the Free Software Foundation; either  //
00019 // version 2 of the License, or (at your option) any later version.     //
00020 //                                                                      //
00021 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope  //
00022 // that it will be useful, but WITHOUT ANY WARRANTY; without even the   //
00023 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR      //
00024 // PURPOSE.  See the GNU General Public License for more details.       //
00025 //                                                                      //
00026 // You should have received a copy of the GNU General Public License    //
00027 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write   //
00028 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330,   //
00029 // Boston, MA 02111-1307 USA.                                           //
00030 // //////////////////////////////////////////////////////////////////// //
00031 //
00032 // Primary maintainer for this file: John Shen <shenjohn@usc.edu>
00033 // $HeadURL$
00034 // $Id$
00035 //
00036 //////////////////////////////////////////////////////////////////////////
00037 //
00038 // Implementation of the segmentation algorithm described in:
00039 //
00040 // Recursive Segmentation and Recognition Templates for 2D Parsing
00041 // Leo Zhu, Yuanhao Chen, Yuan Lin, Chenxi Lin, Alan Yuille
00042 // Advances in Neural Information Processing Systems, 2008
00043 // 
00044 
00045 #ifndef QUADTREE_H_DEFINED
00046 #define QUADTREE_H_DEFINED
00047 
00048 #include "Channels/RawVisualCortex.H"
00049 #include "Image/Image.H"
00050 #include "Image/PixelsTypes.H"
00051 #include "Image/Point2D.H"
00052 #include "Image/Rectangle.H"
00053 #include "rutz/shared_ptr.h"
00054 #include "Util/log.H"
00055 #include "Util/StringConversions.H"
00056 #include "nub/ref.h"
00057 #include <limits>
00058 #include <math.h>
00059 #include <stdio.h>
00060 
00061 //! QuadTree for segmentation and recognition from Leo Zhu's 2008 NIPS paper
00062 class PixelClassifier;
00063 class ColorPixelClassifier;
00064 
00065 typedef std::vector<Point2D<int> > Neighborhood;
00066 class QuadNode
00067 {
00068 public:
00069   struct NodeState {
00070     uint segTemplate; //region partition
00071     std::vector<byte> objLabels; //3 object labels
00072     double E;
00073     bool evaled;
00074     
00075     NodeState(uint st) : segTemplate(st) {}
00076     NodeState(uint st, std::vector<byte> oL) : segTemplate(st), objLabels(oL) {}
00077     NodeState(uint st, byte l1, byte l2, byte l3) : 
00078       segTemplate(st) {
00079       objLabels.push_back(l1); 
00080       objLabels.push_back(l2);
00081       objLabels.push_back(l3);
00082     }
00083     
00084     bool isSingleton() const {return segTemplate==0;}
00085     bool isDoubleton() const {return segTemplate==0 || (segTemplate >= 3 && segTemplate <= 5);}
00086   };
00087 
00088   QuadNode();
00089   QuadNode(rutz::shared_ptr<QuadNode> q);
00090   QuadNode(rutz::shared_ptr<QuadNode> q, NodeState n);
00091   Image<byte> getSegImage() {if(itsIsStale) refreshSegImage(); return itsSegImage;}
00092   Image<byte> getChildSegImage();
00093   Image<PixRGB<byte> > getColorizedSegImage();
00094   Image<PixRGB<byte> > getColorizedChildSegImage();
00095   byte getObjLabelAt(Point2D<int> loc) const;
00096 
00097   bool isLeaf() const {return itsIsLeaf;}
00098 
00099   void setArea(Rectangle r) {itsArea = r;}
00100   Rectangle getArea() const {return itsArea;}
00101 
00102   void setDepth(uint l) {itsDepth = l;}
00103   uint getDepth() const {return itsDepth;}
00104 
00105   void setLabels(std::vector<byte> v) {itsState.objLabels = v; itsIsStale = true;}
00106   void setLabel(uint i, byte v) {itsState.objLabels[i] = v; itsIsStale = true;}
00107   std::vector<byte> getLabels() const {return itsState.objLabels;}
00108   
00109   void setSegTemplate(byte c) {itsState.segTemplate = c; itsIsStale = true;}
00110   uint getSegTemplate() const {return itsState.segTemplate;}
00111 
00112   void setState(NodeState n) {itsState = n; itsIsStale = true;}
00113   NodeState getState() const {return itsState;}
00114 
00115   void storeEnergy(double e) {itsState.E = e; itsState.evaled = true;}
00116   double getEnergy() const {return itsState.E;}
00117   bool energySaved() const {return itsState.evaled;}
00118 
00119   void addChild(rutz::shared_ptr<QuadNode> n) {
00120     itsIsLeaf = false;
00121     if(itsChildren.size() < 4) itsChildren.push_back(n);
00122     else LINFO("Tried to add node in excess of 4");
00123   }
00124 
00125   Point2D<int> convertToGlobal(Point2D<int> p) const
00126   {return Point2D<int>(p.i + itsArea.left(), p.j + itsArea.top());}
00127 
00128   Point2D<int> convertToLocal(Point2D<int> p) const
00129   {return Point2D<int>(p.i - itsArea.left(), p.j - itsArea.top());}
00130 
00131   rutz::shared_ptr<QuadNode> getChild(uint i) const {
00132     return itsChildren[i];
00133   }
00134 
00135 private:
00136   void refreshSegImage();
00137   Image<PixRGB<byte> > colorLabels(Image<byte> im) const;
00138 
00139   uint itsDepth;
00140 
00141   // NB: if the node is a leaf, 
00142   // then the shared_ptrs to children are uninitialized
00143   bool itsIsLeaf;
00144   bool itsIsStale;
00145   
00146   Rectangle itsArea;
00147 
00148   NodeState itsState;
00149   Image<byte> itsSegImage;
00150   rutz::shared_ptr<QuadNode> itsParent;
00151   
00152   // top left, top right, bottom left, bottom right
00153   // NB: the tree structure is static once initialized
00154   std::vector<rutz::shared_ptr<QuadNode> > itsChildren;   
00155 };
00156 
00157 class QuadTree
00158 {
00159 public:
00160   // construct a quad tree on a given depth and image size
00161   QuadTree(int Nlevels, Dims d); 
00162   QuadTree(int Nlevels, Image<PixRGB<byte> > im);
00163 
00164   void initAlphas() {double weights[] = {1, 2.5, 2.5}; itsAlphas.assign(weights,weights+3);}
00165   void addTreeUnder(rutz::shared_ptr<QuadNode> parent,int Nlevels, Rectangle r);
00166 
00167   void cacheClassifierResult();
00168   double evaluateClassifierAt(rutz::shared_ptr<QuadNode> q) const; // classifier accuracy
00169 
00170   double evaluateCohesionAt(rutz::shared_ptr<QuadNode> q) const; //cohesion - similar pixels in similar groups in one region
00171 
00172   double evaluateCorrespondenceAt(rutz::shared_ptr<QuadNode> q) const; //labeling is the same at both layers, between a parent and its children
00173 
00174   double evaluateTotalEnergyAt(rutz::shared_ptr<QuadNode> q) const;
00175   //  byte getClassOutputAt(Point2D<int> P) const;
00176   void printTree() const;
00177   std::string writeTree() const;
00178   
00179   void setClassifier(rutz::shared_ptr<PixelClassifier> cc) {itsClassifier = cc;}
00180   rutz::shared_ptr<PixelClassifier> getClassifier() const {return itsClassifier;}
00181 
00182   rutz::shared_ptr<QuadNode> getRootNode() const {return itsRootNode;}
00183 
00184   std::vector<QuadNode::NodeState> generateProposalsAt(rutz::shared_ptr<QuadNode> q, double thresh);
00185 private:
00186 
00187   uint itsNumLevels;
00188   Dims itsDims;
00189   Image<PixRGB<byte> > itsImage;
00190 
00191   rutz::shared_ptr<PixelClassifier> itsClassifier;
00192   std::vector<Image<double> > itsClassifierOutput;
00193   Image<byte> itsBestClassOutput;
00194   rutz::shared_ptr<QuadNode> itsRootNode;
00195   std::deque<rutz::shared_ptr<QuadNode> > itsNodes;  
00196   std::vector<double> itsAlphas;
00197 };
00198 
00199 
00200 // base class for a pixel-based classifier
00201 class PixelClassifier 
00202 {
00203 public:
00204   PixelClassifier() {itsNumClasses=0;}
00205 
00206   uint getNumClasses() const {return itsNumClasses;}
00207 
00208   virtual void learnInput(Image<PixRGB<byte> > im, Image<uint> labels) = 0;
00209   virtual double classifyAt(Image<PixRGB<byte> > im, uint C) = 0;
00210 protected:
00211   uint itsNumClasses;
00212 };
00213 
00214 class ColorPixelClassifier : public PixelClassifier
00215 {
00216 public:
00217   struct ColorCat {
00218     PixRGB<byte> color;
00219     double sig_cdist;
00220     
00221     ColorCat() {}
00222     ColorCat(PixRGB<byte> c, double d) {color = c; sig_cdist = d;}
00223   };
00224 
00225   ColorPixelClassifier();
00226 
00227   void learnInput(Image<PixRGB<byte> > im, Image<uint> labels) {}
00228   void addCategory(ColorCat cc) {itsCats.push_back(cc);
00229     itsNumClasses++;}
00230   
00231   // just uses color of center pixel for now
00232   double classifyAt(Image<PixRGB<byte> > im, uint C); 
00233 private:
00234   std::vector<ColorCat> itsCats;
00235 };
00236 
00237 class GistPixelClassifier : public PixelClassifier
00238 {
00239 public:  
00240   // calculates local gist in a sliding window around each pixel
00241   GistPixelClassifier();
00242 
00243   //  void setVC(nub::ref<RawVisualCortex> VC) {itsVC = VC;}
00244   void learnInput(Image<PixRGB<byte> > im, Image<uint> labels);
00245   double classifyAt(Image<PixRGB<byte> > im, uint C);
00246   
00247 private:
00248   //  nub::ref<RawVisualCortex> itsVC;
00249 };
00250 
00251   ////////////////////////////////////////////////////
00252 // free functions 
00253 
00254 // just the manhattan distance (L1)
00255 uint inline colorDistance(PixRGB<byte> A, PixRGB<byte>B) {
00256   return abs(A[0]-B[0])+abs(A[1]-B[1])+abs(A[2]-B[2]);
00257 }
00258 
00259 // euclidean distance in (L2) space
00260 double inline colorL2Distance(PixRGB<byte> A, PixRGB<byte>B) {
00261   return sqrt((A[0]-B[0])*(A[0]-B[0])+(A[1]-B[1])*(A[1]-B[1])+(A[2]-B[2])*(A[2]-B[2]));
00262 }
00263 
00264 
00265 std::string convertToString(const QuadNode &q);
00266 std::string convertToString(const QuadNode::NodeState& n); 
00267 
00268 #endif // QUADTREE_H_DEFINED
00269 
00270 // ######################################################################
00271 /* So things look consistent in everyone's emacs... */
00272 /* Local Variables: */
00273 /* indent-tabs-mode: nil */
00274 /* End: */
Generated on Sun May 8 08:40:58 2011 for iLab Neuromorphic Vision Toolkit by  doxygen 1.6.3