LCFIVertex  0.7.2
NeuralNetTrainer.h
1 #ifndef NeuralNetTrainer_h
2 #define NeuralNetTrainer_h
3 
4 //Marlin and LCIO includes
5 #include "marlin/Processor.h"
6 #include "lcio.h"
7 #include "EVENT/ReconstructedParticle.h"
8 
9 //Neural Net includes
10 #include "nnet/inc/NeuralNet.h"
11 #include "nnet/inc/NeuralNetDataSet.h"
12 #include "nnet/inc/BackPropagationCGAlgorithm.h"
13 
71 class NeuralNetTrainerProcessor : public marlin::Processor
72 {
73 public:
74  //The usual Marlin processor methods
75  virtual Processor* newProcessor() { return new NeuralNetTrainerProcessor; }
77  virtual ~NeuralNetTrainerProcessor();
78  virtual void init();
79  virtual void processRunHeader( LCRunHeader* pRun );
80  virtual void processEvent( LCEvent* pEvent );
81  //don't need this
82  //virtual void check( LCEvent* pEvent );
83  virtual void end();
84 protected:
85  //variables for the steering file options
86  std::string _JetCollectionName; //The name of the collection of ReconstructedParticles that is the jet (comes from the steering file)
87  std::string _FlavourTagInputsCollectionName;
88  std::string _TrueJetFlavourCollectionName;
89  int _serialiseAsXML;
90  nnet::NeuralNet::SerialisationMode _outputFormat;
91 
92  //These maps all use the same string keys to distinguish between the different nets.
93  //The strings are of the form "c_net-2vtx", "bc_net-3vtx" etcetera.
94  std::map<std::string,std::string> _filename;
95  std::map<std::string,bool> _trainThisNet;
97  std::map<std::string,nnet::NeuralNetDataSet*> _dataSet;
98  std::map<std::string,int> _numSignal;
100  std::map<std::string,int> _numBackground;
102  //List of strings for all the nets enabled for training
103  //Strings are the same as the keys used in the maps above
104  std::vector<std::string> _listOfSelectedNetNames;
107  //This map holds the position of the Inputs in the LCFloatVec
108  std::map<std::string,unsigned int> _IndexOf;
111  int _nRun;
112  int _nEvent;
115  //useful constants
116  static const int C_JET=4;
117  static const int B_JET=5;
119  //The following functions are just code that has been split off so that the code doesn't look quite so cluttered.
120  void _displayCollectionNames( lcio::LCEvent* pEvent );
121  void _trainNet( nnet::BackPropagationCGAlgorithm& pBackPropCGAlgo, nnet::NeuralNetDataSet& dataSet );
122  bool _passesCuts( lcio::LCEvent* pEvent );
123 };
124 
125 #endif //ifndef NeuralNetTrainer_h
std::map< std::string, int > _numBackground
void _trainNet(nnet::BackPropagationCGAlgorithm &pBackPropCGAlgo, nnet::NeuralNetDataSet &dataSet)
Trains neural networks to be used for jet flavour tagging.
std::vector< std::string > _listOfSelectedNetNames
std::map< std::string, std::string > _filename
std::map< std::string, unsigned int > _IndexOf
std::map< std::string, bool > _trainThisNet
std::map< std::string, nnet::NeuralNetDataSet * > _dataSet
std::map< std::string, int > _numSignal
bool _passesCuts(lcio::LCEvent *pEvent)
void _displayCollectionNames(lcio::LCEvent *pEvent)