/* newinterp.inc
 * Daniel S. Roche, January 2011
 * See COPYING.txt for permissions.
 *
 * Black-box interpolation over a finite field using diversification.
 * See "Diversification improves interpolation", Giesbrecht & Roche, 2011
 *
 * Include file (template implementations)
 */

#include <cmath>
#include <utility>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <NTL/ZZX.h>
#include <NTL/lzz_pX.h>
#include <NTL/ZZ_pE.h>
#include <NTL/lzz_pX.h>
#include <NTL/lzz_pE.h>
#include <NTL/ZZXFactoring.h>
#include "misc.h"

NTL_OPEN_NNS

// For STL sorting of pointers to objects. Avoids potentially costly copying.
template <typename T>
struct PointerSort : public std::binary_function<T*, T*, bool> {
  bool operator() (const T* a, const T* b) const {
    return *a < *b;
  }
};

// Specialization of PointerSort for ZZ_p
template <>
bool PointerSort<ZZ_p>::operator() (const ZZ_p* a, const ZZ_p* b) const 
  { return rep(*a) < rep(*b); }

// Specialization of PointerSort for zz_p
template <>
bool PointerSort<zz_p>::operator() (const zz_p* a, const zz_p* b) const 
  { return rep(*a) < rep(*b); }

// Specialization of PointerSort for ZZ_pE
template <>
bool PointerSort<ZZ_pE>::operator() (const ZZ_pE* a, const ZZ_pE* b) const  {
  if (deg(rep(*a)) != deg(rep(*b)))
    return deg(rep(*a)) < deg(rep(*b));
  for (long i=0; i<=deg(rep(*a)); ++i) {
    if (rep(*a).rep[i] != rep(*b).rep[i])
      return rep(rep(*a).rep[i]) < rep(rep(*a).rep[i]);
  }
  return false;
}

// Specialization of PointerSort for zz_pE
template <>
bool PointerSort<zz_pE>::operator() (const zz_pE* a, const zz_pE* b) const 
  { return false; }

/* Tests whether the given polynomial is "diverse", i.e.
 * has all coefficients pairwise distinct.
 */
template <typename Poly, typename Base>
bool is_diverse (const Poly& f) {
  std::set< const Base*,PointerSort<Base> > coeffs;
  for (long i=0; i<=deg(f); ++i) {
    if (!IsZero(coeff(f,i)) &&
        !coeffs.insert(&(f.rep[i])).second) 
      return false;
  }
  return true;
}

/* Probabilistic interpolation method "A".
 * The verification step is not included; the algorithm is Monte Carlo.
 * f: will hold the output
 * BBT: should be a subclass of UniModBB
 * bb: uni-modular black box for unknown polynomial f
 * D: upper bound on degree of f
 * T: upper bound on sparsity of f
 */
