StandardAlgorithm.h

00001 #ifndef __STANDARD_ALGORITHM_
00002 #define __STANDARD_ALGORITHM_
00003 
00004 #include "Algorithm.h"
00005 #include "AutomaticParameterTuner.h"
00006 #include "BlendStopping.h"
00007 #include "Framework.h"
00008 #include "AUC.h"
00009 #include "DatasetReader.h"
00010 
00011 #include <math.h>
00012 #include <stdio.h>
00013 #include <stdlib.h>
00014 
00015 using namespace std;
00016 
00023 class StandardAlgorithm : public Algorithm, public AutomaticParameterTuner, public AUC, public Framework
00024 {
00025 public:
00026     StandardAlgorithm();
00027     ~StandardAlgorithm();
00028 
00029     virtual double calcRMSEonProbe();
00030     virtual double calcRMSEonBlend();
00031     void saveBestPrediction();
00032 
00033     virtual void setPredictionMode ( int cross );
00034     virtual double train();
00035 
00036     //int predictOutput(REAL* rawInput, REAL* effect, REAL* output);
00037     virtual void predictMultipleOutputs ( REAL* rawInput, REAL* effect, REAL* output, int* label, int nSamples, int crossRun );
00038 
00039     // must be implemented in the particular algorithm
00040     virtual void modelInit() = 0;
00041     virtual void modelUpdate ( REAL* input, REAL* target, uint nSamples, uint crossRun ) = 0;
00042     virtual void predictAllOutputs ( REAL* rawInputs, REAL* outputs,  uint nSamples, uint crossRun ) = 0;
00043     virtual void readSpecificMaps() = 0;
00044     virtual void saveWeights ( int cross ) = 0;
00045     virtual void loadWeights ( int cross ) = 0;
00046     virtual void loadMetaWeights ( int cross ) = 0;
00047 
00048 protected:
00049 
00050     void init();
00051     void readMaps();
00052 
00053     void calculateFullPrediction();
00054     void writeFullPrediction(int nSamples);
00055 
00056     BlendStopping* m_blendStop;
00057 
00058     // params
00059     vector<int*> paramEpochValues;
00060     vector<string> paramEpochNames;
00061     vector<double*> paramDoubleValues;
00062     vector<string> paramDoubleNames;
00063     vector<int*> paramIntValues;
00064     vector<string> paramIntNames;
00065     double m_maxSwing;
00066 
00067     // tmp fields
00068     REAL* m_crossValidationPrediction;
00069     REAL* m_prediction;
00070     REAL* m_predictionBest;
00071     REAL** m_predictionProbe;
00072     REAL* m_singlePrediction;
00073     int* m_labelPrediction;
00074     int* m_wrongLabelCnt;
00075     REAL* m_outOfBagEstimate;
00076     int* m_outOfBagEstimateCnt;
00077 
00078     // dsc file
00079     int m_maxTuninigEpochs;
00080     int m_minTuninigEpochs;
00081     bool m_enableClipping;
00082     bool m_enableTuneSwing;
00083     bool m_minimzeProbe;
00084     bool m_minimzeProbeClassificationError;
00085     bool m_minimzeBlend;
00086     bool m_minimzeBlendClassificationError;
00087     double m_initMaxSwing;
00088     string m_weightFile;
00089     string m_fullPrediction;
00090 
00091 };
00092 
00093 
00094 #endif

Generated on Tue Jan 26 09:20:59 2010 for ELF by  doxygen 1.5.8