#include "VoteForGroup.H"
#include <PointCloud/Features/Planes/details/VonMisesFischer.H>
#include <nrt/PointCloud2/Features/FeatureTypes/PointNormal.H>

namespace
{
  // ######################################################################
  std::tuple<float, float> toPolar(Eigen::Vector3f const & n)
  {
    float const phi = std::acos(n.z());
    float theta;
    if((std::abs(phi) < 0.0001) || (std::abs(M_PI - phi) < 0.0001))
      theta = 0.0;
    else
      theta = std::atan2(n.y(), n.x());
    if(theta < 0) theta += M_PI*2.0;

    return std::make_tuple(theta, phi);
  }

  // ######################################################################
  float tuningCurve(float const similarity, float const curvature, float const linear_offset = 0.85)
  {
    float const x = similarity * similarity;
    float const xx = 1.0f - x;
    float const xxx = xx*xx*xx;
    float const kumaraswamy_cdf = 1.0f - xxx;

    float const linear = curvature * similarity + linear_offset * (1.0f - curvature);

    return curvature * kumaraswamy_cdf + (1.0 - curvature)*linear;
  }

  // ######################################################################
  //! Figures out how much rotation about the global z axis is needed to
  //! have a shortest distance to a line of rho
  /*! @param rho The desired rho
    @param e3 The original group e3 vector
    @param CO The vector pointing from centroid to origin
    @param centroid The centroid
    @param etaMax The maximum rotation, in radians
    @return The eta matching the desired rho, or nan if not possible */
  float findEta(float const rho, Eigen::Vector3f const & e3,
      Eigen::Vector3f const & CO, Eigen::Vector3f const & centroid,
      float const etaMax)
  {
    float const a = CO.y() * e3.x() - CO.x() * e3.y();
    float const b = CO.y() * e3.y() + CO.x() * e3.x();
    float const c = std::hypot(a, b);
    float const d_c = centroid.norm();
    float const d_i = rho;
    float const gamma = std::sqrt(d_c*d_c - d_i*d_i) / d_c - CO.z()*e3.z();
    float const alpha = std::atan2(a, b);

    // Compute the four possible results
    std::array<float, 4> results =
    {{
       std::acos(-gamma/c) + alpha + static_cast<float>(M_PI),
       std::acos( gamma/c) + alpha,
       -std::acos( gamma/c) + alpha,
       -std::acos(-gamma/c) + alpha + static_cast<float>(M_PI)
     }};

    // Mod by 2 pi
    for( float & result : results )
      result = std::fmod( result, static_cast<float>(2*M_PI) );

    // Pick the first result within etaMin and etaMax
    for( float const & result : results )
      if( result <= etaMax && result >= -etaMax )
        return result;

    // If we reach here we didn't get valid results.  This can occur because
    // we happen to be at a point where incrementing rho to eta_rho_curr would require
    // a greater rotation than our etaMax and etaMin allow
    return std::numeric_limits<float>::quiet_NaN();
  }

