// @(#)root/tmva $Id$ // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen /***************************************************************************** * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * * Package: TMVA * * Class : MethodCompositeBase * * Web : http://tmva.sourceforge.net * * * * Description: * * Virtual base class for all MVA method * * * * Authors (alphabetical): * * Andreas Hoecker - CERN, Switzerland * * Joerg Stelzer - MSU, USA * * Helge Voss - MPI-K Heidelberg, Germany * * Kai Voss - U. of Victoria, Canada * * Or Cohen - Weizmann Inst., Israel * * * * Copyright (c) 2005: * * CERN, Switzerland * * U. of Victoria, Canada * * MPI-K Heidelberg, Germany * * LAPP, Annecy, France * * * * Redistribution and use in source and binary forms, with or without * * modification, are permitted according to the terms listed in LICENSE * * (http://tmva.sourceforge.net/LICENSE) * *****************************************************************************/ //_______________________________________________________________________ // // This class is virtual class meant to combine more than one classifier// // together. The training of the classifiers is done by classes that are// // derived from this one, while the saving and loading of weights file // // and the evaluation is done here. // //_______________________________________________________________________ #include #include #include #include "Riostream.h" #include "TRandom3.h" #include "TMath.h" #include "TObjString.h" #include "TMVA/MethodCompositeBase.h" #include "TMVA/MethodBoost.h" #include "TMVA/MethodBase.h" #include "TMVA/Tools.h" #include "TMVA/Types.h" #include "TMVA/Factory.h" #include "TMVA/ClassifierFactory.h" using std::vector; ClassImp(TMVA::MethodCompositeBase) //_______________________________________________________________________ TMVA::MethodCompositeBase::MethodCompositeBase( const TString& jobName, Types::EMVA methodType, const TString& methodTitle, DataSetInfo& theData, const TString& theOption, TDirectory* theTargetDir ) : TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption, theTargetDir ), fCurrentMethodIdx(0), fCurrentMethod(0) {} //_______________________________________________________________________ TMVA::MethodCompositeBase::MethodCompositeBase( Types::EMVA methodType, DataSetInfo& dsi, const TString& weightFile, TDirectory* theTargetDir ) : TMVA::MethodBase( methodType, dsi, weightFile, theTargetDir ), fCurrentMethodIdx(0), fCurrentMethod(0) {} //_______________________________________________________________________ TMVA::IMethod* TMVA::MethodCompositeBase::GetMethod( const TString &methodTitle ) const { // returns pointer to MVA that corresponds to given method title std::vector::const_iterator itrMethod = fMethods.begin(); std::vector::const_iterator itrMethodEnd = fMethods.end(); for (; itrMethod != itrMethodEnd; itrMethod++) { MethodBase* mva = dynamic_cast(*itrMethod); if ( (mva->GetMethodName())==methodTitle ) return mva; } return 0; } //_______________________________________________________________________ TMVA::IMethod* TMVA::MethodCompositeBase::GetMethod( const Int_t index ) const { // returns pointer to MVA that corresponds to given method index std::vector::const_iterator itrMethod = fMethods.begin()+index; if (itrMethod(fMethods[i]); gTools().AddAttr(methxml,"Index", i ); gTools().AddAttr(methxml,"Weight", fMethodWeight[i]); gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut()); gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation()); gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName()); gTools().AddAttr(methxml,"MethodName", method->GetMethodName() ); gTools().AddAttr(methxml,"JobName", method->GetJobName()); gTools().AddAttr(methxml,"Options", method->GetOptions()); if (method->fTransformationPointer) gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true")); else gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false")); method->AddWeightsXMLTo(methxml); } } //_______________________________________________________________________ TMVA::MethodCompositeBase::~MethodCompositeBase( void ) { // delete methods std::vector::iterator itrMethod = fMethods.begin(); for (; itrMethod != fMethods.end(); itrMethod++) { Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl; delete (*itrMethod); } fMethods.clear(); } //_______________________________________________________________________ void TMVA::MethodCompositeBase::ReadWeightsFromXML( void* wghtnode ) { // XML streamer UInt_t nMethods; TString methodName, methodTypeName, jobName, optionString; for (UInt_t i=0;i needs to be fixed later ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString ); } fMethods.push_back(ClassifierFactory::Instance().Create( std::string(methodTypeName),jobName, methodName,DataInfo(),optionString)); fMethodWeight.push_back(methodWeight); MethodBase* meth = dynamic_cast(fMethods.back()); if(meth==0) Log() << kFATAL << "Could not read method from XML" << Endl; void* methXML = gTools().GetChild(ch); meth->SetupMethod(); meth->SetMsgType(kWARNING); meth->ParseOptions(); meth->ProcessSetup(); meth->CheckSetup(); meth->ReadWeightsFromXML(methXML); meth->SetSignalReferenceCut(methodSigCut); meth->SetSignalReferenceCutOrientation(methodSigCutOrientation); meth->RerouteTransformationHandler (&(this->GetTransformationHandler())); ch = gTools().GetNextChild(ch); } //Log() << kINFO << "Reading methods from XML done " << Endl; } //_______________________________________________________________________ void TMVA::MethodCompositeBase::ReadWeightsFromStream( std::istream& istr ) { // text streamer TString var, dummy; TString methodName, methodTitle=GetMethodName(), jobName=GetJobName(),optionString=GetOptions(); UInt_t methodNum; Double_t methodWeight; // and read the Weights (BDT coefficients) // coverity[tainted_data_argument] istr >> dummy >> methodNum; Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl; for (UInt_t i=0;i> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight; if ((UInt_t)fCurrentMethodIdx != i) { Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex=" << fCurrentMethodIdx << " i=" << i << " MethodName " << methodName << " dummy " << dummy << " MethodWeight= " << methodWeight << Endl; } if (GetMethodType() != Types::kBoost || i==0) { istr >> dummy >> jobName; istr >> dummy >> methodTitle; istr >> dummy >> optionString; if (GetMethodType() == Types::kBoost) ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString ); } else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx); fMethods.push_back(ClassifierFactory::Instance().Create( std::string(methodName), jobName, methodTitle,DataInfo(), optionString) ); fMethodWeight.push_back( methodWeight ); if(MethodBase* m = dynamic_cast(fMethods.back()) ) m->ReadWeightsFromStream(istr); } } //_______________________________________________________________________ Double_t TMVA::MethodCompositeBase::GetMvaValue( Double_t* err, Double_t* errUpper ) { // return composite MVA response Double_t mvaValue = 0; for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i]; // cannot determine error NoErrorCalc(err, errUpper); return mvaValue; }