1 #include "NeuralNetTrainer.h"
11 #include "EVENT/LCCollection.h"
12 #include "IMPL/ReconstructedParticleImpl.h"
13 #include "EVENT/LCFloatVec.h"
14 #include "EVENT/LCIntVec.h"
15 #include "EVENT/LCParameters.h"
17 #include "util/inc/memorymanager.h"
18 #include "util/inc/vector3.h"
20 #include "nnet/inc/NeuralNet.h"
21 #include "nnet/inc/NeuralNetDataSet.h"
22 #include "nnet/inc/SigmoidNeuronBuilder.h"
23 #include "nnet/inc/BackPropagationCGAlgorithm.h"
28 NeuralNetTrainerProcessor::NeuralNetTrainerProcessor() : marlin::Processor(
"NeuralNetTrainer")
30 _description =
"Trains a neural net from the lcio file" ;
35 registerInputCollection( lcio::LCIO::RECONSTRUCTEDPARTICLE,
37 "Name of the collection of ReconstructedParticles that is the jet" ,
39 std::string(
"SGVJets") ) ;
40 registerInputCollection( lcio::LCIO::LCFLOATVEC,
41 "FlavourTagInputsCollection" ,
42 "Name of the LCFloatVec Collection that contains the flavour tag inputs (in same order as jet collection)" ,
43 _FlavourTagInputsCollectionName,
44 "FlavourTagInputs" ) ;
45 registerInputCollection( lcio::LCIO::LCFLOATVEC,
46 "TrueJetFlavourCollection" ,
47 "Name of the LCIntVec Collection that contains the true jet flavours" ,
48 _TrueJetFlavourCollectionName,
51 registerProcessorParameter(
"SaveAsXML",
52 "Set this to 0 to output the neural nets in plain text format (default), or 1 (or anything non zero) to save in XML format",
58 registerProcessorParameter(
"Filename-b_net-1vtx" ,
59 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
60 _filename[
"b_net-1vtx"],
62 registerProcessorParameter(
"Filename-c_net-1vtx" ,
63 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
64 _filename[
"c_net-1vtx"],
66 registerProcessorParameter(
"Filename-bc_net-1vtx" ,
67 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
68 _filename[
"bc_net-1vtx"],
70 registerProcessorParameter(
"Filename-b_net-2vtx" ,
71 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
72 _filename[
"b_net-2vtx"],
74 registerProcessorParameter(
"Filename-c_net-2vtx" ,
75 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
76 _filename[
"c_net-2vtx"],
78 registerProcessorParameter(
"Filename-bc_net-2vtx" ,
79 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
80 _filename[
"bc_net-2vtx"],
82 registerProcessorParameter(
"Filename-b_net-3plusvtx" ,
83 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
84 _filename[
"b_net-3vtx"],
86 registerProcessorParameter(
"Filename-c_net-3plusvtx" ,
87 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
88 _filename[
"c_net-3vtx"],
90 registerProcessorParameter(
"Filename-bc_net-3plusvtx" ,
91 "Output filename for the trained net. If it is blank (default) then this net is not trained" ,
92 _filename[
"bc_net-3vtx"],
96 NeuralNetTrainerProcessor::~NeuralNetTrainerProcessor()
100 void NeuralNetTrainerProcessor::init()
103 std::cout << _description << std::endl
104 <<
"-------------------------------------------------" << std::endl
113 for( std::map<std::string,std::string>::iterator i=_filename.begin(); i!=_filename.end(); ++i )
115 if( (*i).second!=
"" )
117 _listOfSelectedNetNames.push_back( (*i).first );
118 _trainThisNet[ (*i).first ]=
true;
120 else _trainThisNet[ (*i).first ]=
false;
124 if( _listOfSelectedNetNames.size()==0 )
126 std::stringstream message;
128 <<
"############################################################################################\n"
129 <<
"# NeuralNetTrainerProcessor: #\n"
130 <<
"# No nets have been enabled for training in the steering file! #\n"
131 <<
"# Supply at least one output filename in the steering file. For example, put the line #\n"
132 <<
"# <parameter name=\"Filename-bc_net-1vtx\" type=\"string\"> bc_net-1vtx.xml </parameter> #\n"
133 <<
"# In with the parameters for this processor. #\n"
134 <<
"############################################################################################" << std::endl;
135 throw lcio::Exception( message.str() );
139 if( _serialiseAsXML==0 ) _outputFormat=nnet::NeuralNet::PlainText;
140 else _outputFormat=nnet::NeuralNet::XML;
143 for( std::vector<std::string>::iterator iName=_listOfSelectedNetNames.begin(); iName<_listOfSelectedNetNames.end(); ++iName )
149 _numSignal[*iName]=0;
150 _numBackground[*iName]=0;
158 void NeuralNetTrainerProcessor::processRunHeader( LCRunHeader* pRun )
163 std::vector<std::string> VarNames;
164 (pRun->parameters()).getStringVals(_FlavourTagInputsCollectionName,VarNames);
166 std::set<std::string> AvailableNames;
167 for (
size_t i = 0;i < VarNames.size();++i)
169 AvailableNames.insert(VarNames[i]);
170 _IndexOf[VarNames[i]] = i;
174 std::set<std::string> RequiredNames;
175 RequiredNames.insert(
"D0Significance1");
176 RequiredNames.insert(
"D0Significance2");
177 RequiredNames.insert(
"Z0Significance1" );
178 RequiredNames.insert(
"Z0Significance2" );
179 RequiredNames.insert(
"JointProbRPhi");
180 RequiredNames.insert(
"JointProbZ");
181 RequiredNames.insert(
"Momentum1");
182 RequiredNames.insert(
"Momentum2");
183 RequiredNames.insert(
"DecayLengthSignificance");
184 RequiredNames.insert(
"DecayLength");
185 RequiredNames.insert(
"PTCorrectedMass");
186 RequiredNames.insert(
"RawMomentum");
187 RequiredNames.insert(
"NumTracksInVertices" );
188 RequiredNames.insert(
"SecondaryVertexProbability");
190 if (!includes(AvailableNames.begin(),AvailableNames.end(),RequiredNames.begin(),RequiredNames.end()))
191 std::cerr << _FlavourTagInputsCollectionName <<
" does not contain information required by NeuralNetTrainerProcessor";
195 void NeuralNetTrainerProcessor::processEvent( lcio::LCEvent* pEvent )
198 if( isFirstEvent() ) _displayCollectionNames( pEvent );
202 lcio::LCCollection* pJetCollection=pEvent->getCollection( _JetCollectionName );
205 if( pJetCollection->getTypeName()!=lcio::LCIO::RECONSTRUCTEDPARTICLE )
207 std::stringstream message;
209 <<
"########################################################################################\n"
210 <<
"# NeuralNetTrainerProcessor: #\n"
211 <<
"# The jet collection requested (\"" << _JetCollectionName <<
"\") is not of the type \"" << lcio::LCIO::RECONSTRUCTEDPARTICLE <<
"\" #\n"
212 <<
"########################################################################################" << std::endl;
213 throw lcio::EventException( message.str() );
216 lcio::ReconstructedParticle* pJet;
217 int numJets=pJetCollection->getNumberOfElements();
220 if( _passesCuts(pEvent) )
223 for(
int a=0; a<numJets; ++a )
227 pJet=
dynamic_cast<lcio::ReconstructedParticle*
>( pJetCollection->getElementAt(a) );
230 double jetEnergy=pJet->getEnergy();
234 if( isFirstEvent() ) std::cout <<
"*** NeuralNetTrainer - Warning: Jet energy undefined, assuming 45.5GeV ***" << std::endl;
245 double Norm_D0Significance = 100.0;
246 double Norm_Z0Significance = 100.0;
247 double Norm_Momentum = jetEnergy/3.0;
248 double Norm_DecayLengthSignificance = 6.0*jetEnergy;
249 double Norm_DecayLength = 1.0;
250 double Norm_PTMassCorrection = 5.0;
251 double Norm_RawMomentum = jetEnergy;
252 double Norm_NumTracksInVertices = 10.0;
255 lcio::LCCollection* pTrueJet=pEvent->getCollection( _TrueJetFlavourCollectionName );
257 if( pTrueJet->getTypeName()!=lcio::LCIO::LCFLOATVEC )
259 std::stringstream message;
261 <<
"########################################################################################\n"
262 <<
"# FlavourTagProcessor - #\n"
263 <<
"# The jet collection requested (\"" << _TrueJetFlavourCollectionName <<
"\") is not of the type \"" << lcio::LCIO::LCINTVEC <<
"\" #\n"
264 <<
"########################################################################################" << std::endl;
265 throw lcio::EventException( message.str() );
267 float jetType = *((
dynamic_cast<lcio::LCFloatVec*
>( pTrueJet->getElementAt(a))->begin()));
273 lcio::LCCollection* pInputs=pEvent->getCollection( _FlavourTagInputsCollectionName );
275 if( pInputs->getTypeName()!=lcio::LCIO::LCFLOATVEC )
277 std::stringstream message;
279 <<
"########################################################################################\n"
280 <<
"# FlavourTagProcessor - #\n"
281 <<
"# The jet collection requested (\"" << _FlavourTagInputsCollectionName <<
"\") is not of the type \"" << lcio::LCIO::LCFLOATVEC <<
"\" #\n"
282 <<
"########################################################################################" << std::endl;
283 throw lcio::EventException( message.str() );
285 LCFloatVec Inputs = *(
dynamic_cast<lcio::LCFloatVec*
>( pInputs->getElementAt(a) ));
287 std::vector<double> inputs;
288 std::vector<double> target;
290 double NumVertices = Inputs[_IndexOf[
"NumVertices"]];
294 inputs.push_back( std::tanh(Inputs[_IndexOf[
"D0Significance1"]]/Norm_D0Significance) );
295 inputs.push_back( std::tanh(Inputs[_IndexOf[
"D0Significance2"]]/Norm_D0Significance) );
296 inputs.push_back( std::tanh(Inputs[_IndexOf[
"Z0Significance1"]]/Norm_Z0Significance) );
297 inputs.push_back( std::tanh(Inputs[_IndexOf[
"Z0Significance2"]]/Norm_Z0Significance) );
298 inputs.push_back( Inputs[_IndexOf[
"JointProbRPhi"]] );
299 inputs.push_back( Inputs[_IndexOf[
"JointProbZ"]] );
300 inputs.push_back( std::tanh(Inputs[_IndexOf[
"Momentum1"]]/Norm_Momentum) );
301 inputs.push_back( std::tanh(Inputs[_IndexOf[
"Momentum2"]]/Norm_Momentum) );
305 inputs.push_back( std::tanh(Inputs[_IndexOf[
"DecayLengthSignificance"]]/Norm_DecayLengthSignificance) );
306 inputs.push_back( std::tanh((Inputs[_IndexOf[
"DecayLength"]]/10.0)/Norm_DecayLength));
307 inputs.push_back( std::tanh(Inputs[_IndexOf[
"PTCorrectedMass"]]/Norm_PTMassCorrection) );
308 inputs.push_back( std::tanh(Inputs[_IndexOf[
"RawMomentum"]]/Norm_RawMomentum) );
309 inputs.push_back( Inputs[_IndexOf[
"JointProbRPhi"]] );
310 inputs.push_back( Inputs[_IndexOf[
"JointProbZ"]] );
311 inputs.push_back( std::tanh(Inputs[_IndexOf[
"NumTracksInVertices"]]/Norm_NumTracksInVertices) );
312 inputs.push_back( Inputs[_IndexOf[
"SecondaryVertexProbability"]] );
318 target.push_back( 1.0 );
319 if( _trainThisNet[
"b_net-1vtx"] && NumVertices==1 ){ _dataSet[
"b_net-1vtx"]->addDataItem( inputs, target );_numSignal[
"b_net-1vtx"]+=1;}
320 if( _trainThisNet[
"b_net-2vtx"] && NumVertices==2 ){ _dataSet[
"b_net-2vtx"]->addDataItem( inputs, target );_numSignal[
"b_net-2vtx"]+=1;}
321 if( _trainThisNet[
"b_net-3vtx"] && NumVertices>=3 ){ _dataSet[
"b_net-3vtx"]->addDataItem( inputs, target );_numSignal[
"b_net-3vtx"]+=1;}
323 target.push_back( 0.0 );
324 if( _trainThisNet[
"c_net-1vtx"] && NumVertices==1 ){ _dataSet[
"c_net-1vtx"]->addDataItem( inputs, target );_numBackground[
"c_net-1vtx"]+=1;}
325 if( _trainThisNet[
"c_net-2vtx"] && NumVertices==2 ){ _dataSet[
"c_net-2vtx"]->addDataItem( inputs, target );_numBackground[
"c_net-2vtx"]+=1;}
326 if( _trainThisNet[
"c_net-3vtx"] && NumVertices>=3 ){ _dataSet[
"c_net-3vtx"]->addDataItem( inputs, target );_numBackground[
"c_net-3vtx"]+=1;}
327 if( _trainThisNet[
"bc_net-1vtx"] && NumVertices==1 ){ _dataSet[
"bc_net-1vtx"]->addDataItem( inputs, target );_numBackground[
"bc_net-1vtx"]+=1;}
328 if( _trainThisNet[
"bc_net-2vtx"] && NumVertices==2 ){ _dataSet[
"bc_net-2vtx"]->addDataItem( inputs, target );_numBackground[
"bc_net-2vtx"]+=1;}
329 if( _trainThisNet[
"bc_net-3vtx"] && NumVertices>=3 ){ _dataSet[
"bc_net-3vtx"]->addDataItem( inputs, target );_numBackground[
"bc_net-3vtx"]+=1;}
331 else if( jetType==C_JET )
334 target.push_back( 0.0 );
335 if( _trainThisNet[
"b_net-1vtx"] && NumVertices==1 ){ _dataSet[
"b_net-1vtx"]->addDataItem( inputs, target );_numBackground[
"b_net-1vtx"]+=1;}
336 if( _trainThisNet[
"b_net-2vtx"] && NumVertices==2 ){ _dataSet[
"b_net-2vtx"]->addDataItem( inputs, target );_numBackground[
"b_net-2vtx"]+=1;}
337 if( _trainThisNet[
"b_net-3vtx"] && NumVertices>=3 ){ _dataSet[
"b_net-3vtx"]->addDataItem( inputs, target );_numBackground[
"b_net-3vtx"]+=1;}
339 target.push_back( 1.0 );
340 if( _trainThisNet[
"c_net-1vtx"] && NumVertices==1 ){ _dataSet[
"c_net-1vtx"]->addDataItem( inputs, target );_numSignal[
"c_net-1vtx"]+=1;}
341 if( _trainThisNet[
"c_net-2vtx"] && NumVertices==2 ){ _dataSet[
"c_net-2vtx"]->addDataItem( inputs, target );_numSignal[
"c_net-2vtx"]+=1;}
342 if( _trainThisNet[
"c_net-3vtx"] && NumVertices>=3 ){ _dataSet[
"c_net-3vtx"]->addDataItem( inputs, target );_numSignal[
"c_net-3vtx"]+=1;}
343 if( _trainThisNet[
"bc_net-1vtx"] && NumVertices==1 ){ _dataSet[
"bc_net-1vtx"]->addDataItem( inputs, target );_numSignal[
"bc_net-1vtx"]+=1;}
344 if( _trainThisNet[
"bc_net-2vtx"] && NumVertices==2 ){ _dataSet[
"bc_net-2vtx"]->addDataItem( inputs, target );_numSignal[
"bc_net-2vtx"]+=1;}
345 if( _trainThisNet[
"bc_net-3vtx"] && NumVertices>=3 ){ _dataSet[
"bc_net-3vtx"]->addDataItem( inputs, target );_numSignal[
"bc_net-3vtx"]+=1;}
350 target.push_back( 0.0 );
351 if( _trainThisNet[
"b_net-1vtx"] && NumVertices==1 ){ _dataSet[
"b_net-1vtx"]->addDataItem( inputs, target );_numBackground[
"b_net-1vtx"]+=1;}
352 if( _trainThisNet[
"b_net-2vtx"] && NumVertices==2 ){ _dataSet[
"b_net-2vtx"]->addDataItem( inputs, target );_numBackground[
"b_net-2vtx"]+=1;}
353 if( _trainThisNet[
"b_net-3vtx"] && NumVertices>=3 ){ _dataSet[
"b_net-3vtx"]->addDataItem( inputs, target );_numBackground[
"b_net-3vtx"]+=1;}
354 if( _trainThisNet[
"c_net-1vtx"] && NumVertices==1 ){ _dataSet[
"c_net-1vtx"]->addDataItem( inputs, target );_numBackground[
"c_net-1vtx"]+=1;}
355 if( _trainThisNet[
"c_net-2vtx"] && NumVertices==2 ){ _dataSet[
"c_net-2vtx"]->addDataItem( inputs, target );_numBackground[
"c_net-2vtx"]+=1;}
356 if( _trainThisNet[
"c_net-3vtx"] && NumVertices>=3 ){ _dataSet[
"c_net-3vtx"]->addDataItem( inputs, target );_numBackground[
"c_net-3vtx"]+=1;}
383 std::vector<vertex_lcfi::util::Vector3> jetMomentums;
386 static int numWarningsNoMomentum=0;
390 lcio::LCCollection* pJetCollection=pEvent->getCollection( _JetCollectionName );
391 for(
int i=0; i<pJetCollection->getNumberOfElements(); ++i )
393 const double* mom=(
dynamic_cast<lcio::ReconstructedParticle*
>( pJetCollection->getElementAt(i) ))->getMomentum();
394 if( mom[0]==0 && mom[1]==0 && mom[2]==0 )
throw lcio::Exception(
"Jet momentum not defined" );
399 catch( lcio::Exception exception )
402 if( numWarningsNoMomentum<=2 )
404 std::cerr <<
"############################################################################\n"
405 <<
"# NeuralNetTrainerProcessor: #\n"
406 <<
"# Unable to get the data for the jet momentum because - #\n"
408 << exception.what() << std::endl
410 <<
"# Training will proceed with no cut on the jet momentum theta. #\n";
411 if ( numWarningsNoMomentum==2 ) std::cerr <<
"# NO FURTHER WARNINGS WILL BE DISPLAYED #\n";
412 std::cerr <<
"############################################################################" << std::endl;
413 ++numWarningsNoMomentum;
424 for( std::vector<vertex_lcfi::util::Vector3>::iterator iMom=jetMomentums.begin(); iMom<jetMomentums.end(); ++iMom )
427 double cosTheta=(*iMom).dot( zAxis );
430 if( cosTheta>0.866 || cosTheta<-0.866 )
return false;
441 void NeuralNetTrainerProcessor::end()
452 std::vector<int> nodes;
453 nodes.push_back(2 * nInputs - 2);
457 std::cout <<
"NeuralNetTrainer: " << _nAcceptedEvents <<
" of " << _nEvent <<
" events passed the cuts. See the documentation of NeuralNetTrainerProcessor::_passesCuts() for details of the cuts applied." << std::endl;
460 for( std::vector<std::string>::iterator iName=_listOfSelectedNetNames.begin(); iName<_listOfSelectedNetNames.end(); ++iName )
466 std::cout << std::endl <<
"Training neural net " << *iName <<
" with " << _dataSet[*iName]->numberOfDataItems()
467 <<
" jets " <<
"(" << _numBackground[*iName] <<
" background, " << _numSignal[*iName] <<
" signal)..." << std::endl;
470 std::ofstream outputFile( _filename[*iName].c_str() );
471 if( outputFile.is_open() )
474 _trainNet( myAlgorithm, *_dataSet[*iName] );
477 thisNeuralNet.setSerialisationMode( _outputFormat );
478 thisNeuralNet.serialise( outputFile );
483 std::cerr <<
"Unable to open file " << _filename[*iName] <<
"! Skipping training for this net." << std::endl;
488 std::cout <<
"Finished training all selected nets" << std::endl;
498 double PrevErr,CurrErr;
500 bool breakLoop=
false;
502 while( i<50 && breakLoop==
false )
508 backPropCGAlgo.train( 10, dataSet );
513 std::vector<double> epochErrors=backPropCGAlgo.getTrainingErrorValuesPerEpoch();
514 CurrErr=epochErrors.back();
521 if( std::isnan( CurrErr ) )
523 std::cerr <<
"NeuralNetTrainer.cc (line 523): Training the net gave an error of NaN! That's not good. Still looking\n"
524 <<
"into why this happens, most likely there's not enough difference between the tag variables of your\n"
525 <<
"signal and background data sets. Your net is going to be gibberish. Sorry." << std::endl;
535 const std::vector<std::string>* pCollectionNames=pEvent->getCollectionNames();
537 std::cout <<
"The available collections are: (name - type)" << std::endl;
538 for( std::vector<std::string>::const_iterator i=pCollectionNames->begin(); i<pCollectionNames->end(); ++i )
540 lcio::LCCollection* pCollection=pEvent->getCollection( (*i) );
541 const std::string typeName=pCollection->getTypeName();
542 std::cout <<
" " << (*i) <<
" - " << typeName << std::endl;
544 std::cout << std::endl;
void _trainNet(nnet::BackPropagationCGAlgorithm &pBackPropCGAlgo, nnet::NeuralNetDataSet &dataSet)
Trains neural networks to be used for jet flavour tagging.
static MemoryManager< T > * Run()
Returns the Run duration singleton instance of the MemoryManager for type T.
bool _passesCuts(lcio::LCEvent *pEvent)
void _displayCollectionNames(lcio::LCEvent *pEvent)