#include "RegionGrowing.H"
#include "RefinePlane.H"
#include "../PlaneDetectionCommon.H"
#include <PointCloud/Features/RangeImage.H>
#include <nrt/PointCloud2/Features/Normals.H>
#include <PointCloud/Features/VelodyneHelpers.H>
#include <PointCloud/Features/Planes/details/FrameTiming.H>
#include <stack>

#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/connected_components.hpp>

#include "Timing.H"

size_t const NO_GROUP  = std::numeric_limits<groupid_t>::max();

template<size_t numRows>
void pushNeighbors(size_t const index, size_t const regionId, Eigen::Hyperplane<float, 3> const & plane, std::vector<size_t> & regionAssignments,
    std::vector<groupid_t> const & groupAssignments, std::array<std::vector<LiDARGroup>, numRows> const & groups,
    std::stack<size_t> & stack, std::stack<size_t> & newStack, nrt::PointCloud2 const & cloud, float const thresholdSq)
{
  size_t const row = velodyneIdx2Row<numRows>(index);

  auto checkPoint = [&plane, &groupAssignments, &groups, thresholdSq] (Eigen::Vector3f const & neighborPoint, size_t neighborIdx)
  {
    groupid_t const gid = groupAssignments[neighborIdx];
    if(gid != NO_GROUP && std::abs(groupFromId<numRows>(groups, gid).e3.dot(plane.normal())) > 0.1)
      return false;

    float const dot = plane.normal().dot(neighborPoint);
    if(dot == 0) return false;

    float const s = -plane.offset() / dot;

    if(s < 0) return false;

    Eigen::Vector3f const expectedPoint = neighborPoint * s;

    return (expectedPoint - neighborPoint).squaredNorm() < thresholdSq;
  }; // End: checkPoint

  if(row < 31)
  {
    size_t const up = index + getPositiveRowIncrement<numRows>(row);
    if(regionAssignments[up] != regionId && checkPoint(cloud[up].getVector3Map(), up))
    {
      regionAssignments[up] = regionId;
      stack.push(up);
      newStack.push(up);
    }
  }

  if(row > 0)
  {
    size_t const down = index + getNegativeRowIncrement<numRows>(row);
    if(regionAssignments[down] != regionId && checkPoint(cloud[down].getVector3Map(), down))
    {
      regionAssignments[down] = regionId;
      stack.push(down);
      newStack.push(down);
    }
  }

  size_t const left = (index > numRows) ? index - numRows : cloud.size() - 1 + index - numRows;
  if(regionAssignments[left] != regionId && checkPoint(cloud[left].getVector3Map(), left))
  {
    regionAssignments[left] = regionId;
    stack.push(left);
    newStack.push(left);
  }

  size_t const right = (index + numRows < cloud.size()) ? index + numRows : index + numRows - cloud.size();
  if(regionAssignments[right] != regionId && checkPoint(cloud[right].getVector3Map(), right))
  {
    regionAssignments[right] = regionId;
    stack.push(right);
    newStack.push(right);
  }
};

template void pushNeighbors<32>(size_t const index, size_t const regionId, Eigen::Hyperplane<float, 3> const & plane, std::vector<size_t> & regionAssignments,
    std::vector<groupid_t> const & groupAssignments, std::array<std::vector<LiDARGroup>, 32> const & groups,
    std::stack<size_t> & stack, std::stack<size_t> & newStack, nrt::PointCloud2 const & cloud, float const thresholdSq);

template void pushNeighbors<64>(size_t const index, size_t const regionId, Eigen::Hyperplane<float, 3> const & plane, std::vector<size_t> & regionAssignments,
    std::vector<groupid_t> const & groupAssignments, std::array<std::vector<LiDARGroup>, 64> const & groups,
    std::stack<size_t> & stack, std::stack<size_t> & newStack, nrt::PointCloud2 const & cloud, float const thresholdSq);

// ######################################################################
  template<size_t numRows>
std::vector<DetectedPlane> growRegions(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>, numRows> const & groups, nrt::PointCloud2 const & cloud, float const pointPlaneThreshold, size_t const maxIterations)
{
  float const pointPlaneThresholdSq = pointPlaneThreshold * pointPlaneThreshold;

  frame_timing::start("setup");
  size_t const NO_REGION = std::numeric_limits<size_t>::max();
  // Keep track of which group every point belongs to
  std::vector<groupid_t> groupAssignments(cloud.size(), NO_GROUP);
  for(std::vector<LiDARGroup> const & row : groups)
    for(LiDARGroup const & group : row)
      for(size_t pid : group.indices)
        groupAssignments[pid] = group.id;

  // an index of what region (plane) every point is assigned to
  std::vector<size_t> regionAssignments(cloud.size(), NO_REGION);

  frame_timing::stop("setup");
  frame_timing::start("plane loop");

  size_t regionId = 0;
  std::vector<DetectedPlane> finalPlanes;
  finalPlanes.reserve(detectedPlanes.size());
  for(DetectedPlane finalPlane : detectedPlanes)
  {
    std::stack<size_t> indexStack;

    // Set the assignments for all group indices, and push them onto the stack for checking
    for(groupid_t const & gid : finalPlane.groups)
    {
      LiDARGroup const & group = groupFromId<numRows>(groups, gid);

      for(size_t const & pid : group.indices)
      {
        indexStack.push(pid);
        regionAssignments[pid] = regionId;
      }
    }

    nrt::Indices indices;
    size_t iterations = 0;
    for(; iterations < maxIterations; ++iterations)
    {
      Eigen::Hyperplane<float, 3> const & plane = finalPlane.plane;
      std::stack<size_t> newStack;

      // Go through the index stack
      while(!indexStack.empty())
      {
        size_t const pid = indexStack.top();

        indexStack.pop();

        auto const & pointNormal = cloud.at<nrt::PointNormal>(pid).get<nrt::PointNormal>().getVector3Map();
        if(pointNormal.dot(plane.normal()) > .8)
        {
          indices.push_back(pid);

          pushNeighbors<numRows>(pid, regionId, plane, regionAssignments, groupAssignments, groups, indexStack, newStack, cloud, pointPlaneThresholdSq);
        }
      }

      if(indices.size() < 5) break;

      finalPlane.indices = indices;
      finalPlane.groups.clear();
      finalPlane = refinePlaneFromPoints(finalPlane, groups, cloud, RefineFrom::Indices);

      if(newStack.empty()) break;

      indexStack = std::move(newStack);
    }

    if(finalPlane.indices.size() > 15)
      finalPlanes.push_back(finalPlane);

    ++regionId;
  }

  frame_timing::stop("plane loop");
  return finalPlanes;
}

