QuadTree.H
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045 #ifndef QUADTREE_H_DEFINED
00046 #define QUADTREE_H_DEFINED
00047
00048 #include "Channels/RawVisualCortex.H"
00049 #include "Image/Image.H"
00050 #include "Image/PixelsTypes.H"
00051 #include "Image/Point2D.H"
00052 #include "Image/Rectangle.H"
00053 #include "rutz/shared_ptr.h"
00054 #include "Util/log.H"
00055 #include "Util/StringConversions.H"
00056 #include "nub/ref.h"
00057 #include <limits>
00058 #include <math.h>
00059 #include <stdio.h>
00060
00061
00062 class PixelClassifier;
00063 class ColorPixelClassifier;
00064
00065 typedef std::vector<Point2D<int> > Neighborhood;
00066 class QuadNode
00067 {
00068 public:
00069 struct NodeState {
00070 uint segTemplate;
00071 std::vector<byte> objLabels;
00072 double E;
00073 bool evaled;
00074
00075 NodeState(uint st) : segTemplate(st) {}
00076 NodeState(uint st, std::vector<byte> oL) : segTemplate(st), objLabels(oL) {}
00077 NodeState(uint st, byte l1, byte l2, byte l3) :
00078 segTemplate(st) {
00079 objLabels.push_back(l1);
00080 objLabels.push_back(l2);
00081 objLabels.push_back(l3);
00082 }
00083
00084 bool isSingleton() const {return segTemplate==0;}
00085 bool isDoubleton() const {return segTemplate==0 || (segTemplate >= 3 && segTemplate <= 5);}
00086 };
00087
00088 QuadNode();
00089 QuadNode(rutz::shared_ptr<QuadNode> q);
00090 QuadNode(rutz::shared_ptr<QuadNode> q, NodeState n);
00091 Image<byte> getSegImage() {if(itsIsStale) refreshSegImage(); return itsSegImage;}
00092 Image<byte> getChildSegImage();
00093 Image<PixRGB<byte> > getColorizedSegImage();
00094 Image<PixRGB<byte> > getColorizedChildSegImage();
00095 byte getObjLabelAt(Point2D<int> loc) const;
00096
00097 bool isLeaf() const {return itsIsLeaf;}
00098
00099 void setArea(Rectangle r) {itsArea = r;}
00100 Rectangle getArea() const {return itsArea;}
00101
00102 void setDepth(uint l) {itsDepth = l;}
00103 uint getDepth() const {return itsDepth;}
00104
00105 void setLabels(std::vector<byte> v) {itsState.objLabels = v; itsIsStale = true;}
00106 void setLabel(uint i, byte v) {itsState.objLabels[i] = v; itsIsStale = true;}
00107 std::vector<byte> getLabels() const {return itsState.objLabels;}
00108
00109 void setSegTemplate(byte c) {itsState.segTemplate = c; itsIsStale = true;}
00110 uint getSegTemplate() const {return itsState.segTemplate;}
00111
00112 void setState(NodeState n) {itsState = n; itsIsStale = true;}
00113 NodeState getState() const {return itsState;}
00114
00115 void storeEnergy(double e) {itsState.E = e; itsState.evaled = true;}
00116 double getEnergy() const {return itsState.E;}
00117 bool energySaved() const {return itsState.evaled;}
00118
00119 void addChild(rutz::shared_ptr<QuadNode> n) {
00120 itsIsLeaf = false;
00121 if(itsChildren.size() < 4) itsChildren.push_back(n);
00122 else LINFO("Tried to add node in excess of 4");
00123 }
00124
00125 Point2D<int> convertToGlobal(Point2D<int> p) const
00126 {return Point2D<int>(p.i + itsArea.left(), p.j + itsArea.top());}
00127
00128 Point2D<int> convertToLocal(Point2D<int> p) const
00129 {return Point2D<int>(p.i - itsArea.left(), p.j - itsArea.top());}
00130
00131 rutz::shared_ptr<QuadNode> getChild(uint i) const {
00132 return itsChildren[i];
00133 }
00134
00135 private:
00136 void refreshSegImage();
00137 Image<PixRGB<byte> > colorLabels(Image<byte> im) const;
00138
00139 uint itsDepth;
00140
00141
00142
00143 bool itsIsLeaf;
00144 bool itsIsStale;
00145
00146 Rectangle itsArea;
00147
00148 NodeState itsState;
00149 Image<byte> itsSegImage;
00150 rutz::shared_ptr<QuadNode> itsParent;
00151
00152
00153
00154 std::vector<rutz::shared_ptr<QuadNode> > itsChildren;
00155 };
00156
00157 class QuadTree
00158 {
00159 public:
00160
00161 QuadTree(int Nlevels, Dims d);
00162 QuadTree(int Nlevels, Image<PixRGB<byte> > im);
00163
00164 void initAlphas() {double weights[] = {1, 2.5, 2.5}; itsAlphas.assign(weights,weights+3);}
00165 void addTreeUnder(rutz::shared_ptr<QuadNode> parent,int Nlevels, Rectangle r);
00166
00167 void cacheClassifierResult();
00168 double evaluateClassifierAt(rutz::shared_ptr<QuadNode> q) const;
00169
00170 double evaluateCohesionAt(rutz::shared_ptr<QuadNode> q) const;
00171
00172 double evaluateCorrespondenceAt(rutz::shared_ptr<QuadNode> q) const;
00173
00174 double evaluateTotalEnergyAt(rutz::shared_ptr<QuadNode> q) const;
00175
00176 void printTree() const;
00177 std::string writeTree() const;
00178
00179 void setClassifier(rutz::shared_ptr<PixelClassifier> cc) {itsClassifier = cc;}
00180 rutz::shared_ptr<PixelClassifier> getClassifier() const {return itsClassifier;}
00181
00182 rutz::shared_ptr<QuadNode> getRootNode() const {return itsRootNode;}
00183
00184 std::vector<QuadNode::NodeState> generateProposalsAt(rutz::shared_ptr<QuadNode> q, double thresh);
00185 private:
00186
00187 uint itsNumLevels;
00188 Dims itsDims;
00189 Image<PixRGB<byte> > itsImage;
00190
00191 rutz::shared_ptr<PixelClassifier> itsClassifier;
00192 std::vector<Image<double> > itsClassifierOutput;
00193 Image<byte> itsBestClassOutput;
00194 rutz::shared_ptr<QuadNode> itsRootNode;
00195 std::deque<rutz::shared_ptr<QuadNode> > itsNodes;
00196 std::vector<double> itsAlphas;
00197 };
00198
00199
00200
00201 class PixelClassifier
00202 {
00203 public:
00204 PixelClassifier() {itsNumClasses=0;}
00205
00206 uint getNumClasses() const {return itsNumClasses;}
00207
00208 virtual void learnInput(Image<PixRGB<byte> > im, Image<uint> labels) = 0;
00209 virtual double classifyAt(Image<PixRGB<byte> > im, uint C) = 0;
00210 protected:
00211 uint itsNumClasses;
00212 };
00213
00214 class ColorPixelClassifier : public PixelClassifier
00215 {
00216 public:
00217 struct ColorCat {
00218 PixRGB<byte> color;
00219 double sig_cdist;
00220
00221 ColorCat() {}
00222 ColorCat(PixRGB<byte> c, double d) {color = c; sig_cdist = d;}
00223 };
00224
00225 ColorPixelClassifier();
00226
00227 void learnInput(Image<PixRGB<byte> > im, Image<uint> labels) {}
00228 void addCategory(ColorCat cc) {itsCats.push_back(cc);
00229 itsNumClasses++;}
00230
00231
00232 double classifyAt(Image<PixRGB<byte> > im, uint C);
00233 private:
00234 std::vector<ColorCat> itsCats;
00235 };
00236
00237 class GistPixelClassifier : public PixelClassifier
00238 {
00239 public:
00240
00241 GistPixelClassifier();
00242
00243
00244 void learnInput(Image<PixRGB<byte> > im, Image<uint> labels);
00245 double classifyAt(Image<PixRGB<byte> > im, uint C);
00246
00247 private:
00248
00249 };
00250
00251
00252
00253
00254
00255 uint inline colorDistance(PixRGB<byte> A, PixRGB<byte>B) {
00256 return abs(A[0]-B[0])+abs(A[1]-B[1])+abs(A[2]-B[2]);
00257 }
00258
00259
00260 double inline colorL2Distance(PixRGB<byte> A, PixRGB<byte>B) {
00261 return sqrt((A[0]-B[0])*(A[0]-B[0])+(A[1]-B[1])*(A[1]-B[1])+(A[2]-B[2])*(A[2]-B[2]));
00262 }
00263
00264
00265 std::string convertToString(const QuadNode &q);
00266 std::string convertToString(const QuadNode::NodeState& n);
00267
00268 #endif // QUADTREE_H_DEFINED
00269
00270
00271
00272
00273
00274