template <typename BBT>
void new_interpA 
  (SparsePoly< typename BBT::PolyT, typename BBT::BaseT >& f, 
   BBT& bb, ZZ D, long T)
{
  long i,j;

  // Compute lambda s.t. a prime in the range lambda..2*lambda is good w.h.p.
  long lambda = goodp_bound (D, T);
  long goodplen = NumBits(lambda) + 1;

  // Polynomials for the black box evaluation
  typename BBT::PolyT alphax;
  SetX(alphax);
  typename BBT::PolyT xpm1;
  typename BBT::BaseT minusone;
  conv (minusone, -1);
  SetCoeff (xpm1, 0, minusone);
  typename BBT::BaseT zero;
  conv (zero, 0);

  // Use NTL to choose random primes that are good w.h.p.,
  // as well as random alpha s.t. f(alpha x) is diverse w.h.p.
  // goodevals holds black box evaluations diverse with sparsity t.
  // Once the product of primes is larger than D, we move on.
  typename BBT::PolyT fp;
  long t = 0;
  long p;
  long sfp;
  std::vector< std::pair<long, typename BBT::PolyT> > goodevals;
  ZZ goodprod;
  conv(goodprod,1); // product of good primes
  typename BBT::BaseT& alpha = alphax.rep[1];
  random(alpha);
  bool diverse = false;
#ifdef VERBOSE
  long nbad = 0;
  long ngood = 0;
  long nbadalpha = 0;
  long ntotal = 0;
#endif

  while (goodprod <= D) {
    do { p = GenPrime_long (goodplen); } while (divide(goodprod,p));
    SetCoeff (xpm1, p);
    bb.eval(fp, alphax, xpm1);
    SetCoeff (xpm1, p, zero);
    sfp = sparsity(fp);
    if (sfp > t) {
#ifdef VERBOSE
      nbad += goodevals.size();
#endif
      goodevals.clear();
      conv (goodprod,1);
      t = sfp;
      diverse = false;
    }
    if (sfp == t) {
      if (!diverse && is_diverse<typename BBT::PolyT,typename BBT::BaseT>(fp))
        diverse = true;
      if (diverse) {
        goodevals.resize (goodevals.size()+1);
        goodevals.back().first = p;
        goodevals.back().second = fp;
        goodprod *= p;
      }
#ifdef VERBOSE
      else ++ngood;
#endif
    }
#ifdef VERBOSE
    else ++nbad;
    ++ntotal;
    if (!diverse) ++nbadalpha;
#endif
    if (!diverse) random(alpha);
  }
#ifdef VERBOSE
  std::cout << "Sampled mod x^p-1 for " 
            << ngood+goodevals.size() << " good primes and "
            << nbad << " bad primes" << std::endl
            << "in the range " << (1L<<(goodplen-1)) << " <= p <= "
            << ((1L<<goodplen)-1) << std::endl;
  std::cout << "Tried " << nbadalpha << " bad alphas before finding "
            << "f(" << alpha << "x) is diverse." << std::endl;
  std::cout << (ngood ? "Only " : "All ") << goodevals.size()
            << " good prime samples were actually used." << std::endl;
#endif

  // For some reason vector CRT in NTL is not defined for a vector type,
  // so we use a build a ZZX object to hold the exponents of f
  // (as coefficients, not as roots).
  ZZX expons;
  zz_pX expons_p;
  expons_p.rep.SetLength(t);
  std::vector<long> expons_vec(t);
  zz_pBak bak;
  ZZ pprod;
  typename std::vector< std::pair<long, typename BBT::PolyT> >::const_iterator
    iter = goodevals.begin();

  // Use the first good evaluation to build an stl map of
  // coefficients to indices. This will be used to match up the terms
  // in different evaluations.
  // We also initialize expons with this first image modulo p.
  std::map< const typename BBT::BaseT*, long, PointerSort<typename BBT::BaseT> >
    cmap;
  const typename BBT::BaseT* cptr;
  i = 0;
  for (j=0; j<=deg(iter->second); ++j) {
    cptr = &(iter->second.rep[j]);
    if (!IsZero(*cptr)) {
      SetCoeff(expons,i,j);
      cmap[cptr] = i++;
    }
  }
  conv(pprod,iter->first);

  while (++iter != goodevals.end()) {
    for (j=0; j<=deg(iter->second); ++j) {
      cptr = &(iter->second.rep[j]);
      if (!IsZero(*cptr))
        expons_vec[ cmap[cptr] ] = j;
    }
    // Now we have to change the global NTL prime by "pushing" onto the "stack"
    bak.save();
    zz_p::init(iter->first);
    // Now copy expons_vec to expons_p
    for (i=0; i<t; ++i)
      conv (expons_p.rep[i], expons_vec[i]);
    // Update expons with the new modular image at each coefficient
    CRT (expons, pprod, expons_p);
    bak.restore(); // "pop" our prime off the "stack"
  }

  // NTL does Chinese remaindering in the symmetric range, so we might need
  // to "fix" negative exponents (which are coefficients of expons).
  for (i=0; i<t; ++i) {
    if (sign(coeff(expons,i)) < 0)
      expons.rep[i] += goodprod;
  }

  // We want to get the exponents in sorted order.
  // So make a vector of pointers to each coefficient of expons, then sort it.
  std::vector<ZZ*> exptrs(t);
  for (long i=0; i<t; ++i) exptrs[i] = &(expons.rep[i]);
  std::sort (exptrs.begin(), exptrs.end(), PointerSort<ZZ>() );

  // Now extract the exponents from factors and store them in f.
  // Simultaneously, use a single good-prime evaluation to
  // get the coefficients.
  // Each coefficient must be divided by alpha^i.
  std::vector<typename BBT::BaseT> ainvpows(1); // ainvpows[i] = alpha^-(2^i)
  inv (ainvpows[0], alpha);
  typename BBT::BaseT curaipow;
  conv(curaipow,1);
  ZZ curpow; // invariant: curaipow = a^-curpow
  ZZ diffpow;

  f.rep.resize(t);
  typename SparsePoly<typename BBT::PolyT, typename BBT::BaseT>::RepT::iterator 
    fiter = f.rep.begin();
  for (std::vector<ZZ*>::const_iterator eiter = exptrs.begin();
       eiter != exptrs.end(); ++eiter) {
    fiter->second = **eiter;
    fiter->first = 
      coeff (goodevals.front().second, 
             rem(fiter->second,goodevals.front().first));
    
    // Compute alpha^-(e_i)
    diffpow = fiter->second;
    diffpow -= curpow;
    while (static_cast<unsigned long>(NumBits(diffpow)) > ainvpows.size()) {
      // grow ainvpows if necessary for this step
      ainvpows.resize(ainvpows.size()+1);
      sqr (ainvpows[ainvpows.size()-1], ainvpows[ainvpows.size()-2]);
    }
    while (!IsZero(diffpow)) {
      curaipow *= ainvpows[NumBits(diffpow)-1];
      SwitchBit (diffpow, NumBits(diffpow)-1);
    }
    curpow = fiter->second;
    fiter->first *= curaipow;
    ++fiter;
  }
}

