StandardAlgorithm.h
00001 #ifndef __STANDARD_ALGORITHM_
00002 #define __STANDARD_ALGORITHM_
00003
00004 #include "Algorithm.h"
00005 #include "AutomaticParameterTuner.h"
00006 #include "BlendStopping.h"
00007 #include "Framework.h"
00008 #include "AUC.h"
00009 #include "DatasetReader.h"
00010
00011 #include <math.h>
00012 #include <stdio.h>
00013 #include <stdlib.h>
00014
00015 using namespace std;
00016
00023 class StandardAlgorithm : public Algorithm, public AutomaticParameterTuner, public AUC, public Framework
00024 {
00025 public:
00026 StandardAlgorithm();
00027 ~StandardAlgorithm();
00028
00029 virtual double calcRMSEonProbe();
00030 virtual double calcRMSEonBlend();
00031 void saveBestPrediction();
00032
00033 virtual void setPredictionMode ( int cross );
00034 virtual double train();
00035
00036
00037 virtual void predictMultipleOutputs ( REAL* rawInput, REAL* effect, REAL* output, int* label, int nSamples, int crossRun );
00038
00039
00040 virtual void modelInit() = 0;
00041 virtual void modelUpdate ( REAL* input, REAL* target, uint nSamples, uint crossRun ) = 0;
00042 virtual void predictAllOutputs ( REAL* rawInputs, REAL* outputs, uint nSamples, uint crossRun ) = 0;
00043 virtual void readSpecificMaps() = 0;
00044 virtual void saveWeights ( int cross ) = 0;
00045 virtual void loadWeights ( int cross ) = 0;
00046 virtual void loadMetaWeights ( int cross ) = 0;
00047
00048 protected:
00049
00050 void init();
00051 void readMaps();
00052
00053 void calculateFullPrediction();
00054 void writeFullPrediction(int nSamples);
00055
00056 BlendStopping* m_blendStop;
00057
00058
00059 vector<int*> paramEpochValues;
00060 vector<string> paramEpochNames;
00061 vector<double*> paramDoubleValues;
00062 vector<string> paramDoubleNames;
00063 vector<int*> paramIntValues;
00064 vector<string> paramIntNames;
00065 double m_maxSwing;
00066
00067
00068 REAL* m_crossValidationPrediction;
00069 REAL* m_prediction;
00070 REAL* m_predictionBest;
00071 REAL** m_predictionProbe;
00072 REAL* m_singlePrediction;
00073 int* m_labelPrediction;
00074 int* m_wrongLabelCnt;
00075 REAL* m_outOfBagEstimate;
00076 int* m_outOfBagEstimateCnt;
00077
00078
00079 int m_maxTuninigEpochs;
00080 int m_minTuninigEpochs;
00081 bool m_enableClipping;
00082 bool m_enableTuneSwing;
00083 bool m_minimzeProbe;
00084 bool m_minimzeProbeClassificationError;
00085 bool m_minimzeBlend;
00086 bool m_minimzeBlendClassificationError;
00087 double m_initMaxSwing;
00088 string m_weightFile;
00089 string m_fullPrediction;
00090
00091 };
00092
00093
00094 #endif