GBDT.h

00001 #ifndef __GBDT_
00002 #define __GBDT_
00003 
00004 #include <deque>
00005 
00006 #include "StandardAlgorithm.h"
00007 #include "Framework.h"
00008 
00012 typedef struct node_
00013 {
00014     int m_featureNr;                // decision on this feature
00015     REAL m_value;                   // the prediction value
00016     struct node_* m_toSmallerEqual; // pointer to node, if:  feature[m_featureNr] <=  m_value
00017     struct node_* m_toLarger;       // pointer to node, if:  feature[m_featureNr] > m_value
00018     int* m_trainSamples;            // a list of indices of the training samples in this node
00019     int m_nSamples;                 // the length of m_trainSamples
00020 } node;
00021 
00025 typedef struct nodeReduced_
00026 {
00027     node* m_node;
00028     uint m_size;
00029 } nodeReduced;
00030 
00042 class GBDT : public StandardAlgorithm, public Framework
00043 {
00044 public:
00045     GBDT();
00046     ~GBDT();
00047 
00048     virtual void modelInit();
00049     virtual void modelUpdate ( REAL* input, REAL* target, uint nSamples, uint crossRun );
00050     virtual void predictAllOutputs ( REAL* rawInputs, REAL* outputs, uint nSamples, uint crossRun );
00051     virtual void readSpecificMaps();
00052     virtual void saveWeights ( int cross );
00053     virtual void loadWeights ( int cross );
00054     virtual void loadMetaWeights ( int cross );
00055 
00056     static string templateGenerator ( int id, string preEffect, int nameID, bool blendStop );
00057 
00058 private:
00059     void trainSingleTree ( node* n, deque<nodeReduced> &largestNodes, REAL* input, REAL* inputTmp, REAL* inputTargetsSort, REAL* singleTarget, bool* usedFeatures, int nSamples, int* sortIndex, int* radixTmp0, REAL* radixTmp1 );
00060     void cleanTree ( node* n );
00061     REAL predictSingleTree ( node* n, REAL* input );
00062     void saveTreeRecursive ( node* n, fstream &f );
00063     void loadTreeRecursive ( node* n, fstream &f );
00064 
00065     // tree members
00066     node*** m_trees;   // [nCross][target][boostingStep]
00067     int m_epoch;
00068     REAL** m_globalMean;  // [nCross][target]
00069     REAL** m_treeTargets;  // [nCross][target x samples]
00070     double** m_validationPredictions;  // [nCross][nClass*nDomain x m_probeSize]
00071 
00072     // dsc file
00073     int m_featureSubspaceSize;
00074     int m_maxTreeLeafes;
00075     bool m_useOptSplitPoint;
00076     bool m_calculateGlobalMean;
00077     double m_lRate;
00078 };
00079 
00080 
00081 #endif

Generated on Tue Jan 26 09:20:59 2010 for ELF by  doxygen 1.5.8