GBDT.h
00001 #ifndef __GBDT_
00002 #define __GBDT_
00003
00004 #include <deque>
00005
00006 #include "StandardAlgorithm.h"
00007 #include "Framework.h"
00008
00012 typedef struct node_
00013 {
00014 int m_featureNr;
00015 REAL m_value;
00016 struct node_* m_toSmallerEqual;
00017 struct node_* m_toLarger;
00018 int* m_trainSamples;
00019 int m_nSamples;
00020 } node;
00021
00025 typedef struct nodeReduced_
00026 {
00027 node* m_node;
00028 uint m_size;
00029 } nodeReduced;
00030
00042 class GBDT : public StandardAlgorithm, public Framework
00043 {
00044 public:
00045 GBDT();
00046 ~GBDT();
00047
00048 virtual void modelInit();
00049 virtual void modelUpdate ( REAL* input, REAL* target, uint nSamples, uint crossRun );
00050 virtual void predictAllOutputs ( REAL* rawInputs, REAL* outputs, uint nSamples, uint crossRun );
00051 virtual void readSpecificMaps();
00052 virtual void saveWeights ( int cross );
00053 virtual void loadWeights ( int cross );
00054 virtual void loadMetaWeights ( int cross );
00055
00056 static string templateGenerator ( int id, string preEffect, int nameID, bool blendStop );
00057
00058 private:
00059 void trainSingleTree ( node* n, deque<nodeReduced> &largestNodes, REAL* input, REAL* inputTmp, REAL* inputTargetsSort, REAL* singleTarget, bool* usedFeatures, int nSamples, int* sortIndex, int* radixTmp0, REAL* radixTmp1 );
00060 void cleanTree ( node* n );
00061 REAL predictSingleTree ( node* n, REAL* input );
00062 void saveTreeRecursive ( node* n, fstream &f );
00063 void loadTreeRecursive ( node* n, fstream &f );
00064
00065
00066 node*** m_trees;
00067 int m_epoch;
00068 REAL** m_globalMean;
00069 REAL** m_treeTargets;
00070 double** m_validationPredictions;
00071
00072
00073 int m_featureSubspaceSize;
00074 int m_maxTreeLeafes;
00075 bool m_useOptSplitPoint;
00076 bool m_calculateGlobalMean;
00077 double m_lRate;
00078 };
00079
00080
00081 #endif