#include "FastKtParam.h"
#include "FastKtUtils.h"
#include "FastKtHull.h"
#include "JetCore/JetDistances.hh"
namespace SpartyJet { 

namespace FastKtJet {

  // *****************************************************************  
  // Distances -------------------------------------------------------
  // *****************************************************************  
  // Usual DeltaR distance
  KtDistanceDeltaR::KtDistanceDeltaR() :  m_name("DeltaR") {}
  std::string KtDistanceDeltaR::name() const {return m_name;}

  float KtDistanceDeltaR::operator()(const KtJetInfo* a) const {
    return  a->pt2; // If e+e-, no beam remnant, so result will be ignored anyway
  }

  float KtDistanceDeltaR::operator()(const KtJetInfo* a, const KtJetInfo* b) const {
    float rsq,esq,kt,deltaEta,deltaPhi;
    deltaEta = a->eta - b->eta;
    deltaPhi = JetDistances::deltaPhi(a->phi,b->phi);
    rsq = deltaEta*deltaEta + deltaPhi*deltaPhi;
    esq = std::min(a->pt2 , b->pt2 );
    kt = esq*rsq;
    return kt;
  }

  inline float KtDistanceDeltaR::geodist(const KtJetInfo* a, const KtJetInfo * b) const {
    float rsq,deltaEta,deltaPhi;
    deltaEta = a->eta -b->eta ;
    deltaPhi = JetDistances::deltaPhi(a->phi,b->phi);
    rsq = deltaEta*deltaEta + deltaPhi*deltaPhi;
    return rsq;
  }


  // 
  KtDistanceAngular::KtDistanceAngular() :  m_name("Angular") {}
  std::string KtDistanceAngular::name() const {return m_name;}

  float KtDistanceAngular::operator()(const KtJetInfo* /*a*/) const {
    return  1; 
  }

  float KtDistanceAngular::operator()(const KtJetInfo* a, const KtJetInfo* b) const {
    //     float rsq,deltaEta,deltaPhi;
    //     deltaEta = a->eta - b->eta;
    //     deltaPhi = JetDistances::deltaPhi(a->phi,b->phi);
    //     rsq = deltaEta*deltaEta + deltaPhi*deltaPhi;
    return KtDistanceAngular::geodist(a,b);
    //return rsq;
  }

  inline float KtDistanceAngular::geodist(const KtJetInfo* a, const KtJetInfo * b) const {
    float rsq,deltaEta,deltaPhi;
    deltaEta = a->eta -b->eta ;
    deltaPhi = JetDistances::deltaPhi(a->phi,b->phi);
    rsq = deltaEta*deltaEta + deltaPhi*deltaPhi;
    return rsq;
  }


  // Reversed DeltaR distance
  KtDistanceReversed::KtDistanceReversed() :  m_name("Reversed") {}
  std::string KtDistanceReversed::name() const {return m_name;}

  float KtDistanceReversed::operator()(const KtJetInfo* a) const {
    return  a->pt2; // If e+e-, no beam remnant, so result will be ignored anyway
  }

  float KtDistanceReversed::operator()(const KtJetInfo* a, const KtJetInfo* b) const {
    float rsq,esq,kt,deltaEta,deltaPhi;
    deltaEta = a->eta - b->eta;
    deltaPhi = JetDistances::deltaPhi(a->phi,b->phi);
    rsq = deltaEta*deltaEta + deltaPhi*deltaPhi;
    esq = std::max(a->pt2 , b->pt2 ); // reverse here
    kt = esq/rsq;                     // and here
    return kt;
  }

  inline float KtDistanceReversed::geodist(const KtJetInfo* a, const KtJetInfo * b) const {
    float rsq,deltaEta,deltaPhi;
    deltaEta = a->eta -b->eta ;
    deltaPhi = JetDistances::deltaPhi(a->phi,b->phi);
    rsq = deltaEta*deltaEta + deltaPhi*deltaPhi;
    return rsq;
  }

  // *****************************************************************  
  // Recomb. schemes  --------------------------------------------------
  // *****************************************************************  

  // E scheme ------
  KtRecomE::KtRecomE() : m_name("E") {}
  std::string KtRecomE::name() const {return m_name;}