/* Probabilistic interpolation method "B".
 * f: will hold the output
 * BBT: should be a subclass of UniModBB
 * bb: uni-modular black box for unknown polynomial f
 * D: upper bound on degree of f
 * T: upper bound on sparsity of f
 * Return: false if a known failure occurred, otherwise true
 */
template <typename BBT>
bool new_interpB 
  (SparsePoly< typename BBT::PolyT, typename BBT::BaseT >& f, 
   BBT& bb, ZZ D, long T)
{
  long i,j;

  // Compute lambda s.t. a prime in the range lambda..2*lambda is good w.h.p.
  long lambda = goodp_bound (D, T);
  long goodplen = NumBits(lambda) + 1;

  // Compute a guess as to where we should start searching for good primes
  long guess = guessp_bound (D, T);

  // Polynomials for the black box evaluation
  typename BBT::PolyT alphax;
  SetX(alphax);
  typename BBT::PolyT xpm1;
  typename BBT::BaseT minusone;
  conv (minusone, -1);
  SetCoeff (xpm1, 0, minusone);
  typename BBT::BaseT zero;
  conv (zero, 0);

  // Use NTL to choose random primes that are good w.h.p.,
  // as well as random alpha s.t. f(alpha x) is diverse w.h.p.
  // goodevals holds black box evaluations diverse with sparsity t.
  // Once the product of primes is larger than D, we move on.
  typename BBT::PolyT firstfp;
  long t = 0;
  long p;
  long sfp;
  ZZ goodprod;
  typename BBT::BaseT& alpha = alphax.rep[1];
  random(alpha);
#ifdef VERBOSE
  long nbad = 0;
  long nbadalpha = 0;
  long ngood = 0;
#endif

  // First we do one "big prime" evaluation
  // which gets the sparsity t correct w.h.p.
  // and chooses a good alpha as well.
  long firstp = GenPrime_long (goodplen);
  conv(goodprod,firstp);
  SetCoeff (xpm1, firstp);
  bb.eval(firstfp, alphax, xpm1);
  SetCoeff (xpm1, firstp, zero);
  if (!is_diverse<typename BBT::PolyT,typename BBT::BaseT>(firstfp)) 
    return false;

  // For some reason vector CRT in NTL is not defined for a vector type,
  // so we use a build a ZZX object to hold the exponents of f
  // (as coefficients, not as roots).
  ZZX expons;

  // Use the first good evaluation to build an stl map of
  // coefficients to indices. This will be used to match up the terms
  // in different evaluations.
  // We also initialize expons with this first image modulo p.
  std::map< const typename BBT::BaseT*, long, PointerSort<typename BBT::BaseT> >
    cmap;
  const typename BBT::BaseT* cptr;
  i = 0;
  for (j=0; j<=deg(firstfp); ++j) {
    cptr = &(firstfp.rep[j]);
    if (!IsZero(*cptr)) {
      ++t;
      SetCoeff(expons,i,j);
      cmap[cptr] = i++;
    }
  }
  zz_pX expons_p;
  expons_p.rep.SetLength(t);
  std::vector<long> expons_vec(t);
  zz_pBak bak;
  typename BBT::PolyT fp;

  PrimeSeq ps;
  ps.reset (guess);
  while (goodprod <= D) {
    p = ps.next();
    SetCoeff (xpm1, p);
    bb.eval(fp, alphax, xpm1);
    SetCoeff (xpm1, p, zero);
    sfp = sparsity(fp);
    if (sfp > t) return false;
    if (sfp == t) {
#ifdef VERBOSE
      ++ngood;
#endif
      for (j=0; j<=deg(fp); ++j) {
        cptr = &(fp.rep[j]);
        if (!IsZero(*cptr)) {
          expons_vec[ cmap[cptr] ] = j;
        }
      }
      // Now we have to change the global NTL prime by "pushing" onto the stack
      bak.save();
      zz_p::init(p);
      // Now copy expons_vec to expons_p
      for (i=0; i<t; ++i)
        conv (expons_p.rep[i], expons_vec[i]);
      // Update expons with the new modular image at each coefficient
      CRT (expons, goodprod, expons_p);
      bak.restore(); // "pop" our prime off the "stack"
    }
#ifdef VERBOSE
    else ++nbad;
#endif
  }
#ifdef VERBOSE
  std::cout << "Sampled mod x^p-1 for a single good";
  std::cout << " prime p=" << firstp << ',' << std::endl
            << "as well as " << ngood << " good primes and "
            << nbad << " bad primes" << std::endl
            << "in the range " << guess << " <= p <= "
            << p << std::endl;
  std::cout << "Tried " << nbadalpha << " bad alphas before finding "
            << "f(" << alpha << "x) is diverse." << std::endl;
  std::cout << "All " << ngood
            << " good prime samples were actually used." << std::endl;
#endif

  // NTL does Chinese remaindering in the symmetric range, so we might need
  // to "fix" negative exponents (which are coefficients of expons).
  for (i=0; i<t; ++i) {
    if (sign(coeff(expons,i)) < 0)
      expons.rep[i] += goodprod;
  }

  // We want to get the exponents in sorted order.
  // So make a vector of pointers to each coefficient of expons, then sort it.
  std::vector<ZZ*> exptrs(t);
  for (long i=0; i<t; ++i) exptrs[i] = &(expons.rep[i]);
  std::sort (exptrs.begin(), exptrs.end(), PointerSort<ZZ>() );

  // Now extract the exponents from factors and store them in f.
  // Simultaneously, use a single good-prime evaluation to
  // get the coefficients.
  // Each coefficient must be divided by alpha^i.
  std::vector<typename BBT::BaseT> ainvpows(1); // ainvpows[i] = alpha^-(2^i)
  inv (ainvpows[0], alpha);
  typename BBT::BaseT curaipow;
  conv(curaipow,1);
  ZZ curpow; // invariant: curaipow = a^-curpow
  ZZ diffpow;

  f.rep.resize(t);
  typename SparsePoly<typename BBT::PolyT, typename BBT::BaseT>::RepT::iterator 
    fiter = f.rep.begin();
  for (std::vector<ZZ*>::const_iterator eiter = exptrs.begin();
       eiter != exptrs.end(); ++eiter) {
    fiter->second = **eiter;
    fiter->first = 
      coeff (firstfp, rem(fiter->second,firstp));
    
    // Compute alpha^-(e_i)
    diffpow = fiter->second;
    diffpow -= curpow;
    while (static_cast<unsigned long>(NumBits(diffpow)) > ainvpows.size()) {
      // grow ainvpows if necessary for this step
      ainvpows.resize(ainvpows.size()+1);
      sqr (ainvpows[ainvpows.size()-1], ainvpows[ainvpows.size()-2]);
    }
    while (!IsZero(diffpow)) {
      curaipow *= ainvpows[NumBits(diffpow)-1];
      SwitchBit (diffpow, NumBits(diffpow)-1);
    }
    curpow = fiter->second;
    fiter->first *= curaipow;
    ++fiter;
  }

  return true;
}

NTL_CLOSE_NNS


