#include "FindGroupsInRow.H"
#include "../Filtering.H"
#include <nrt/PointCloud2/Common/Centroid.H>
#include <nrt/PointCloud2/Features/Normals.H>
#include <PointCloud/Features/Planes/details/FrameTiming.H>

#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/connected_components.hpp>

// TODO: Things to try for better groups:
// 1) Look into 3D edge detection for break point detection?
// 2) Use boost graph for merging
// 3) Look into normals for break point detection.

namespace
{
  // ######################################################################
  // Add in meta data about the groups (eigenvectors, centroid, etc)
  LiDARGroup makeGroup(nrt::Indices const & indices, nrt::PointCloud2 const & cloud)
  {
    LiDARGroup group;
    group.indices = indices;
    group.centroid = nrt::PointCloud2::Vector3(nrt::computeCentroid(cloud, indices).getVector3Map());

    nrt::PointCloud2::Matrix3 eigenVectors;
    //nrt::Normals::computePointNormal(cloud, indices, nrt::OptionalEmpty, false, group.variation, eigenVectors);
    nrt::Normals::computePointNormal(cloud, indices, false, nrt::OptionalEmpty, group.variation, eigenVectors);

    group.e1 = eigenVectors.block<3,1>(0,0);
    group.e2 = eigenVectors.block<3,1>(0,1);
    group.e3 = eigenVectors.block<3,1>(0,2);

    group.endPoints.first = cloud[indices.front()].getVector3Map();
    group.endPoints.second = cloud[indices.back()].getVector3Map();

    Eigen::Vector3f vec = group.centroid + group.e3;
    bool isCounterClockwise = (group.centroid.x() * vec.y() - group.centroid.y() * vec.x()) > 0.0f;

    // Enforce that e3 points in the CLOCKWISE direction
    if(isCounterClockwise)
      group.e3 *= -1.0f;

    return group;
  }

  // ######################################################################
  // Compute the curvature for a group
  inline float getCurvature(LiDARGroup const & group)
  {
    float const curvature = std::min(1.0f, group.variation[1] / 0.005f);
    // TODO: investigate centroid dot e1 being negative and what that means for curvature
    //return std::max( 0.0f, curvature * group.centroid.normalized().dot(group.e1) );
    return curvature * std::abs( group.centroid.normalized().dot(group.e1) );
  }
}

// ######################################################################
//! Merge groups using boost.graph, and the provided point normals
void mergeBoost(std::vector<LiDARGroup> & groups, nrt::PointCloud2 const & cloud, float const mergeDotProductThreshold, float const mergeDistanceThreshold)
{
  typedef boost::adjacency_list<boost::setS, boost::vecS, boost::undirectedS> Graph;
  Graph graph;
  for(size_t i=0; i<groups.size(); ++i)
  {
    auto const & group_i = groups[i];
    for(size_t j=i; j<groups.size(); ++j)
    {
      auto const & group_j = groups[j];

      bool const direction_ok = group_i.e3.dot(group_j.e3) > mergeDotProductThreshold;
      bool const distance_ok = shortestSquaredDistanceBetweenSegments(
          group_i.endPoints.first, group_i.endPoints.second,
          group_j.endPoints.first, group_j.endPoints.second) < mergeDistanceThreshold;

      if(direction_ok && distance_ok)
        boost::add_edge(i, j, graph);
    }
  }

  std::vector<size_t> components(boost::num_vertices(graph));
  int numComponents = boost::connected_components(graph, components.data());

  std::vector<LiDARGroup> mergedGroups(numComponents);
  for(size_t i=0; i<components.size(); ++i)
  {
    size_t const component = components[i];
    LiDARGroup const & group = groups[i];
    LiDARGroup & mergedGroup = mergedGroups[component];

    mergedGroup.indices.insert(mergedGroup.indices.end(), group.indices.begin(), group.indices.end());
  }

  for(LiDARGroup & group : mergedGroups)
    group = makeGroup(group.indices, cloud);

  groups = mergedGroups;
}

// ######################################################################
//! Merge groups using just the dot product of the largest eigenvectors, and the distance between groups
void mergeOriginal(std::vector<LiDARGroup> & groups, nrt::PointCloud2 const & cloud, float const mergeDotProductThreshold, float const mergeDistanceThreshold)
{
  for(auto currentIt = groups.begin(); currentIt < groups.end(); ++currentIt)
  {
    std::vector<std::vector<LiDARGroup>::iterator> merged_groups;

    // Add the candidate group that aligns well with the current group
    for(auto candidateIt = currentIt+1; candidateIt < groups.end(); ++candidateIt)
      if(currentIt->e3.dot(candidateIt->e3) > mergeDotProductThreshold &&
          shortestSquaredDistanceBetweenSegments(currentIt->endPoints.first, currentIt->endPoints.second,
            candidateIt->endPoints.first, candidateIt->endPoints.second) < mergeDistanceThreshold)
        merged_groups.push_back(candidateIt);

    // Merge all of the indices
    for(auto merged_group : merged_groups)
      currentIt->indices.insert(currentIt->indices.end(),
          merged_group->indices.begin(), merged_group->indices.end());

    // Recompute the group from the merged indices
    *currentIt = makeGroup(currentIt->indices, cloud);

    // Remove all of the merged groups
    removeIndices(groups, merged_groups);
  }
}