  void KtRecomE::combine( KtJetInfo * jet,  KtJetInfo * jet_toadd) {    
    jet->hlv += jet_toadd->hlv;


    jet->eta = jet->hlv.rapidity();
    jet->phi = jet->hlv.phi();
    jet->pt2 = jet->hlv.Perp2();
    
    std::list<Jet *>::iterator itB = jet_toadd->constit_list.begin();
    std::list<Jet *>::iterator itE = jet_toadd->constit_list.end();
    jet->constit_list.insert(jet->constit_list.end(),itB,itE);
    
  }
  // Pt scheme ------
  KtRecomPt::KtRecomPt() : m_name("Pt") {}
  std::string KtRecomPt::name() const {return m_name;}

  void KtRecomPt::combine( KtJetInfo * jet,  KtJetInfo * jet_toadd) {
    jet->hlv += jet_toadd->hlv;

    float pti =  sqrt(jet->pt2) ;
    float ptj =  sqrt(jet_toadd->pt2) ;
    float newPt = pti + ptj ;

    jet->eta = (jet->eta*pti + jet_toadd->eta*ptj)/newPt;
    jet->phi = (jet->phi*pti + jet_toadd->phi*ptj)/newPt;
    jet->pt2 = newPt*newPt;

    std::list<Jet *>::iterator itB = jet_toadd->constit_list.begin();
    std::list<Jet *>::iterator itE = jet_toadd->constit_list.end();
    jet->constit_list.insert(jet->constit_list.end(),itB,itE);
  }

 

  // *****************************************************************  
  // Kt Alg
  // *****************************************************************
  inline void KtAlgoStandard::processStep(){
    m_lastDPair = m_ktList->getMinDPair();
    m_lastDJet = m_ktList->getMinDJet() ;
    m_doExtractJet = (m_lastDJet * m_rParameterSq <= m_lastDPair); 
    m_extractedJet = m_ktList->getMinJet();
    std::pair<int,int> p = m_ktList->getMinPairIndex();
    m_lastIndex1 = p.first;
    m_lastIndex2 = p.second;
    //std::cout << m_doExtractJet << " dpair=" << m_lastDPair << "  djet="<< m_lastDJet << "   "<< m_rParameterSq<< std::endl;
    if(m_doExtractJet){
      m_ktList->killMinJet();
    }else{
      m_ktList->mergeMinJets();
    }
  }

  inline void KtAlgoReversed::processStep(){
    m_lastDPair = m_ktList->getMaxDPair();
    m_lastDJet = m_ktList->getMaxDJet() ;
    m_doExtractJet = ( (m_lastDJet * m_rParameterSq) > m_lastDPair); 
    m_extractedJet = m_ktList->getMaxJet();
    std::pair<int,int> p = m_ktList->getMaxPairIndex();
    m_lastIndex1 = p.first;
    m_lastIndex2 = p.second;
    if(m_doExtractJet){
      m_ktList->killMaxJet();
    }else{
      m_ktList->mergeMaxJets();
    }
  }

  KtAlgoStandard::KtAlgoStandard() {
    m_ktDist = NULL ;
    m_ktRecom= NULL ;
    m_ktList = NULL ;
  }
  KtAlgoStandard::KtAlgoStandard(double rParameter) { 
    m_rParameterSq = (rParameter*rParameter) ;
    m_ktDist = NULL ;
    m_ktRecom= NULL ;
    m_ktList = NULL ;
  }
  KtAlgoStandard::~KtAlgoStandard(){
    delete m_ktList;
    delete m_ktDist;
    delete m_ktRecom;
  }
  

  void KtAlgoStandard::init(jetcollection_t *constituents){
    if(m_ktList) delete m_ktList;
    m_ktList = new KtLists(constituents, m_ktDist, m_ktRecom, false);    
  }
  bool KtAlgoStandard::continueClustering(){
    bool docontinue = (m_ktList->getNJets() > 1);
    if (!docontinue) return false;
    
    processStep();
    if(m_doExtractJet) {
      m_finalJets->push_back(m_extractedJet);
    }

    return true;
  }
  
  void KtAlgoStandard::endClustering(){
    m_lastDJet = m_ktList->getMinDPair() ;
    m_extractedJet = m_ktList->getMinJet();
    m_finalJets->push_back(m_extractedJet);
    //m_ktList->killMinJet();
  }

  void KtAlgoStandard::setDistanceType(KtDistance::KtDistanceType distType) {
    if(m_ktDist) delete m_ktDist;
    m_ktDist = getDistanceScheme(distType); 
  }
  void KtAlgoStandard::setRecomType(KtRecom::KtRecomType recomType) {
    if(m_ktRecom) delete m_ktRecom;
    m_ktRecom = getRecomScheme(recomType); 
  }