template<size_t numRows>
std::vector<DetectedPlane> growRegions2(std::vector<DetectedPlane> const & detectedPlanes, nrt::PointCloud2 cloud)
{
  typedef boost::adjacency_list<boost::setS, boost::vecS, boost::undirectedS> Graph;

  frame_timing::start("setup");
  Graph graph;

  std::vector<DetectedPlane> new_planes = detectedPlanes;

  size_t const numplanes = new_planes.size();

  for(size_t plane_idx=0; plane_idx<numplanes; ++plane_idx)
    for(auto i : new_planes[plane_idx].indices)
      boost::add_edge(plane_idx, numplanes + i, graph);

  float const distance_threshold = 0.0001;
  float const normal_threshold = 0.95;

  auto plane_distance = [](Eigen::Vector3f const & p0, Eigen::Vector3f const & n0, Eigen::Vector3f const & p1)
  {
    return std::abs((p1 - p0).dot(n0));
  };

  size_t const numCols = cloud.size()/int(numRows);
  for(size_t row=0; row<numRows-1; ++row)
  {
    for(size_t col=0; col<numCols-1; ++col)
    {
      size_t const index = velodyne32::index(row,col);
      auto p = cloud.get<nrt::PointNormal>(index);
      auto point = p.geometry().getVector3Map();
      if(point.norm() < 0.1)  continue;
      auto normal = p.get<nrt::PointNormal>().getVector3Map().normalized();

      size_t const index_right = velodyne32::index(row,col+1);
      auto r = cloud.get<nrt::PointNormal>(index_right);
      auto right = r.geometry().getVector3Map();
      if(right.norm() > 0.1)
      {
        auto right_normal = r.get<nrt::PointNormal>().getVector3Map();
        if((std::abs(right_normal.dot(normal)) > normal_threshold) &&
            (plane_distance(point, normal, right) < distance_threshold))
          boost::add_edge(index + numplanes, index_right + numplanes, graph);
      }

      size_t const index_down = velodyne32::index(row+1,col);
      auto d = cloud.get<nrt::PointNormal>(index_down);
      auto down = d.geometry().getVector3Map();
      if(down.norm() > 0.1)
      {
        auto down_normal = d.get<nrt::PointNormal>().getVector3Map();
        if((std::abs(down_normal.dot(normal)) > normal_threshold) &&
            (plane_distance(point, normal, down) < distance_threshold))
          boost::add_edge(index + numplanes, index_down + numplanes, graph);
      }
    }
  }
  frame_timing::stop("setup");

  frame_timing::start("solve");
  std::vector<size_t> components(numplanes + cloud.size(), -1);
  boost::connected_components(graph, components.data());
  frame_timing::stop("solve");

  frame_timing::start("organize");

  // A map from component ID to plane ID
  std::unordered_map<size_t, std::vector<size_t>> plane_components;
  for(size_t plane_idx=0; plane_idx<numplanes; ++plane_idx)
    plane_components[components[plane_idx]].push_back(plane_idx);

  std::vector<nrt::Indices> new_indices(numplanes);
  for(size_t i=numplanes; i<components.size(); ++i)
  {
    size_t const index = i - numplanes;
    size_t const component = components[i];
    auto found_component = plane_components.find(component);
    if(found_component == plane_components.end()) continue;

    auto normal = cloud.get<nrt::PointNormal>(index).get<nrt::PointNormal>().getVector3Map();

    for(size_t plane_idx : found_component->second)
    {
      if(std::abs(new_planes[plane_idx].plane.normal().dot(normal)) > 0.9)
        new_indices[plane_idx].push_back(index);
    }
  }

  std::vector<DetectedPlane> ret_planes;
  for(size_t plane_idx=0; plane_idx<numplanes; ++plane_idx)
  {
    if(new_indices[plane_idx].size() > 3)
    {
      new_planes[plane_idx].indices = new_indices[plane_idx];
      new_planes[plane_idx] = refinePlaneFromPoints<numRows>(new_planes[plane_idx], {}, cloud, RefineFrom::Indices);
      ret_planes.push_back(new_planes[plane_idx]);
    }
  }


  frame_timing::stop("organize");

  return ret_planes;
}

  template
std::vector<DetectedPlane> growRegions2<32>(std::vector<DetectedPlane> const & detectedPlanes, nrt::PointCloud2 cloud);
  template
std::vector<DetectedPlane> growRegions2<64>(std::vector<DetectedPlane> const & detectedPlanes, nrt::PointCloud2 cloud);

  template
std::vector<DetectedPlane> growRegions<32>(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>, 32> const & groups, nrt::PointCloud2 const & cloud, float const pointPlaneThreshold, size_t const maxIterations);

  template
std::vector<DetectedPlane> growRegions<64>(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>, 64> const & groups, nrt::PointCloud2 const & cloud, float const pointPlaneThreshold, size_t const maxIterations);
