#include "SmartMerge.H"
#include "RefinePlane.H"
#include "../PlaneDetectionCommon.H"
#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/connected_components.hpp>

typedef std::vector<DetectedPlane>::const_iterator PlaneIterator;

namespace
{
  // ######################################################################
  //! Detect whether the bounding box of the normals of the given groups is
  //  larger than the threshold on any side
  bool detectMultiModalNormals(std::vector<PlaneIterator> const & groupVotes, float const threshold)
  {
    float minx = std::numeric_limits<float>::max();
    float maxx = std::numeric_limits<float>::lowest();
    float miny = std::numeric_limits<float>::max();
    float maxy = std::numeric_limits<float>::lowest();
    float minz = std::numeric_limits<float>::max();
    float maxz = std::numeric_limits<float>::lowest();

    for(PlaneIterator plane : groupVotes)
    {
      Eigen::Vector3f const & n = plane->plane.normal();

      minx = std::min(n.x(), minx);
      maxx = std::max(n.x(), maxx);

      miny = std::min(n.y(), miny);
      maxy = std::max(n.y(), maxy);

      minz = std::min(n.z(), minz);
      maxz = std::max(n.z(), maxz);
    }

    return (maxx-minx > threshold) || (maxy-miny > threshold) || (maxz-minz > threshold);
  }

  // ######################################################################
  //! Find the plane which best fits a group
  /*! This works by looking at the points above and below the group, and finding
    the plane with the minimum distance to these neighboring points.

    @param maxNeighborDist If the average distance from all neighboring points
    to the best plane is greater than maxNeighborDist, then an OptionalEmpty is
    returned.  */
  template<size_t numRows>
    nrt::Optional<PlaneIterator> findBestNormal(LiDARGroup const & group, std::vector<PlaneIterator> const & groupVotes, nrt::PointCloud2 const & cloud, float const maxNeighborDist)
    {
      if(groupVotes.size() == 0) return nrt::OptionalEmpty;

      float bestDistance = std::numeric_limits<float>::max();
      PlaneIterator bestIt = groupVotes[0];

      for(PlaneIterator planeIt : groupVotes)
      {
        Eigen::Hyperplane<float, 3> const & plane = planeIt->plane;

        float totalDistance = 0;

        for(size_t pointIdx : group.indices)
        {
          size_t const row = velodyneIdx2Row<numRows>(pointIdx);
          if(row < numRows-1)
            totalDistance += plane.absDistance(cloud[pointIdx + getPositiveRowIncrement<numRows>(row)].getVector3Map());
          if(row > 0)
            totalDistance += plane.absDistance(cloud[pointIdx + getNegativeRowIncrement<numRows>(row)].getVector3Map());

          if(totalDistance > bestDistance) break;
        }

        if(totalDistance < bestDistance)
        {
          bestDistance = totalDistance;
          bestIt = planeIt;
        }
      }

      if(bestDistance/group.indices.size() > maxNeighborDist)
        return nrt::OptionalEmpty;

      return bestIt;
    }

  template
    nrt::Optional<PlaneIterator> findBestNormal<32>(LiDARGroup const & group, std::vector<PlaneIterator> const & groupVotes, nrt::PointCloud2 const & cloud, float const maxNeighborDist);

  template
    nrt::Optional<PlaneIterator> findBestNormal<64>(LiDARGroup const & group, std::vector<PlaneIterator> const & groupVotes, nrt::PointCloud2 const & cloud, float const maxNeighborDist);
}