  // -----------------------------------------------------

  bool KtAlgoReversed::continueClustering(){
    bool docontinue = (m_ktList->getNJets() > 1);
    if (!docontinue) return false;
    
    processStep();
    if(m_doExtractJet) m_finalJets->push_back(m_extractedJet);
    return true;
  }
  
  void KtAlgoReversed::endClustering(){
    m_lastDJet = m_ktList->getMaxDPair() ;
    m_extractedJet = m_ktList->getMaxJet();
    m_finalJets->push_back(m_extractedJet);
    //m_ktList->killMaxJet();
  }
  void KtAlgoReversed::setDistanceType(KtDistance::KtDistanceType /*distType*/) {
    if(m_ktDist) delete m_ktDist;
    m_ktDist = getDistanceScheme(KtDistance::Reversed); 
  }

  // -----------------------------------------------------
  bool KtAlgoStandardExclN::continueClustering(){
    bool docontinue = (m_ktList->getNJets() > m_Njet);
    if (!docontinue) return false;

    processStep();
    if(m_doExtractJet) m_rejectedJets->push_back(m_extractedJet);
    return true;
  }

  void KtAlgoStandardExclN::endClustering(){
    while (m_ktList->getNJets() >1 ) {
      m_finalJets->push_back( m_ktList->getMinJet() ); 
      m_ktList->killMinJet();                                                                                                         
    }
    m_finalJets->push_back( m_ktList->getMinJet() );     
  }
  // -----------------------------------------------------
  bool KtAlgoStandardExclD::continueClustering(){
    bool docontinue = (m_ktList->getMinDPair() < m_Ycut);
    if (!docontinue) return false;

    processStep();
    if(m_doExtractJet) m_rejectedJets->push_back(m_extractedJet);
    return true;
  }

  void KtAlgoStandardExclD::endClustering(){
    while (m_ktList->getNJets() >1 ) {
      m_finalJets->push_back( m_ktList->getMinJet() ); 
      m_ktList->killMinJet();
    }
    m_finalJets->push_back( m_ktList->getMinJet() ); 
  }



  // *****************************************************************  
  // Get functions
  // *****************************************************************  

  KtDistance* getDistanceScheme(int distType ) {
    switch(distType){
    case KtDistance::DeltaR :
      return new KtDistanceDeltaR();
      break;
    case KtDistance::Angular :
      return new KtDistanceAngular();
    case KtDistance::Reversed :
      return new KtDistanceReversed();
      break;
    default:
      {
	std::cout << "FastKtJet::KtDistance WARNING, unreconised distance scheme specified! : " << distType<<std::endl;
	std::cout << "                      Distance Scheme set to KtDistanceDeltaR" << std::endl;
	return new KtDistanceDeltaR();
      }
      
    }
    
  }

  KtRecom* getRecomScheme(int recom) {
    switch(recom){
    case KtRecom::E :
      return new KtRecomE();
      break;
    case KtRecom::Pt :
      return new KtRecomPt();
      break;      
    default :
      {
	std::cout << "WARNING, unreconised recombination scheme specified!" << std::endl;
	std::cout << "Recombination Scheme set to KtRecomE" << std::endl;
	return new KtRecomE();
      }
    }
  }

  /// Returns a default KtAlgo. Can be re-parametruzed if needed
  KtAlgo * getKtAlgo(int type, double rparam){
    KtAlgo * a=0;
    switch(type){
    case KtAlgo::Standard :
      a = new KtAlgoStandard(rparam);
      a->setDistanceType(KtDistance::DeltaR);
      break;
    case KtAlgo::Aachen :
      a = new KtAlgoStandard(rparam);
      a->setDistanceType(KtDistance::Angular);
      break;
    case KtAlgo::Reversed :
      a = new KtAlgoReversed(rparam);
      a->setDistanceType(KtDistance::Reversed);
      break;
    case KtAlgo::StandardExcluN :
      a = new KtAlgoStandardExclN(rparam,5); 
      a->setDistanceType(KtDistance::DeltaR);
      break;
    case KtAlgo::StandardExcluD :
      a = new KtAlgoStandardExclD(rparam,0.0); 
      a->setDistanceType(KtDistance::DeltaR);
      break;
    case KtAlgo::Hull :
      a = new KtAlgoHull(rparam);
      a->setDistanceType(KtDistance::Angular);
      break;

    }
    a->setRecomType(KtRecom::E);
    return a;
  }

}  // namespace SpartyJet
}
