import ROOT
import os, os.path
import logging
m_log = logging.getLogger('SpartyLog')
outStr = logging.StreamHandler()
format = logging.Formatter("%(module)-18s  %(levelname)-8s  %(message)s")
outStr.setFormatter(format)
m_log.addHandler(outStr)


try :    
    sj_dir= os.environ['SPARTYJETDIR'] +'/libs/'
except:
    m_log.error('Environment variable SPARTYJETDIR not set. You need to do source setup.sh in top directory')
    import sys
    sys.exit(1)

ROOT.gSystem.Load('libPhysics.so')
# Load SpartyJet

_sj_libs=[   'libJetCore.so',                 
            'libJetTools.so' ,
            'libFastJet.so',
            'libATLASJet.so',
            'libCellJet.so',
            'libCDFJet.so',
            'libD0Jet.so',
            'libEventShape.so',
            'libSpartyDisplay.so'
        ]
for lib in _sj_libs:
    if os.path.exists(sj_dir+lib):
        ROOT.gSystem.Load(sj_dir+lib)
    else:
        print 'SpartyJetConfig   WARNING  missing',sj_dir+lib

from ROOT import SpartyJet as SJ 
from ROOT import fastjet   as fj

def syncMessageLevel():
    '''Gets spartyjet global message level
    and sets logger to same level'''
    ld ={ 
            SJ.DEBUG  :logging.DEBUG,
            SJ.INFO   :logging.INFO,
            SJ.WARNING:logging.WARNING,
            SJ.ERROR  :logging.ERROR
        }
    
    m_log.setLevel(ld[SJ.GlobalLevel()])


