nn.h
00001 #ifndef __NN_
00002 #define __NN_
00003
00004 #include "StreamOutput.h"
00005 #include "Framework.h"
00006
00007 #include <math.h>
00008 #include <assert.h>
00009
00010 using namespace std;
00011
00038 class NN : public Framework
00039 {
00040 public:
00041 NN();
00042 ~NN();
00043
00044
00045 void setNrTargets ( int n );
00046 void setNrInputs ( int n );
00047 void setNrExamplesTrain ( int n );
00048 void setNrExamplesProbe ( int n );
00049 void setTrainInputs ( REAL* inputs );
00050 void setTrainTargets ( REAL* targets );
00051 void setProbeInputs ( REAL* inputs );
00052 void setProbeTargets ( REAL* targets );
00053
00054
00055 void setInitWeightFactor ( REAL factor );
00056 void setLearnrate ( REAL learnrate );
00057 void setLearnrateMinimum ( REAL learnrateMin );
00058 void setLearnrateSubtractionValueAfterEverySample ( REAL learnrateDecreaseRate );
00059 void setLearnrateSubtractionValueAfterEveryEpoch ( REAL learnrateDecreaseRate );
00060 void setMomentum ( REAL momentum );
00061 void setWeightDecay ( REAL weightDecay );
00062 void setBatchSize ( int size );
00063 void setMinUpdateErrorBound ( REAL minUpdateBound );
00064 void setMaxEpochs ( int epochs );
00065 void setRPROPPosNeg ( REAL etaPos, REAL etaNeg );
00066 void setRPROPMinMaxUpdate ( REAL min, REAL max );
00067 void setL1Regularization ( bool en );
00068 void initNNWeights ( time_t seed );
00069 void enableErrorFunctionMAE ( bool en );
00070 void setActivationFunctionType( int type );
00071
00072
00073 void setNNStructure ( int nrLayer, int* neuronsPerLayer );
00074
00075
00076 void printLearnrate();
00077 void setScaleOffset ( REAL scale, REAL offset );
00078 void setNormalTrainStopping ( bool en );
00079 void setGlobalEpochs ( int e );
00080 void enableRPROP ( bool en );
00081 void useBLASforTraining ( bool enable );
00082 void trainOneEpoch();
00083 int trainNN();
00084 REAL getRMSETrain();
00085 REAL getRMSEProbe();
00086 void predictSingleInput ( REAL* input, REAL* output );
00087 REAL* getWeightPtr();
00088 void setWeights ( REAL* w );
00089 int getNrWeights();
00090 double m_sumSquaredError;
00091 double m_sumSquaredErrorSamples;
00092
00093
00094 int getWeightIndex ( int layer, int neuron, int weight );
00095 int getBiasIndex ( int layer, int neuron );
00096 int getOutputIndex ( int layer, int neuron );
00097
00098 private:
00099
00100 void saveWeights();
00101 REAL calcRMSE ( REAL* inputs, REAL* targets, int examples );
00102 void forwardCalculation ( REAL* input );
00103 void forwardCalculationBLAS ( REAL* input );
00104 void backpropBLAS ( REAL* input, REAL* target );
00105 void backprop ( REAL* input, REAL* target );
00106
00107
00108 REAL getInitWeight ( int fanIn );
00109
00110
00111 int m_nrTargets;
00112 int m_nrInputs;
00113 int m_nrExamplesTrain;
00114 int m_nrExamplesProbe;
00115 REAL* m_inputsTrain;
00116 REAL* m_inputsProbe;
00117 REAL* m_targetsTrain;
00118 REAL* m_targetsProbe;
00119
00120
00121 REAL m_initWeightFactor;
00122 int m_globalEpochs;
00123 REAL m_RPROP_etaPos;
00124 REAL m_RPROP_etaNeg;
00125 REAL m_RPROP_updateMin;
00126 REAL m_RPROP_updateMax;
00127 REAL m_learnRate;
00128 REAL m_learnRateMin;
00129 REAL m_learnrateDecreaseRate;
00130 REAL m_learnrateDecreaseRateEpoch;
00131 REAL m_momentum;
00132 REAL m_weightDecay;
00133 REAL m_minUpdateBound;
00134 int m_batchSize;
00135 int m_activationFunctionType;
00136
00137
00138 REAL m_scaleOutputs;
00139 REAL m_offsetOutputs;
00140 int m_maxEpochs;
00141 bool m_useBLAS;
00142 bool m_enableRPROP;
00143 bool m_normalTrainStopping;
00144 bool m_enableL1Regularization;
00145 bool m_errorFunctionMAE;
00146
00147
00148 int m_nrLayer;
00149 int* m_neuronsPerLayer;
00150 int m_nrWeights;
00151 int m_nrOutputs;
00152 int* m_nrLayWeights;
00153 int* m_nrLayWeightOffsets;
00154 REAL* m_outputs;
00155 REAL* m_outputsTmp;
00156 REAL* m_derivates;
00157 REAL* m_d1;
00158 REAL* m_weights;
00159 REAL* m_weightsTmp0;
00160 REAL* m_weightsTmp1;
00161 REAL* m_weightsTmp2;
00162 REAL* m_weightsBatchUpdate;
00163 REAL* m_weightsOld;
00164 REAL* m_weightsOldOld;
00165 REAL* m_deltaW;
00166 REAL* m_deltaWOld;
00167 REAL* m_adaptiveRPROPlRate;
00168
00169 };
00170
00171 #endif