LCFIVertex  0.7.2
NeuralNetTrainer.cc
1 #include "NeuralNetTrainer.h"
2 #include <iostream>
3 #include <fstream>
4 #include <string>
5 #include <sstream>
6 #include <vector>
7 #include <cmath>
8 #include <set>
9 
10 
11 #include "EVENT/LCCollection.h"
12 #include "IMPL/ReconstructedParticleImpl.h"
13 #include "EVENT/LCFloatVec.h"
14 #include "EVENT/LCIntVec.h"
15 #include "EVENT/LCParameters.h"
16 
17 #include "util/inc/memorymanager.h"
18 #include "util/inc/vector3.h"
19 
20 #include "nnet/inc/NeuralNet.h"
21 #include "nnet/inc/NeuralNetDataSet.h"
22 #include "nnet/inc/SigmoidNeuronBuilder.h"
23 #include "nnet/inc/BackPropagationCGAlgorithm.h"
24 
25 //Needs to be instantiated for Marlin to know about it (I think)
26 NeuralNetTrainerProcessor aNeuralNetTrainerProcessor;
27 
28 NeuralNetTrainerProcessor::NeuralNetTrainerProcessor() : marlin::Processor("NeuralNetTrainer")
29 {
30  _description = "Trains a neural net from the lcio file" ;
31 
32  // register steering parameters: name, description, class-variable, default value
33 
34  //The name of the collection of ReconstructedParticles that is the jet
35  registerInputCollection( lcio::LCIO::RECONSTRUCTEDPARTICLE,
36  "JetCollectionName" ,
37  "Name of the collection of ReconstructedParticles that is the jet" ,
38  _JetCollectionName ,
39  std::string("SGVJets") ) ;
40  registerInputCollection( lcio::LCIO::LCFLOATVEC,
41  "FlavourTagInputsCollection" ,
42  "Name of the LCFloatVec Collection that contains the flavour tag inputs (in same order as jet collection)" ,
43  _FlavourTagInputsCollectionName,
44  "FlavourTagInputs" ) ;
45  registerInputCollection( lcio::LCIO::LCFLOATVEC,
46  "TrueJetFlavourCollection" ,
47  "Name of the LCIntVec Collection that contains the true jet flavours" ,
48  _TrueJetFlavourCollectionName,
49  "TrueJetFlavour" ) ;
50 
51  registerProcessorParameter( "SaveAsXML",
52  "Set this to 0 to output the neural nets in plain text format (default), or 1 (or anything non zero) to save in XML format",
53  _serialiseAsXML,
54  0 );
55 
56  //These are the variables for the output filenames of the trained nets
57  //Default is "" which will switch off training for that net.
58  registerProcessorParameter( "Filename-b_net-1vtx" ,
59  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
60  _filename["b_net-1vtx"],
61  std::string("") ) ;
62  registerProcessorParameter( "Filename-c_net-1vtx" ,
63  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
64  _filename["c_net-1vtx"],
65  std::string("") ) ;
66  registerProcessorParameter( "Filename-bc_net-1vtx" ,
67  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
68  _filename["bc_net-1vtx"],
69  std::string("") ) ;
70  registerProcessorParameter( "Filename-b_net-2vtx" ,
71  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
72  _filename["b_net-2vtx"],
73  std::string("") ) ;
74  registerProcessorParameter( "Filename-c_net-2vtx" ,
75  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
76  _filename["c_net-2vtx"],
77  std::string("") ) ;
78  registerProcessorParameter( "Filename-bc_net-2vtx" ,
79  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
80  _filename["bc_net-2vtx"],
81  std::string("") ) ;
82  registerProcessorParameter( "Filename-b_net-3plusvtx" ,
83  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
84  _filename["b_net-3vtx"],
85  std::string("") ) ;
86  registerProcessorParameter( "Filename-c_net-3plusvtx" ,
87  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
88  _filename["c_net-3vtx"],
89  std::string("") ) ;
90  registerProcessorParameter( "Filename-bc_net-3plusvtx" ,
91  "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
92  _filename["bc_net-3vtx"],
93  std::string("") ) ;
94 }
95 
96 NeuralNetTrainerProcessor::~NeuralNetTrainerProcessor()
97 {
98 }
99 
100 void NeuralNetTrainerProcessor::init()
101 {
102  printParameters();
103  std::cout << _description << std::endl
104  << "-------------------------------------------------" << std::endl
105  << std::endl;
106 
107  _nRun=0;
108 
109  //Have a look through all the net output filenames supplied in the steering file.
110  //If any of them are blank (the default if one isn't supplied) disable training for
111  //that net. Here "(*i).second" is the filename and "(*i).first" is the string key that
112  //identifies each net in all of the maps.
113  for( std::map<std::string,std::string>::iterator i=_filename.begin(); i!=_filename.end(); ++i )
114  {
115  if( (*i).second!="" )
116  {
117  _listOfSelectedNetNames.push_back( (*i).first );//make a list of the selected map names to make looping over them easier later on
118  _trainThisNet[ (*i).first ]=true;//turn on training of this net
119  }
120  else _trainThisNet[ (*i).first ]=false;//turn off training of this net
121  }
122 
123  //Just check that the user hasn't accidently disabled training of all the nets
124  if( _listOfSelectedNetNames.size()==0 )
125  {
126  std::stringstream message;
127  message << std::endl
128  << "############################################################################################\n"
129  << "# NeuralNetTrainerProcessor: #\n"
130  << "# No nets have been enabled for training in the steering file! #\n"
131  << "# Supply at least one output filename in the steering file. For example, put the line #\n"
132  << "# <parameter name=\"Filename-bc_net-1vtx\" type=\"string\"> bc_net-1vtx.xml </parameter> #\n"
133  << "# In with the parameters for this processor. #\n"
134  << "############################################################################################" << std::endl;
135  throw lcio::Exception( message.str() );
136  }
137 
138  //Decide which format (plain text or XML) to save the files as
139  if( _serialiseAsXML==0 ) _outputFormat=nnet::NeuralNet::PlainText;
140  else _outputFormat=nnet::NeuralNet::XML;
141 
142  //Allocate a data set for each of the enabled nets.
143  for( std::vector<std::string>::iterator iName=_listOfSelectedNetNames.begin(); iName<_listOfSelectedNetNames.end(); ++iName )
144  {
145  _dataSet[*iName]=new nnet::NeuralNetDataSet;
146  vertex_lcfi::MemoryManager<nnet::NeuralNetDataSet>::Run()->registerObject( _dataSet[*iName] );
147 
148  //also set all of the signal/background counters to 0
149  _numSignal[*iName]=0;
150  _numBackground[*iName]=0;
151  }
152 
153  // Initialise the event counters
154  _nEvent=0;
155  _nAcceptedEvents=0;
156 }
157 
158 void NeuralNetTrainerProcessor::processRunHeader( LCRunHeader* pRun )
159 {
160  _nRun++;
161 
162  //Get the list of flavour tag inputs Available
163  std::vector<std::string> VarNames;
164  (pRun->parameters()).getStringVals(_FlavourTagInputsCollectionName,VarNames);
165 
166  std::set<std::string> AvailableNames;
167  for (size_t i = 0;i < VarNames.size();++i)
168  {
169  AvailableNames.insert(VarNames[i]);
170  _IndexOf[VarNames[i]] = i;
171  }
172 
173  //Check the required information is in the LCFloatVec
174  std::set<std::string> RequiredNames;
175  RequiredNames.insert( "D0Significance1");
176  RequiredNames.insert( "D0Significance2");
177  RequiredNames.insert( "Z0Significance1" );
178  RequiredNames.insert( "Z0Significance2" );
179  RequiredNames.insert( "JointProbRPhi");
180  RequiredNames.insert( "JointProbZ");
181  RequiredNames.insert( "Momentum1");
182  RequiredNames.insert( "Momentum2");
183  RequiredNames.insert( "DecayLengthSignificance");
184  RequiredNames.insert( "DecayLength");
185  RequiredNames.insert( "PTCorrectedMass");
186  RequiredNames.insert( "RawMomentum");
187  RequiredNames.insert( "NumTracksInVertices" );
188  RequiredNames.insert( "SecondaryVertexProbability");
189 
190  if (!includes(AvailableNames.begin(),AvailableNames.end(),RequiredNames.begin(),RequiredNames.end()))
191  std::cerr << _FlavourTagInputsCollectionName << " does not contain information required by NeuralNetTrainerProcessor";
192 
193 }
194 
195 void NeuralNetTrainerProcessor::processEvent( lcio::LCEvent* pEvent )
196 {
197  //Output the collection names for debugging
198  if( isFirstEvent() ) _displayCollectionNames( pEvent );
199 
200  //Get the collection of jets. Can't do anything if the collection isn't there
201  //so don't bother catching the exception and terminate.
202  lcio::LCCollection* pJetCollection=pEvent->getCollection( _JetCollectionName );
203 
204  //make sure the collection is of the right type
205  if( pJetCollection->getTypeName()!=lcio::LCIO::RECONSTRUCTEDPARTICLE )
206  {
207  std::stringstream message;
208  message << std::endl
209  << "########################################################################################\n"
210  << "# NeuralNetTrainerProcessor: #\n"
211  << "# The jet collection requested (\"" << _JetCollectionName << "\") is not of the type \"" << lcio::LCIO::RECONSTRUCTEDPARTICLE << "\" #\n"
212  << "########################################################################################" << std::endl;
213  throw lcio::EventException( message.str() );
214  }
215 
216  lcio::ReconstructedParticle* pJet;
217  int numJets=pJetCollection->getNumberOfElements();
218 
219  //apply any cuts on the event here
220  if( _passesCuts(pEvent) )
221  {
222  //loop over the jets
223  for( int a=0; a<numJets; ++a )
224  {
225  //Dynamic casts are not the best programming practice in the world, but I can't see another way of doing this
226  //in the LCIO framework. This cast should be safe though because we've already tested the type.
227  pJet=dynamic_cast<lcio::ReconstructedParticle*>( pJetCollection->getElementAt(a) );
228 
229  // Find out the jet energy to work out the correct normalisation constants
230  double jetEnergy=pJet->getEnergy();
231  if( 0==jetEnergy )
232  {
233  jetEnergy=45.5;
234  if( isFirstEvent() ) std::cout << "*** NeuralNetTrainer - Warning: Jet energy undefined, assuming 45.5GeV ***" << std::endl;
235  }
236 
237 /*
238  -----------------------------------
239  -------------IMPORTANT-------------
240  -----------------------------------
241  If any of these normalisation constants are changed, update in the documentation by modifying
242  the main class description in NeuralNetTrainer.h (at around line 20).
243 */
244  // Variables for the normalisation of the inputs
245  double Norm_D0Significance = 100.0;
246  double Norm_Z0Significance = 100.0;
247  double Norm_Momentum = jetEnergy/3.0;
248  double Norm_DecayLengthSignificance = 6.0*jetEnergy;
249  double Norm_DecayLength = 1.0;
250  double Norm_PTMassCorrection = 5.0;
251  double Norm_RawMomentum = jetEnergy;
252  double Norm_NumTracksInVertices = 10.0;
253 
254  //Get the MC Jet type
255  lcio::LCCollection* pTrueJet=pEvent->getCollection( _TrueJetFlavourCollectionName );
256  //make sure the collection is of the right type
257  if( pTrueJet->getTypeName()!=lcio::LCIO::LCFLOATVEC )
258  {
259  std::stringstream message;
260  message << std::endl
261  << "########################################################################################\n"
262  << "# FlavourTagProcessor - #\n"
263  << "# The jet collection requested (\"" << _TrueJetFlavourCollectionName << "\") is not of the type \"" << lcio::LCIO::LCINTVEC << "\" #\n"
264  << "########################################################################################" << std::endl;
265  throw lcio::EventException( message.str() );
266  }
267  float jetType = *((dynamic_cast<lcio::LCFloatVec*>( pTrueJet->getElementAt(a))->begin()));
268 
269  //
270  // See if we can get the required info from the file
271  //
272 
273  lcio::LCCollection* pInputs=pEvent->getCollection( _FlavourTagInputsCollectionName );
274  //make sure the collection is of the right type
275  if( pInputs->getTypeName()!=lcio::LCIO::LCFLOATVEC )
276  {
277  std::stringstream message;
278  message << std::endl
279  << "########################################################################################\n"
280  << "# FlavourTagProcessor - #\n"
281  << "# The jet collection requested (\"" << _FlavourTagInputsCollectionName << "\") is not of the type \"" << lcio::LCIO::LCFLOATVEC << "\" #\n"
282  << "########################################################################################" << std::endl;
283  throw lcio::EventException( message.str() );
284  }
285  LCFloatVec Inputs = *(dynamic_cast<lcio::LCFloatVec*>( pInputs->getElementAt(a) ));
286 
287  std::vector<double> inputs;
288  std::vector<double> target;
289 
290  double NumVertices = Inputs[_IndexOf["NumVertices"]];
291  //TODO Check that the inputs exist in the index
292  if( NumVertices==1 )
293  {
294  inputs.push_back( std::tanh(Inputs[_IndexOf["D0Significance1"]]/Norm_D0Significance) );
295  inputs.push_back( std::tanh(Inputs[_IndexOf["D0Significance2"]]/Norm_D0Significance) );
296  inputs.push_back( std::tanh(Inputs[_IndexOf["Z0Significance1"]]/Norm_Z0Significance) );
297  inputs.push_back( std::tanh(Inputs[_IndexOf["Z0Significance2"]]/Norm_Z0Significance) );
298  inputs.push_back( Inputs[_IndexOf["JointProbRPhi"]] );
299  inputs.push_back( Inputs[_IndexOf["JointProbZ"]] );
300  inputs.push_back( std::tanh(Inputs[_IndexOf["Momentum1"]]/Norm_Momentum) );
301  inputs.push_back( std::tanh(Inputs[_IndexOf["Momentum2"]]/Norm_Momentum) );
302  }
303  else
304  {
305  inputs.push_back( std::tanh(Inputs[_IndexOf["DecayLengthSignificance"]]/Norm_DecayLengthSignificance) );
306  inputs.push_back( std::tanh((Inputs[_IndexOf["DecayLength"]]/10.0)/Norm_DecayLength));
307  inputs.push_back( std::tanh(Inputs[_IndexOf["PTCorrectedMass"]]/Norm_PTMassCorrection) );
308  inputs.push_back( std::tanh(Inputs[_IndexOf["RawMomentum"]]/Norm_RawMomentum) );
309  inputs.push_back( Inputs[_IndexOf["JointProbRPhi"]] );
310  inputs.push_back( Inputs[_IndexOf["JointProbZ"]] );
311  inputs.push_back( std::tanh(Inputs[_IndexOf["NumTracksInVertices"]]/Norm_NumTracksInVertices) );
312  inputs.push_back( Inputs[_IndexOf["SecondaryVertexProbability"]] );
313  }
314 
315  if( jetType==B_JET )
316  {
317  target.clear();
318  target.push_back( 1.0 );
319  if( _trainThisNet["b_net-1vtx"] && NumVertices==1 ){ _dataSet["b_net-1vtx"]->addDataItem( inputs, target );_numSignal["b_net-1vtx"]+=1;}
320  if( _trainThisNet["b_net-2vtx"] && NumVertices==2 ){ _dataSet["b_net-2vtx"]->addDataItem( inputs, target );_numSignal["b_net-2vtx"]+=1;}
321  if( _trainThisNet["b_net-3vtx"] && NumVertices>=3 ){ _dataSet["b_net-3vtx"]->addDataItem( inputs, target );_numSignal["b_net-3vtx"]+=1;}
322  target.clear();
323  target.push_back( 0.0 );
324  if( _trainThisNet["c_net-1vtx"] && NumVertices==1 ){ _dataSet["c_net-1vtx"]->addDataItem( inputs, target );_numBackground["c_net-1vtx"]+=1;}
325  if( _trainThisNet["c_net-2vtx"] && NumVertices==2 ){ _dataSet["c_net-2vtx"]->addDataItem( inputs, target );_numBackground["c_net-2vtx"]+=1;}
326  if( _trainThisNet["c_net-3vtx"] && NumVertices>=3 ){ _dataSet["c_net-3vtx"]->addDataItem( inputs, target );_numBackground["c_net-3vtx"]+=1;}
327  if( _trainThisNet["bc_net-1vtx"] && NumVertices==1 ){ _dataSet["bc_net-1vtx"]->addDataItem( inputs, target );_numBackground["bc_net-1vtx"]+=1;}
328  if( _trainThisNet["bc_net-2vtx"] && NumVertices==2 ){ _dataSet["bc_net-2vtx"]->addDataItem( inputs, target );_numBackground["bc_net-2vtx"]+=1;}
329  if( _trainThisNet["bc_net-3vtx"] && NumVertices>=3 ){ _dataSet["bc_net-3vtx"]->addDataItem( inputs, target );_numBackground["bc_net-3vtx"]+=1;}
330  }
331  else if( jetType==C_JET )
332  {
333  target.clear();
334  target.push_back( 0.0 );
335  if( _trainThisNet["b_net-1vtx"] && NumVertices==1 ){ _dataSet["b_net-1vtx"]->addDataItem( inputs, target );_numBackground["b_net-1vtx"]+=1;}
336  if( _trainThisNet["b_net-2vtx"] && NumVertices==2 ){ _dataSet["b_net-2vtx"]->addDataItem( inputs, target );_numBackground["b_net-2vtx"]+=1;}
337  if( _trainThisNet["b_net-3vtx"] && NumVertices>=3 ){ _dataSet["b_net-3vtx"]->addDataItem( inputs, target );_numBackground["b_net-3vtx"]+=1;}
338  target.clear();
339  target.push_back( 1.0 );
340  if( _trainThisNet["c_net-1vtx"] && NumVertices==1 ){ _dataSet["c_net-1vtx"]->addDataItem( inputs, target );_numSignal["c_net-1vtx"]+=1;}
341  if( _trainThisNet["c_net-2vtx"] && NumVertices==2 ){ _dataSet["c_net-2vtx"]->addDataItem( inputs, target );_numSignal["c_net-2vtx"]+=1;}
342  if( _trainThisNet["c_net-3vtx"] && NumVertices>=3 ){ _dataSet["c_net-3vtx"]->addDataItem( inputs, target );_numSignal["c_net-3vtx"]+=1;}
343  if( _trainThisNet["bc_net-1vtx"] && NumVertices==1 ){ _dataSet["bc_net-1vtx"]->addDataItem( inputs, target );_numSignal["bc_net-1vtx"]+=1;}
344  if( _trainThisNet["bc_net-2vtx"] && NumVertices==2 ){ _dataSet["bc_net-2vtx"]->addDataItem( inputs, target );_numSignal["bc_net-2vtx"]+=1;}
345  if( _trainThisNet["bc_net-3vtx"] && NumVertices>=3 ){ _dataSet["bc_net-3vtx"]->addDataItem( inputs, target );_numSignal["bc_net-3vtx"]+=1;}
346  }
347  else
348  {
349  target.clear();
350  target.push_back( 0.0 );
351  if( _trainThisNet["b_net-1vtx"] && NumVertices==1 ){ _dataSet["b_net-1vtx"]->addDataItem( inputs, target );_numBackground["b_net-1vtx"]+=1;}
352  if( _trainThisNet["b_net-2vtx"] && NumVertices==2 ){ _dataSet["b_net-2vtx"]->addDataItem( inputs, target );_numBackground["b_net-2vtx"]+=1;}
353  if( _trainThisNet["b_net-3vtx"] && NumVertices>=3 ){ _dataSet["b_net-3vtx"]->addDataItem( inputs, target );_numBackground["b_net-3vtx"]+=1;}
354  if( _trainThisNet["c_net-1vtx"] && NumVertices==1 ){ _dataSet["c_net-1vtx"]->addDataItem( inputs, target );_numBackground["c_net-1vtx"]+=1;}
355  if( _trainThisNet["c_net-2vtx"] && NumVertices==2 ){ _dataSet["c_net-2vtx"]->addDataItem( inputs, target );_numBackground["c_net-2vtx"]+=1;}
356  if( _trainThisNet["c_net-3vtx"] && NumVertices>=3 ){ _dataSet["c_net-3vtx"]->addDataItem( inputs, target );_numBackground["c_net-3vtx"]+=1;}
357  //don't fill anything for the bc net because this isn't a b or a c jet
358  }
359 
360  }
361 
362  ++_nAcceptedEvents;
363  }
364  ++_nEvent;
365 
366  //Clear anything that may have been allocated during this event
368 }
369 
370 /*
371 -----------------------------------
372 -------------IMPORTANT-------------
373 -----------------------------------
374 If you change the cuts make sure you change the line below to show the changes in the docs*/
377 bool NeuralNetTrainerProcessor::_passesCuts( lcio::LCEvent* pEvent )
378 {
379  //Any cuts would go in here. Do a test and then return false if the event fails, let control carry on to the other
380  //tests if it passes. If the event has passed all the cuts then there is a return true at the end.
381  //Currently only a cut on the jet momentum theta.
382 
383  std::vector<vertex_lcfi::util::Vector3> jetMomentums;
384 
385  //Don't want to flood the screen if the data is not available, so count how many times a warning has been given.
386  static int numWarningsNoMomentum=0;//The number of times a warning has been printed that momentum data is not available.
387 
388  try
389  {
390  lcio::LCCollection* pJetCollection=pEvent->getCollection( _JetCollectionName );
391  for( int i=0; i<pJetCollection->getNumberOfElements(); ++i )
392  {
393  const double* mom=(dynamic_cast<lcio::ReconstructedParticle*>( pJetCollection->getElementAt(i) ))->getMomentum();
394  if( mom[0]==0 && mom[1]==0 && mom[2]==0 ) throw lcio::Exception( "Jet momentum not defined" );
395 
396  jetMomentums.push_back( vertex_lcfi::util::Vector3( mom[0], mom[1], mom[2] ) );
397  }
398  }
399  catch( lcio::Exception exception )
400  {
401  //Just print a warning and proceed with the other cuts.
402  if( numWarningsNoMomentum<=2 )
403  {
404  std::cerr << "############################################################################\n"
405  << "# NeuralNetTrainerProcessor: #\n"
406  << "# Unable to get the data for the jet momentum because - #\n"
407  << "# #\n"
408  << exception.what() << std::endl
409  << "# #\n"
410  << "# Training will proceed with no cut on the jet momentum theta. #\n";
411  if ( numWarningsNoMomentum==2 ) std::cerr << "# NO FURTHER WARNINGS WILL BE DISPLAYED #\n";
412  std::cerr << "############################################################################" << std::endl;
413  ++numWarningsNoMomentum;
414  }
415  }
416 
417  //
418  // Fail this event if any of the jets have 30deg < theta < 150deg
419  //
420 
421  vertex_lcfi::util::Vector3 zAxis( 0, 0, 1 );
422 
423  //If the above try block failed then this will be empty and control will move on to the other cuts.
424  for( std::vector<vertex_lcfi::util::Vector3>::iterator iMom=jetMomentums.begin(); iMom<jetMomentums.end(); ++iMom )
425  {
426  (*iMom).makeUnit();
427  double cosTheta=(*iMom).dot( zAxis );
428  //equivalent to "if( theta>150degrees || theta<30degrees )"
429  //Should maybe have this value as a steering file parameter?
430  if( cosTheta>0.866 || cosTheta<-0.866 ) return false;
431  }
432 
433 
434  //other cuts would go in here, returning false if the event fails
435 
436 
437  //If control gets to here then the event has passed all the cuts.
438  return true;
439 }
440 
441 void NeuralNetTrainerProcessor::end()
442 {
443  //
444  //The data sets should all be filled, so train the nets
445  //
446 
447  //Just make one neuron builder and reuse it
448  nnet::SigmoidNeuronBuilder neuronBuilder;
449 
450  //All the nets use the same amount of nodes, so just create one of these and reuse it
451  int nInputs=8;
452  std::vector<int> nodes;
453  nodes.push_back(2 * nInputs - 2);
454  nodes.push_back(1);
455 
456  //print out info on how many events passed the cuts
457  std::cout << "NeuralNetTrainer: " << _nAcceptedEvents << " of " << _nEvent << " events passed the cuts. See the documentation of NeuralNetTrainerProcessor::_passesCuts() for details of the cuts applied." << std::endl;
458 
459  //Train and save any nets that have been selected for training
460  for( std::vector<std::string>::iterator iName=_listOfSelectedNetNames.begin(); iName<_listOfSelectedNetNames.end(); ++iName )
461  {
462  //Not going to need these once the net is trained and saved so have them local to this 'for' loop
463  nnet::NeuralNet thisNeuralNet( nInputs, nodes, &neuronBuilder, 1 );
464  nnet::BackPropagationCGAlgorithm myAlgorithm( thisNeuralNet );
465 
466  std::cout << std::endl << "Training neural net " << *iName << " with " << _dataSet[*iName]->numberOfDataItems()
467  << " jets " << "(" << _numBackground[*iName] << " background, " << _numSignal[*iName] << " signal)..." << std::endl;
468 
469  // Make sure we can open the file before training. Nothing worse than waiting ages to train and then losing the result!
470  std::ofstream outputFile( _filename[*iName].c_str() );
471  if( outputFile.is_open() )
472  {
473  //do the training
474  _trainNet( myAlgorithm, *_dataSet[*iName] );
475 
476  //Set the output format to the one requested in the steering file
477  thisNeuralNet.setSerialisationMode( _outputFormat );
478  thisNeuralNet.serialise( outputFile );
479  outputFile.close();
480  }
481  else
482  {
483  std::cerr << "Unable to open file " << _filename[*iName] << "! Skipping training for this net." << std::endl;
484  }
485  }
486 
487 
488  std::cout << "Finished training all selected nets" << std::endl;
489 
490  //free up stuff
492 }
493 
495 {
496  //This function pretty much just calls backPropCGAlgo.train(...) at the moment, although code can easily be added
497  //to check the errors after each iteration
498  double PrevErr,CurrErr;//Training errors
499  int i=0;
500  bool breakLoop=false;//not actually used at the moment, but will be used to cut the loop early if required
501 
502  while( i<50 && breakLoop==false )
503  {
504  // for CG algorithm with 10 epochs per loop index
505  //A bit silly calling 10 epochs 50 times instead of just 500 epochs, but I'll
506  //hopefully put in some code to cut out early depending on how the training is
507  //going.
508  backPropCGAlgo.train( 10, dataSet );
509 
510  //Have a look at the errors. Not a lot of point at the moment but you could put
511  //something in here to cut out early if the errors aren't getting significantly
512  //smaller (e.g. "if( (CurrErr-PrevErr)/PrevErr < 0.02 ) breakLoop=true")
513  std::vector<double> epochErrors=backPropCGAlgo.getTrainingErrorValuesPerEpoch();
514  CurrErr=epochErrors.back();
515  PrevErr = CurrErr;
516 
517  // 26/Apr/07 - Been having problems with the net not training and getting a NAN
518  //error under certain conditions which are still being looked into. This takes
519  //ages, and the loop keeps trying to do the same thing over and over, so check
520  //for this and cut the loop to save running time. Maybe dump the data set somewhere?
521  if( std::isnan( CurrErr ) )
522  {
523  std::cerr << "NeuralNetTrainer.cc (line 523): Training the net gave an error of NaN! That's not good. Still looking\n"
524  << "into why this happens, most likely there's not enough difference between the tag variables of your\n"
525  << "signal and background data sets. Your net is going to be gibberish. Sorry." << std::endl;
526  breakLoop=true;
527  }
528 
529  ++i;
530  }
531 }
532 
534 {
535  const std::vector<std::string>* pCollectionNames=pEvent->getCollectionNames();
536 
537  std::cout << "The available collections are: (name - type)" << std::endl;
538  for( std::vector<std::string>::const_iterator i=pCollectionNames->begin(); i<pCollectionNames->end(); ++i )
539  {
540  lcio::LCCollection* pCollection=pEvent->getCollection( (*i) );
541  const std::string typeName=pCollection->getTypeName();
542  std::cout << " " << (*i) << " - " << typeName << std::endl;
543  }
544  std::cout << std::endl;
545 }
void _trainNet(nnet::BackPropagationCGAlgorithm &pBackPropCGAlgo, nnet::NeuralNetDataSet &dataSet)
static MetaMemoryManager * Event()
Returns the Event duration singleton instance of the controller.
void delAllObjects()
Delete all objects of all types held by this instance.
Trains neural networks to be used for jet flavour tagging.
static MetaMemoryManager * Run()
Returns the Run duration singleton instance of the controller.
static MemoryManager< T > * Run()
Returns the Run duration singleton instance of the MemoryManager for type T.
bool _passesCuts(lcio::LCEvent *pEvent)
void _displayCollectionNames(lcio::LCEvent *pEvent)