LCFIVertex  0.7.2
BatchBackPropagationAlgorithm.h
1 #ifndef BATCHBACKPROPAGATIONALGORITHM_H
2 #define BATCHBACKPROPAGATIONALGORITHM_H
3 
4 #include "NeuralNetConfig.h"
5 
6 #include "NeuralNet.h"
7 #include "NeuralNetDataSet.h"
8 
9 #include <vector>
10 
11 #ifdef __CINT__
12 #include "InputNormaliser.h"
13 #else
14 //namespace nnet added 15/08/06 by Mark Grimes (mark.grimes@bristol.ac.uk) for the LCFI vertex package
15 namespace nnet
16 {
17 class InputNormaliser;
18 }
19 #endif
20 
21 //namespace nnet added 15/08/06 by Mark Grimes (mark.grimes@bristol.ac.uk) for the LCFI vertex package
22 namespace nnet
23 {
24 
25 class
26 #ifndef __CINT__
27 NEURALNETDLL
28 #endif
30 {
31 public:
32  BatchBackPropagationAlgorithm(NeuralNet &theNetwork,const double learningRate=0.5,const double momentumConstant=0.5);
34  void setLearningRate(const double newLearningRate)
35  { _learningRate = newLearningRate;}
36  void setMomentumConstant(const double newMomentumConstant)
37  { _momentumConstant = newMomentumConstant;}
38  void setMaxErrorIncrease(const double maxIncrease)
39  { _maxErrorInc = maxIncrease;}
40  double train(const int numberOfEpochs,const NeuralNetDataSet &dataSet,
41  const NeuralNet::InputNormalisationSelect normaliseTrainingData=NeuralNet::PassthroughNormalised);
42  double train(const int numberOfEpochs,const NeuralNetDataSet &dataSet,const std::vector<InputNormaliser *> &inputNormalisers);
43  void setProgressPrintoutFrequency(const int frequency) {_progressPrintoutFrequency = frequency;}
44  void setEpochsToWaitBeforeRestore(const int epochs) {_epochsToWaitBeforeRestore = epochs;}
45  std::vector<double> getTrainingErrorValuesPerEpoch() const {return _savedEpochErrorValues;}
46 
47 protected:
48  double trainWithDataSet(const int numberOfEpochs);
49  std::vector<double> layerOutput(const int layer) const;
50  void calculateLayerOutputs();
51  void calculateDerivativeOutputs();
52  void calculateErrorSignals();
53  void calculateRunningGradientTotal();
54  void calculateDeltaWeights();
55  double error();
56  double newEpoch();
57  double processDataSet();
58 
59 private:
60  typedef std::vector<std::vector<double> > NetMatrix;
61 
62 private:
63  NeuralNet &_theNetwork;
64  double _learningRate;
65  double _maxErrorInc;
66  const std::vector<double> *_inputs,*_target;
67  NetMatrix _neuronErrorSignals;
68  NetMatrix _neuronOutputs;
69  NetMatrix _neuronDerivativeOutputs;
70  NetMatrix _runningGradientTotal;
71  std::vector<double> _momentumWeights;
72  std::vector<double> _previousEpochWeights;
73  double _momentumConstant;
74  int _numberOfTrainingEvents;
75  double _previousEpochError;
76  double _runningEpochErrorTotal;
77  const NeuralNetDataSet *_currentDataSet;
78  int _progressPrintoutFrequency;
79  int _epochsToWaitBeforeRestore;
80  std::vector<double> _savedEpochErrorValues;
81 };
82 
83 }//namespace nnet
84 
85 #endif