// ######################################################################
template<size_t numRows>
std::vector<DetectedPlane> smartMerge(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>, numRows> const & groups, nrt::PointCloud2 const & cloud,
    float const normalModeThreshold, float const maxNeighborDist)
{
  if(detectedPlanes.size() == 0) return {};
  float const normalModeThresholdSq = normalModeThreshold * normalModeThreshold;
  static google::dense_hash_map<groupid_t, std::vector<PlaneIterator>> reverseVoteMap;
  static std::once_flag flag;
  std::call_once(flag, []()
      {
        reverseVoteMap.set_empty_key(std::numeric_limits<groupid_t>::max());
        reverseVoteMap.set_deleted_key(std::numeric_limits<groupid_t>::max()-1);
      });
  reverseVoteMap.clear_no_resize();

  // Create a map to keep track of every plane that each group voted for
  for(PlaneIterator planeIt = detectedPlanes.begin(); planeIt != detectedPlanes.end(); ++planeIt)
    for(groupid_t gid : planeIt->groups)
      reverseVoteMap[gid].push_back(planeIt);

  // For each group, go through all of the planes that it voted for and try to
  // find the maximal mode of the distribution of normals. Throw out any votes
  // whose normals are too far from this mode.
  for(std::pair<const groupid_t, std::vector<PlaneIterator>> & groupVotesPair : reverseVoteMap)
  {
    LiDARGroup const & group = groupFromId<numRows>(groups, groupVotesPair.first);
    std::vector<PlaneIterator> & groupVotes = groupVotesPair.second;

    // Detect if the distribution of normal vectors is too spread out
    if(detectMultiModalNormals(groupVotes, normalModeThreshold))
    {
      // Find the plane that is closest to the points above and below this group
      nrt::Optional<PlaneIterator> bestPlane = findBestNormal<numRows>(group, groupVotes, cloud, maxNeighborDist);

      // Find all planes that are near the best plane
      std::vector<PlaneIterator> closePlanes;
      if(bestPlane)
      {
        Eigen::Vector3f const & bestNormal = (*bestPlane)->plane.normal();
        for(PlaneIterator otherPlane : groupVotes)
        {
          if((bestNormal - otherPlane->plane.normal()).squaredNorm() < normalModeThresholdSq)
            closePlanes.push_back(otherPlane);
        }
      }
      groupVotes = std::move(closePlanes);
    }
  }

  // Create a new list of detected planes with now unimodal groups lists
  std::vector<DetectedPlane> filteredPlanes = detectedPlanes;
  for(DetectedPlane & p : filteredPlanes) p.groups.clear();
  for(std::pair<groupid_t, std::vector<PlaneIterator>> const & groupVotesPair : reverseVoteMap)
  {
    groupid_t const gid = groupVotesPair.first;
    std::vector<PlaneIterator> const & groupVotes = groupVotesPair.second;
    for(PlaneIterator planeIt : groupVotes)
    {
      size_t idx = std::distance(detectedPlanes.begin(), planeIt);
      filteredPlanes[idx].groups.push_back(gid);
    }
  }

  // Create an undirected graph that has VoteResults as nodes, which are
  // connected by edges if they both voted for the same group
  typedef boost::adjacency_list<boost::setS, boost::vecS, boost::undirectedS> VoteGraph;
  VoteGraph voteGraph;
  for(std::pair<groupid_t, std::vector<PlaneIterator>> const & groupVotesPair : reverseVoteMap)
  {
    std::vector<PlaneIterator> const & groupVotes = groupVotesPair.second;

    // TODO: Why the hell is this ever empty?
    if(groupVotes.empty()) continue;

    size_t idx1 = std::distance(detectedPlanes.begin(), groupVotes.front());

    for(PlaneIterator planeIt : groupVotes)
      boost::add_edge(idx1, std::distance(detectedPlanes.begin(), planeIt), voteGraph);
  }

  // Find all of the connected components in the graph and cluster them together
  std::vector<size_t> components(boost::num_vertices(voteGraph));
  int numComponents = boost::connected_components(voteGraph, components.data());
  std::vector<std::vector<DetectedPlane>> clusteredPlanes(numComponents);
  for(size_t i=0; i<components.size(); ++i)
    clusteredPlanes[components[i]].push_back(filteredPlanes[i]);

  // Fit a new plane to the points in each cluster
  std::vector<DetectedPlane> mergedPlanes;
  for(std::vector<DetectedPlane> const & cluster : clusteredPlanes)
  {
    std::vector<size_t> clusterGroupIds;
    for(DetectedPlane const & dp : cluster)
      clusterGroupIds.insert(clusterGroupIds.end(), dp.groups.begin(), dp.groups.end());

    std::sort(clusterGroupIds.begin(), clusterGroupIds.end());
    clusterGroupIds.erase(std::unique(clusterGroupIds.begin(), clusterGroupIds.end()), clusterGroupIds.end());

    if(clusterGroupIds.size() < 2) continue;

    std::vector<LiDARGroup> clusterGroups;
    for(size_t gid : clusterGroupIds)
      clusterGroups.push_back(groupFromId(groups, gid));

    DetectedPlane clusterPlane;
    clusterPlane.groups = clusterGroupIds;
    clusterPlane = refinePlaneFromPoints(clusterPlane, groups, cloud, RefineFrom::Groups);

    mergedPlanes.push_back(clusterPlane);
  }

  return mergedPlanes;
}

template
std::vector<DetectedPlane> smartMerge<32>(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>,32> const & groups, nrt::PointCloud2 const & cloud,
    float const normalModeThreshold, float const maxNeighborDist);

template
std::vector<DetectedPlane> smartMerge<64>(std::vector<DetectedPlane> const & detectedPlanes,
    std::array<std::vector<LiDARGroup>,64> const & groups, nrt::PointCloud2 const & cloud,
    float const normalModeThreshold, float const maxNeighborDist);
