#include "SplitPlanes.H"
#include "../PlaneDetectionCommon.H"
#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/connected_components.hpp>

namespace
{
  float sqSegmentDist(
      Eigen::Vector3f const & s1_begin,
      Eigen::Vector3f const & s1_end,
      Eigen::Vector3f const & s2_begin,
      Eigen::Vector3f const & s2_end)
  {
    static float const SMALL_NUM = 0.001f;
    Eigen::Vector3f const u = s1_end - s1_begin;
    Eigen::Vector3f const v = s2_end - s2_begin;
    Eigen::Vector3f const w = s1_begin - s2_begin;
    float const a = u.dot(u);
    float const b = u.dot(v);
    float const c = v.dot(v);
    float const d = u.dot(w);
    float const e = v.dot(w);
    float const D = a*c - b*b;
    float sc, sN, sD = D;
    float tc, tN, tD = D;

    // compute the line parameters of the closest points
    if(D < SMALL_NUM) // lines almost parallel
    {
      sN = 0.0f;
      sD = 1.0f;
      tN = e;
      tD = c;
    }
    else // get the closest points on the infinite lines
    {
      sN = (b*e - c*d);
      tN = (a*e - b*d);

      if(sN < 0.0f) // sc < 0 -> the s=0 edge is visible
      {
        sN = 0.0f;
        tN = e;
        tD = c;
      }
      else if(sN > sD) // sc > 1 -> the s=1 edge is visible
      {
        sN = sD;
        tN = e + b;
        tD = c;
      }
    }

    if(tN < 0.0f)
    {
      tN = 0.0f;
      // recompute sc for this edge
      if(-d < 0.0f)
        sN = 0.0f;
      else if(-d > a)
        sN = sD;
      else
      {
        sN = -d;
        sD = a;
      }
    }
    else if(tN > tD)
    {
      tN = tD;
      if((-d + b) < 0.0f)
        sN = 0.0f;
      else if((-d + b) > a)
        sN = sD;
      else
      {
        sN = (-d + b);
        sD = a;
      }
    }

    // finally do the division to get sc and tc
    sc = std::abs(sN) < SMALL_NUM ? 0.0f : sN / sD;
    tc = std::abs(tN) < SMALL_NUM ? 0.0f : tN / tD;

    // Get the difference oef the two closest points
    Eigen::Vector3f const dP = w + (sc * u) - (tc * v);

    return dP.squaredNorm();
  }

}

// ######################################################################
  template<size_t numRows>
std::vector<DetectedPlane> splitPlanes(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>, numRows> const & groups, nrt::PointCloud2 const & cloud)
{
  float const dotThreshold = 0.90;
  size_t const minGroups   = 2;

  std::vector<DetectedPlane> results;

  typedef boost::adjacency_list<boost::setS, boost::vecS, boost::undirectedS> GroupGraph;

  for(DetectedPlane const & detectedPlane : detectedPlanes)
  {
    GroupGraph graph;

    // Create a graph of all groups that have a similar dot product and that are close together
    for(size_t idx1=0; idx1<detectedPlane.groups.size(); ++idx1)
    {
      groupid_t gid1 = detectedPlane.groups[idx1];
      LiDARGroup const & group1 = groupFromId(groups, gid1);
      float const group1Cdist = group1.centroid.squaredNorm();

      for(size_t idx2=idx1+1; idx2<detectedPlane.groups.size(); ++idx2)
      {
        groupid_t gid2 = detectedPlane.groups[idx2];
        LiDARGroup const & group2 = groupFromId(groups, gid2);

        if(group1.e3.dot(group2.e3) >= dotThreshold)
        {
          float const dist = sqSegmentDist(group1.endPoints.first, group1.endPoints.second, group2.endPoints.first, group2.endPoints.second);
          float const segDistThresh = std::min(2.5f * 2.5f, std::min(group1Cdist, group2.centroid.squaredNorm()) / 100.0f);

          if(dist < segDistThresh)
            boost::add_edge(idx1, idx2, graph);
        }
      }
    }

    // Find all of the connected components in the graph
    std::vector<size_t> componentIds(boost::num_vertices(graph));
    int numComponents = boost::connected_components(graph, componentIds.data());
    std::vector<std::vector<groupid_t>> components(numComponents);
    for(size_t i=0; i<componentIds.size(); ++i)
      components[componentIds[i]].push_back(detectedPlane.groups[i]);

    // Filter out all components that have less than <minGroups> groups
    std::vector<std::vector<groupid_t>> filteredComponents;
    std::copy_if(components.begin(), components.end(), std::back_inserter(filteredComponents),
        [minGroups](std::vector<groupid_t> const & component)
        { return component.size() > minGroups; });

    // Add all of the newly split components into the results vector
    for(std::vector<groupid_t> const & componentGroups : filteredComponents)
    {
      DetectedPlane componentPlane = detectedPlane;
      componentPlane.groups = componentGroups;
      results.push_back(componentPlane);
    }

  }

  return results;
}

template
std::vector<DetectedPlane> splitPlanes<32>(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>, 32> const & groups, nrt::PointCloud2 const & cloud);

template
std::vector<DetectedPlane> splitPlanes<64>(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>, 64> const & groups, nrt::PointCloud2 const & cloud);