// ######################################################################
template<size_t numRows>
std::vector<LiDARGroup> findGroupsInRow(nrt::PointCloud2 const & cloud, size_t row,
    float const distanceGaussianSigma,
    float const ddThreshold,
    float const distanceThreshold,
    size_t const minGroupPoints,
    float const mergeDotProductThreshold,
    float const mergeDistanceThreshold)
{
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Get the distances to each point in the row, as well as a reverse index table
  size_t first_i = getVelodyneFirstRowIdx<numRows>(row);
  planefiltering::AlignedVector rowDistances; rowDistances.reserve(cloud.size() / numRows);
  std::vector<size_t> rowReverseIndices; rowReverseIndices.reserve(cloud.size() / numRows);
  auto cloudDistanceIterator = cloud.begin<Distance>() + first_i;
  for(size_t i=first_i; i<cloud.size(); i+=numRows)
  {
    rowDistances.push_back(cloudDistanceIterator->get<Distance>().value);
    cloudDistanceIterator += numRows;
    rowReverseIndices.push_back(i);
  }

  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Find the breaks in the row distance vector
  static planefiltering::AlignedVector gaussian = planefiltering::makeGaussian(7, distanceGaussianSigma);
  planefiltering::AlignedVector blurred     = planefiltering::convolve(rowDistances, gaussian);
  planefiltering::AlignedVector d_blurred   = planefiltering::convolve(blurred,      {-1, 0, 1});
  planefiltering::AlignedVector dd_blurred  = planefiltering::convolve(d_blurred,    {-1, 0, 1});
  planefiltering::AlignedVector ddd_blurred = planefiltering::convolve(dd_blurred,   {-1, 0, 1});

  std::vector<size_t> breaks = planefiltering::zeroCrossings(d_blurred);

  size_t ddd_breaks_begin = breaks.size();

  std::vector<size_t> ddd_crossings = planefiltering::zeroCrossings(ddd_blurred);
  std::copy_if(ddd_crossings.begin(), ddd_crossings.end(), std::back_inserter(breaks),
      [dd_blurred, ddThreshold](size_t crossing) { return std::abs(dd_blurred[crossing]) > ddThreshold; });

  std::inplace_merge(breaks.begin(), breaks.begin() + ddd_breaks_begin, breaks.end());

  size_t dist_breaks_begin = breaks.size();

  planefiltering::AlignedVector differences = planefiltering::difference(rowDistances);
  for(size_t i=0; i<differences.size(); ++i)
    if(std::abs(differences[i]) > distanceThreshold)
      breaks.push_back(i);

  std::inplace_merge(breaks.begin(), breaks.begin() + dist_breaks_begin, breaks.end());
  breaks.erase(std::unique(breaks.begin(), breaks.end()), breaks.end());

  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Find all of the group indices between the breaks
  std::vector<nrt::Indices> groupIndices;
  for(size_t i=1; i<breaks.size(); ++i)
  {
    groupIndices.emplace_back();
    for(size_t j=breaks[i-1]+1; j<breaks[i]-1; ++j)
      groupIndices.back().push_back(rowReverseIndices[j]);
  }

  // Handle groups that cross the zero boundary
  groupIndices.emplace_back();
  for(size_t j = breaks.back(); j<rowReverseIndices.size(); ++j)
    groupIndices.back().push_back(rowReverseIndices[j]);
  for(size_t j = 0; j<breaks.front(); ++j)
    groupIndices.back().push_back(rowReverseIndices[j]);

  // Erase groups with less that N elements
  groupIndices.erase(std::remove_if(groupIndices.begin(), groupIndices.end(),
        [minGroupPoints](nrt::Indices const & g) { return g.size() < minGroupPoints; }),
      groupIndices.end());

  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Turn the indices into groups with metadata
  std::vector<LiDARGroup> groups(groupIndices.size());
  std::transform(groupIndices.begin(), groupIndices.end(), groups.begin(), std::bind(makeGroup, std::placeholders::_1, cloud));

  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Merge neighboring groups whose dot products and distance thresholds are good enough
  frame_timing::start("Merge");
  //mergeBoost(groups, cloud, mergeDotProductThreshold, mergeDistanceThreshold);
  mergeOriginal(groups, cloud, mergeDotProductThreshold, mergeDistanceThreshold);
  frame_timing::stop("Merge");

  // Fill in the curvatures and IDs for all of the groups
  // Also compute the normal distribution (of the normals)
  size_t groupIdx = 0;
  for(LiDARGroup & group : groups)
  {
    group.curvature = getCurvature(group);
    group.id = groupId<numRows>(row, groupIdx++);
  }

  return groups;
}

template
std::vector<LiDARGroup> findGroupsInRow<32>(nrt::PointCloud2 const & cloud, size_t row,
    float const distanceGaussianSigma,
    float const ddThreshold,
    float const distanceThreshold,
    size_t const minGroupPoints,
    float const mergeDotProductThreshold,
    float const mergeDistanceThreshold);

template
std::vector<LiDARGroup> findGroupsInRow<64>(nrt::PointCloud2 const & cloud, size_t row,
    float const distanceGaussianSigma,
    float const ddThreshold,
    float const distanceThreshold,
    size_t const minGroupPoints,
    float const mergeDotProductThreshold,
    float const mergeDistanceThreshold);