  // ######################################################################
  /*!
    @param minRho The minimum value of rho to vote for
    */
  std::vector<AccumulatorBall::cellid_t> voteForLine(nrt::PointCloud2 const & cloud, Eigen::Vector3f const & line, LiDARGroup const & group, AccumulatorBall & accumulator,
      float const minRho, bool const voteSpread)
  {
    std::vector<AccumulatorBall::cellid_t> successfulVotes;

    // Create distribution of normals for this group
    //VonMisesFischer groupNormalDistribution;
    //for( auto const normal : cloud.subset_range<nrt::PointNormal>( group.indices ) )
    //  groupNormalDistribution.add( normal.get<nrt::PointNormal>().getVector3Map() );

    Eigen::Vector3f const & centroid = group.centroid;

    // Find the nullspace vector of our underconstrained problem
    Eigen::Vector3f const x = line.cross(centroid);

    // Find the closest point from our line to the origin
    float const maxdist = x.norm() / line.norm();

    // Fill out our underconstrained problem matrix
    Eigen::Matrix<float,2,3> A;
    A.block<1,3>(0,0) = centroid.transpose();
    A.block<1,3>(1,0) = line.transpose();

    // Find the pseudo inverse of A
    Eigen::Matrix<float,3,2> const Ai = A.transpose() * (A*A.transpose()).inverse();

    // dovote: votes for a single rho_i for a line
    std::function<void(float, bool)> dovote = [&](float rho_i, bool bothSides)
    {
      Eigen::Vector3f const vi = Ai * Eigen::Vector2f(rho_i, 0);

      float const v_dot_x = vi.dot(x);
      float const x_n_2 = x.squaredNorm();
      float const v_n_2 = vi.squaredNorm();

      // Vote for positive rho
      {
        // Solve for one of the values of a that will make ||vi + x|| = 1
        float const a = v_dot_x + std::sqrt(std::max(0.0F, v_dot_x*v_dot_x - x_n_2 * (v_n_2 - 1))) / x_n_2;

        Eigen::Vector3f const n = (vi + x*a).normalized();
        float const normal_similarity = std::abs(n.dot(group.e1));                // similarity to group smallest eigenvector
        float const normal_dist_similarity = 1.0; //groupNormalDistribution.evaluate(n); // similarity to dist of normals in group
        float const vote_strength = tuningCurve(normal_similarity, group.curvature) * normal_dist_similarity;
        //float const vote_strength = normal_dist_similarity;
        float theta, phi;
        std::tie(theta, phi) = toPolar(n);
        if(voteSpread)
        {
          for(AccumulatorBall::cellid_t cellid : accumulator.accumulateSpread(theta, phi, rho_i, vote_strength, group.id))
            successfulVotes.push_back(cellid);
        }
        else
        {
          nrt::Optional<AccumulatorBall::cellid_t> accResult = accumulator.accumulate(theta, phi, rho_i, vote_strength, group.id);
          if(accResult) successfulVotes.push_back(*accResult);
        }
      }

      // Vote for negative rho
      if(bothSides)
      {
        // Solve for the other value of a that will make ||vi + x|| = 1
        float const a = v_dot_x - std::sqrt(std::max(0.0F, v_dot_x*v_dot_x - x_n_2 * (v_n_2 - 1))) / x_n_2;

        Eigen::Vector3f const n = (vi + x*a).normalized();
        float const normal_similarity = std::abs(n.dot(group.e1));                // similarity to group smallest eigenvector
        float const normal_dist_similarity = 1.0;// groupNormalDistribution.evaluate(n); // similarity to dist of normals in group
        float const vote_strength = tuningCurve(normal_similarity, group.curvature) * normal_dist_similarity;
        //float const vote_strength = normal_dist_similarity;
        float theta, phi;
        std::tie(theta, phi) = toPolar(n);
        if(voteSpread)
        {
          for(AccumulatorBall::cellid_t cellid : accumulator.accumulateSpread(theta, phi, rho_i, vote_strength, group.id))
            successfulVotes.push_back(cellid);
        }
        else
        {
          nrt::Optional<AccumulatorBall::cellid_t> accResult = accumulator.accumulate(theta, phi, rho_i, vote_strength, group.id);
          if(accResult) successfulVotes.push_back(*accResult);
        }
      }
    }; // end lambda dovote

    float const rhoStep = accumulator.getRhoStep();
    float rhoVote = minRho;
    for(; rhoVote <= maxdist; rhoVote += rhoStep) dovote(rhoVote, true);
    if(rhoVote - rhoStep < maxdist)               dovote(maxdist, false);

    return successfulVotes;
  }

}


// ######################################################################

//! Vote the given group in the accumulator, and return the cell IDs of all cells that passed the accumulator threshold.
/*! @param group
    @param accumulator
    @param etaMax The maximum eta rotation in radians
    @return */
std::vector<AccumulatorBall::cellid_t> voteForGroup(nrt::PointCloud2 const & cloud, LiDARGroup const & group, AccumulatorBall & accumulator,
    bool const voteSpread,
    float const etaMax, float const minRho)
{
  std::vector<AccumulatorBall::cellid_t> successfulVotes;

  Eigen::Vector3f const & e3       = group.e3;
  Eigen::Vector3f const & centroid = group.centroid;
  Eigen::Vector3f const CO         = -centroid.normalized();

  // First vote on the original group with no rotation
  std::vector<AccumulatorBall::cellid_t> baseVotes = voteForLine(cloud, e3, group, accumulator, minRho, voteSpread);
  successfulVotes.insert(successfulVotes.end(), baseVotes.begin(), baseVotes.end());

  const float distanceToGroupLine = e3.cross(centroid).norm() / e3.norm();
  const float rhoStep = accumulator.getRhoStep();

  // Do upwards sweep of eta
  for(float rho_i = distanceToGroupLine + rhoStep; ; rho_i += rhoStep)
  {
    float const eta_i = findEta(rho_i, e3, CO, centroid, etaMax);
    if(std::isnan(eta_i)) break;
    else
    {
      Eigen::Vector3f const rotated_line = Eigen::AngleAxisf(eta_i, Eigen::Vector3f::UnitZ()) * e3;
      std::vector<AccumulatorBall::cellid_t> etaVotes = voteForLine(cloud, rotated_line, group, accumulator, minRho, voteSpread);
      successfulVotes.insert(successfulVotes.end(), etaVotes.begin(), etaVotes.end());
    }
  }

  // Do downwards sweep of eta
  for(float rho_i = distanceToGroupLine - rhoStep; ; rho_i -= rhoStep)
  {
    float const eta_i = findEta(rho_i, e3, CO, centroid, etaMax);
    if(std::isnan(eta_i)) break;
    else
    {
      Eigen::Vector3f const rotated_line = Eigen::AngleAxisf(eta_i, Eigen::Vector3f::UnitZ()) * e3;
      std::vector<AccumulatorBall::cellid_t> etaVotes = voteForLine(cloud, rotated_line, group, accumulator, minRho, voteSpread);
      successfulVotes.insert(successfulVotes.end(), etaVotes.begin(), etaVotes.end());
    }
  }

  return successfulVotes;
}
