DecisionTree.H
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #ifndef DECISIONTREE_H_DEFINED
00038 #define DECISIONTREE_H_DEFINED
00039
00040 #include "rutz/shared_ptr.h"
00041 #include <cstdlib>
00042 #include <deque>
00043 #include <map>
00044 #include <list>
00045 #include <vector>
00046 #include <iostream>
00047 #include <fstream>
00048
00049 class DecisionNode;
00050
00051 class DecisionTree
00052 {
00053 public:
00054 DecisionTree(int maxSplits=1);
00055
00056 void train(const std::vector<std::vector<float> >& data, const std::vector<int>& labels,std::vector<float> weights=std::vector<float>());
00057
00058 std::vector<int> predict(const std::vector<std::vector<float> >& data, std::vector<float> weights=std::vector<float>());
00059 void printTree();
00060 std::deque<rutz::shared_ptr<DecisionNode> > getNodes();
00061 void addNode(rutz::shared_ptr<DecisionNode> node);
00062 private:
00063 size_t itsMaxSplits;
00064 std::deque<rutz::shared_ptr<DecisionNode> > itsNodes;
00065 };
00066
00067
00068 class DecisionNode
00069 {
00070 public:
00071 DecisionNode();
00072 std::vector<int> decide(const std::vector<std::vector<float> >& data);
00073 float split(const std::vector<std::vector<float> >& data, const std::vector<int>& labels, const std::vector<float>& weights, rutz::shared_ptr<DecisionNode>& left, rutz::shared_ptr<DecisionNode>& right, const rutz::shared_ptr<DecisionNode> parent=rutz::shared_ptr<DecisionNode>(NULL));
00074 size_t getDim();
00075 bool isValid();
00076 int printNode(std::string& output, int depth=0);
00077 void writeNode(std::ostream& outstream, bool needEnd=true);
00078 rutz::shared_ptr<DecisionNode> readNode(std::istream& instream);
00079 void setDim(size_t dim);
00080 void setLeaf(bool isLeaf);
00081 void setParent(rutz::shared_ptr<DecisionNode> parent);
00082 void setLeftConstraint(float constraint);
00083 void setRightConstraint(float constraint);
00084 void setClass(int classId);
00085 int getClass();
00086 private:
00087 int itsDim;
00088
00089
00090 bool itsLeaf;
00091 float itsLeftConstraint;
00092 float itsRightConstraint;
00093 int itsClass;
00094
00095 rutz::shared_ptr<DecisionNode> itsParent;
00096 };
00097
00098
00099 #endif
00100
00101
00102
00103
00104
00105
00106
00107