nn.h

00001 #ifndef __NN_
00002 #define __NN_
00003 
00004 #include "StreamOutput.h"
00005 #include "Framework.h"
00006 
00007 #include <math.h>
00008 #include <assert.h>
00009 
00010 using namespace std;
00011 
00038 class NN : public Framework
00039 {
00040 public:
00041     NN();
00042     ~NN();
00043 
00044     // input/output data definition methods
00045     void setNrTargets ( int n );
00046     void setNrInputs ( int n );
00047     void setNrExamplesTrain ( int n );
00048     void setNrExamplesProbe ( int n );
00049     void setTrainInputs ( REAL* inputs );
00050     void setTrainTargets ( REAL* targets );
00051     void setProbeInputs ( REAL* inputs );
00052     void setProbeTargets ( REAL* targets );
00053 
00054     // learn parameters
00055     void setInitWeightFactor ( REAL factor );
00056     void setLearnrate ( REAL learnrate );
00057     void setLearnrateMinimum ( REAL learnrateMin );
00058     void setLearnrateSubtractionValueAfterEverySample ( REAL learnrateDecreaseRate );
00059     void setLearnrateSubtractionValueAfterEveryEpoch ( REAL learnrateDecreaseRate );
00060     void setMomentum ( REAL momentum );
00061     void setWeightDecay ( REAL weightDecay );
00062     void setBatchSize ( int size );
00063     void setMinUpdateErrorBound ( REAL minUpdateBound );
00064     void setMaxEpochs ( int epochs );
00065     void setRPROPPosNeg ( REAL etaPos, REAL etaNeg );
00066     void setRPROPMinMaxUpdate ( REAL min, REAL max );
00067     void setL1Regularization ( bool en );
00068     void initNNWeights ( time_t seed );
00069     void enableErrorFunctionMAE ( bool en );
00070     void setActivationFunctionType( int type );
00071 
00072     // set net inner stucture
00073     void setNNStructure ( int nrLayer, int* neuronsPerLayer );
00074 
00075     // all around training
00076     void printLearnrate();
00077     void setScaleOffset ( REAL scale, REAL offset );
00078     void setNormalTrainStopping ( bool en );
00079     void setGlobalEpochs ( int e );
00080     void enableRPROP ( bool en );
00081     void useBLASforTraining ( bool enable );
00082     void trainOneEpoch();
00083     int trainNN();
00084     REAL getRMSETrain();
00085     REAL getRMSEProbe();
00086     void predictSingleInput ( REAL* input, REAL* output );
00087     REAL* getWeightPtr();
00088     void setWeights ( REAL* w );
00089     int getNrWeights();
00090     double m_sumSquaredError;
00091     double m_sumSquaredErrorSamples;
00092 
00093     // access to weights and neuron outputs of the net
00094     int getWeightIndex ( int layer, int neuron, int weight );
00095     int getBiasIndex ( int layer, int neuron );
00096     int getOutputIndex ( int layer, int neuron );
00097 
00098 private:
00099     // training
00100     void saveWeights();
00101     REAL calcRMSE ( REAL* inputs, REAL* targets, int examples );
00102     void forwardCalculation ( REAL* input );
00103     void forwardCalculationBLAS ( REAL* input );
00104     void backpropBLAS ( REAL* input, REAL* target );
00105     void backprop ( REAL* input, REAL* target );
00106 
00107     // weight init
00108     REAL getInitWeight ( int fanIn );
00109 
00110     // data description
00111     int m_nrTargets;
00112     int m_nrInputs;
00113     int m_nrExamplesTrain;
00114     int m_nrExamplesProbe;
00115     REAL* m_inputsTrain;
00116     REAL* m_inputsProbe;
00117     REAL* m_targetsTrain;
00118     REAL* m_targetsProbe;
00119 
00120     // learn params
00121     REAL m_initWeightFactor;
00122     int m_globalEpochs;
00123     REAL m_RPROP_etaPos;
00124     REAL m_RPROP_etaNeg;
00125     REAL m_RPROP_updateMin;
00126     REAL m_RPROP_updateMax;
00127     REAL m_learnRate;
00128     REAL m_learnRateMin;
00129     REAL m_learnrateDecreaseRate;
00130     REAL m_learnrateDecreaseRateEpoch;
00131     REAL m_momentum;
00132     REAL m_weightDecay;
00133     REAL m_minUpdateBound;
00134     int m_batchSize;
00135     int m_activationFunctionType;
00136 
00137     // training
00138     REAL m_scaleOutputs;
00139     REAL m_offsetOutputs;
00140     int m_maxEpochs;
00141     bool m_useBLAS;
00142     bool m_enableRPROP;
00143     bool m_normalTrainStopping;
00144     bool m_enableL1Regularization;
00145     bool m_errorFunctionMAE;
00146 
00147     // net description, inner structure
00148     int m_nrLayer;
00149     int* m_neuronsPerLayer;
00150     int m_nrWeights;
00151     int m_nrOutputs;
00152     int* m_nrLayWeights;
00153     int* m_nrLayWeightOffsets;
00154     REAL* m_outputs;
00155     REAL* m_outputsTmp;
00156     REAL* m_derivates;
00157     REAL* m_d1;
00158     REAL* m_weights;
00159     REAL* m_weightsTmp0;
00160     REAL* m_weightsTmp1;
00161     REAL* m_weightsTmp2;
00162     REAL* m_weightsBatchUpdate;
00163     REAL* m_weightsOld;
00164     REAL* m_weightsOldOld;
00165     REAL* m_deltaW;
00166     REAL* m_deltaWOld;
00167     REAL* m_adaptiveRPROPlRate;
00168 
00169 };
00170 
00171 #endif

Generated on Tue Jan 26 09:20:59 2010 for ELF by  doxygen 1.5.8