def createNtupleInputMaker(filename, treename="", inputprefix="GUESS", Nvar="",
                           momentumVars=(), inputType=None, masslessMode=False, inputsuffix=""):
    """ Create & configure a NtupleInputMaker class
    The function will try to guess eveything from the given filename """
    
    syncMessageLevel()

    f = ROOT.TFile(filename)
    if not f.IsOpen():
        m_log.error('Could not open ROOT file with name ', filename)
        return

    # ---------------------------------
    # get tree
    tree = None
    if treename == "":
        # use the first TTree found
        keys = [k.GetName() for k in f.GetListOfKeys( ) ]
        for k in keys:
            t = f.Get(k)
            if isinstance(t, ROOT.TTree):
                tree = t
                treename = k
                break
    else:
        tree = f.get(treename)
    if not bool(tree) :
        m_log.error("Couldn't find tree in ",filename)
        return
    # ---------------------------------

    branches = [b.GetName() for b in  tree.GetListOfBranches( ) ]


    # ---------------------------------
    # Guess input prefix if not set
    if inputprefix=="GUESS":
        # branches, lower case
        for bn in branches:
        # we'll look for  vars starting with 'input'
            if bn.lower().startswith('input') and '_' in bn:
                inputprefix = bn[:bn.find('_')] # get the part before '_'
                break
        # we'll look for  vars without prefix'
            if bn.lower() == 'eta' or bn.lower() == 'px':
                inputprefix = ''
                break
        if inputprefix=="GUESS":
            m_log.error("Couldn't guess proper prefix for input variables")
            return
        else:
            m_log.info("Found prefix input = "+inputprefix)
    # ---------------------------------



    if inputprefix == '':
        # retrieve all variables starting with inputprefix
        branches = [ b for b in branches if '_' not in b ]
        # retrieve all vars from the branch name above : the XX part in bla_XX
        vars = dict( ( b.lower(),b ) for b in branches )
    else:
        if not inputprefix.endswith('_'): inputprefix += '_'
        # retrieve all variables starting with inputprefix
        branches = [ b for b in branches if b.startswith(inputprefix) ]
        # retrieve all vars from the branch name above : the XX part in bla_XX
        vars = dict( ( b[b.find('_')+1:].lower(),b[b.find('_')+1:] ) for b in branches )

    # ---------------------------------
    # Guess variable 
    # look for N
    if Nvar == "":
        for nName in ['n','num','nparticle']:
            if nName in vars:
                Nvar = vars[nName]
        if Nvar == "":
            m_log.error("Couldn't guess proper input_n variable")
            return
        else:
            m_log.info("Found prefix input_n = "+inputprefix+Nvar)
    # ---------------------------------


    # ---------------------------------
    # Guess kinematic variables 
    if momentumVars==():
        # try px,py,pz,e
        vars_set = set(vars.keys())
        if vars_set.issuperset(set(['px','py','pz','e'])):
            momentumVars = tuple( vars[k] for k in ('px','py','pz','e') )
        elif vars_set.issuperset(set(['eta','phi','pt','e'])):
            momentumVars = tuple( vars[k] for k in ('eta','phi','pt','e') )
        elif vars_set.issuperset(set(['eta','phi','p_t','e'])):
            momentumVars = tuple( vars[k] for k in ('eta','phi','p_t','e') )
        elif vars_set.issuperset(set(['eta','phi','pt','m'])):
            momentumVars = tuple( vars[k] for k in ('eta','phi','pt','m') )

        if momentumVars == ():
            m_log.error("Couldn't guess  kinematic input variables")
            return
        else:
            m_log.info("Found kinematic input = "+str(momentumVars))
    # ---------------------------------
    # Guess input type
    if inputType == None:
        vtype = _branchType(tree.GetBranch(inputprefix+momentumVars[0]))
        momkey = (momentumVars[0]+momentumVars[3]).lower()
        inputType = {
            'pxevector_double' :  SJ.NtupleInputMaker.PxPyPzE_vector_double ,  
            'pxevector_float' :   SJ.NtupleInputMaker.PxPyPzE_vector_float,
            'pxearray_double' :  SJ.NtupleInputMaker.PxPyPzE_array_double,
            'pxearray_float' :   SJ.NtupleInputMaker.PxPyPzE_array_float,
            'etaevector_double' : SJ.NtupleInputMaker.EtaPhiPtE_vector_double,  
            'etaevector_float' :  SJ.NtupleInputMaker.EtaPhiPtE_vector_float,
            'etaearray_double' : SJ.NtupleInputMaker.EtaPhiPtE_array_double,
            'etaearray_float' :  SJ.NtupleInputMaker.EtaPhiPtE_array_float,
            'etamvector_double' : SJ.NtupleInputMaker.EtaPhiPtM_vector_double,
            'etamvector_float' :  SJ.NtupleInputMaker.EtaPhiPtM_vector_float,
            'etamarray_double' : SJ.NtupleInputMaker.EtaPhiPtM_array_double,
            'etamarray_float' :  SJ.NtupleInputMaker.EtaPhiPtM_array_float,
            }[momkey+vtype]
        m_log.info("Input variables type = "+vtype)
        

    input = SJ.NtupleInputMaker(inputType)
    input.set_prefix(inputprefix)
    input.set_n_name(Nvar)
    input.set_variables(*momentumVars)
    input.setFileTree(filename, treename)
    input.set_name("InputJet")
    input.set_masslessMode(masslessMode) # 

    return input



_keep_alive = []

def addFinalCut(builder, ptCut ):
    finalPtCut = SJ.JetPtSelectorTool(ptCut)
    builder.add_jetTool(finalPtCut)
    _keep_alive.append(finalPtCut)




# utils ...
def _branchType(b):    
    if isinstance(b, ROOT.TBranchElement):
        if 'double' in b.GetTypeName():
            return 'vector_double'
        else:
            return 'vector_float'            
    else: # assume it's a TBranch
        if '/F' in b.GetTitle( ):
            return 'array_float'
        else:
            return 'array_double'


# Wrapper for functions that take vectors as arguments
def stdVector(*args):
    '''Make a std::vector from a python list'''
    from ROOT import std
    floatType = False
    for i in args: floatType |= isinstance(i,float)
    
    if floatType:
        vec = std.vector( float )()
        for i in args: vec.push_back(i)
        return vec
    else:
        vec = std.vector( int )()
        for i in args: vec.push_back(i)
        return vec
        
            
