#include "RefinePlane.H"
#include "LinearAlgebra.H"
#include <nrt/PointCloud2/Common/Centroid.H>
#include <nrt/PointCloud2/Features/Normals.H>
#include <nrt/PointCloud2/Common/Covariance.H>
#include <nrt/PointCloud2/Common/Plane.H>
#include <nrt/Eigen/EigenDecomposition.H>
#include <smmintrin.h>

NRT_BEGIN_UNCHECKED_INCLUDES;
#include <Eigen/Eigenvalues>
NRT_END_UNCHECKED_INCLUDES;


namespace {
  using namespace nrt;
  auto computeCovarianceAndCentroidSSE( PointCloud2::ConstIterator<> && iter, PointCloud2::ConstIterator<> && end, const size_t size ) -> std::pair<PointCloud2::Matrix3, PointCloud2::Geometry>
  {
    PointCloud2::Matrix3 covarianceMatrix;
    PointCloud2::Geometry centroid;

    //Eigen::Matrix<PointCloud2::BaseType, 1, 9, Eigen::RowMajor> buffer =
    //  Eigen::Matrix<PointCloud2::BaseType, 1, 9, Eigen::RowMajor>::Zero();

    __m128 r1 = _mm_setzero_ps();
    __m128 r2 = _mm_setzero_ps();
    float r3 = 0.0f;

    __m128 ones;
    ones[0] = 1.0f;
    ones[1] = 1.0f;
    ones[2] = 1.0f;
    ones[3] = 1.0f;

    // Build up the buffer as per
    // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance
    // keep track of sum of var^2 over all vars, and sum of vars
    for( ; iter != end; ++iter )
    {
      __m128 point = _mm_blend_ps(_mm_load_ps(iter->geometry().begin()), ones, 0x8);

#define _X_ 0
#define _Y_ 1
#define _Z_ 2
#define _1_ 3

      __m128 b1 = _mm_shuffle_ps(point, point, _MM_SHUFFLE(_Y_, _Z_, _Y_, _X_));
      __m128 b2 = _mm_shuffle_ps(point, point, _MM_SHUFFLE(_Y_, _X_, _X_, _X_));
      __m128 b3 = _mm_shuffle_ps(point, point, _MM_SHUFFLE(_1_, _1_, _Z_, _Z_));
      __m128 b4 = _mm_shuffle_ps(point, point, _MM_SHUFFLE(_Y_, _X_, _Z_, _Y_));

      r1 = _mm_add_ps(r1, _mm_mul_ps(b1, b2));
      r2 = _mm_add_ps(r2, _mm_mul_ps(b3, b4));
      r3 += iter->geometry().z();
    }

    alignas(16) Eigen::Matrix<PointCloud2::BaseType, 1, 9, Eigen::RowMajor> buffer;
    _mm_store_ps(&buffer[0], r1);
    _mm_store_ps(&buffer[4], r2);
    buffer[8] = r3;

    // normalize
    buffer /= static_cast<PointCloud2::BaseType>( size );

    // fill out results
    centroid.x() = buffer[6];
    centroid.y() = buffer[7];
    centroid.z() = buffer[8];

    covarianceMatrix.coeffRef( 0 ) = buffer[0] - buffer[6] * buffer[6]; // (Sum(x^2) - (Sum(x)*Sum(x)) / n ) / n
    covarianceMatrix.coeffRef( 1 ) = buffer[1] - buffer[6] * buffer[7]; // etc
    covarianceMatrix.coeffRef( 2 ) = buffer[2] - buffer[6] * buffer[8];
    covarianceMatrix.coeffRef( 4 ) = buffer[3] - buffer[7] * buffer[7];
    covarianceMatrix.coeffRef( 5 ) = buffer[4] - buffer[7] * buffer[8];
    covarianceMatrix.coeffRef( 8 ) = buffer[5] - buffer[8] * buffer[8];
    covarianceMatrix.coeffRef( 3 ) = covarianceMatrix.coeff( 1 );
    covarianceMatrix.coeffRef( 6 ) = covarianceMatrix.coeff( 2 );
    covarianceMatrix.coeffRef( 7 ) = covarianceMatrix.coeff( 5 );

    return std::make_pair( covarianceMatrix, centroid );
  }
}



// ######################################################################
template<size_t numRows>
DetectedPlane refinePlaneFromGroups(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, numRows> const & groups, nrt::PointCloud2 const & cloud)
{
  nrt::PointCloud2 groupsCloud;
  groupsCloud.resize(plane.groups.size()*2);

  auto cloudIt = groupsCloud.geometry_begin();
  for(groupid_t gid : plane.groups)
  {
    LiDARGroup const & group = groupFromId(groups, gid);
    cloudIt->getVector3Map() = group.centroid;
    ++cloudIt;
    cloudIt->getVector3Map() = group.centroid + group.e3;
    ++cloudIt;
  }

  Eigen::Vector4f planeParameters;
  nrt::PointCloud2::Geometry centroid = nrt::computeCentroid(groupsCloud);
  nrt::Normals::computePointNormal(groupsCloud, false, planeParameters);
  nrt::Normals::faceNormalToViewPoint(centroid, {0,0,0}, planeParameters);

  DetectedPlane refinedPlane = plane;
  refinedPlane.plane = Eigen::Hyperplane<float, 3>(planeParameters.block<3,1>(0,0), planeParameters[3]);

  return refinedPlane;
}

template
DetectedPlane refinePlaneFromGroups<32>(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, 32> const & groups, nrt::PointCloud2 const & cloud);

template
DetectedPlane refinePlaneFromGroups<64>(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, 64> const & groups, nrt::PointCloud2 const & cloud);

// ######################################################################
template<size_t numRows>
DetectedPlane refinePlaneFromPoints(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, numRows> const & groups, nrt::PointCloud2 const & cloud, RefineFrom const refinefrom)
{
  nrt::Indices indices;

  if(refinefrom == RefineFrom::Groups)
  {
    assert(plane.groups.size());
    for(groupid_t gid : plane.groups)
    {
      LiDARGroup const & group = groupFromId(groups, gid);
      indices.insert(indices.end(), group.indices.begin(), group.indices.end());
    }
  }
  else if(refinefrom == RefineFrom::Indices)
  {
    assert(plane.indices.size());
    indices = plane.indices;
  }

  Eigen::Vector4f planeParameters;
  //nrt::PointCloud2::Geometry centroid = nrt::computeCentroid(cloud, indices);
  //nrt::Normals::computePointNormal(cloud, indices, false, planeParameters);

  // solve plane parameters
  auto covarianceCentroid = computeCovarianceAndCentroidSSE(cloud.subset_begin(indices), cloud.subset_end(indices), indices.size());
  nrt::solvePlaneParameters( covarianceCentroid.first, covarianceCentroid.second, planeParameters);
  nrt::Normals::faceNormalToViewPoint(covarianceCentroid.second, {0,0,0}, planeParameters);

  DetectedPlane refinedPlane = plane;
  refinedPlane.plane = Eigen::Hyperplane<float, 3>(planeParameters.block<3,1>(0,0).normalized(), planeParameters[3]);

  return refinedPlane;
}

template
DetectedPlane refinePlaneFromPoints<32>(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, 32> const & groups, nrt::PointCloud2 const & cloud, RefineFrom const refinefrom);

template
DetectedPlane refinePlaneFromPoints<64>(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, 64> const & groups, nrt::PointCloud2 const & cloud, RefineFrom const refinefrom);

// ######################################################################
template<size_t numRows>
DetectedPlane refinePlaneFromPointsFast(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, numRows> const & groups, nrt::PointCloud2 const & cloud, RefineFrom const refinefrom)
{
  Eigen::Matrix3f M_tilde = Eigen::Matrix3f::Zero();
  Eigen::Vector3f b_tilde = Eigen::Vector3f::Zero();
  size_t count = 0;

  if(refinefrom == RefineFrom::Groups)
  {
    assert(plane.groups.size());
    for(groupid_t gid : plane.groups)
    {
      LiDARGroup const & group = groupFromId(groups, gid);
      for(size_t index : group.indices)
      {
        auto const geo = cloud[index].getVector3Map();

        M_tilde += geo * geo.transpose();
        b_tilde += geo;
        ++count;
      }
    }
  }
  else if(refinefrom == RefineFrom::Indices)
  {
    assert(plane.indices.size());

    for(auto it = cloud.subset_begin(plane.indices), end = cloud.subset_end(plane.indices); it != end; ++it)
    {
      alignas(16) auto const geo = it->geometry().getVector3Map();
      M_tilde += geo * geo.transpose();
      b_tilde += geo;
    }
    count = plane.indices.size();
  }

  Eigen::Vector3f normal = (M_tilde.inverse() * b_tilde).normalized();
  float rho = (b_tilde / static_cast<float>(count)).norm();

  if(-normal.dot(normal) < 0)
    normal *= -1;

  DetectedPlane refinedPlane = plane;
  refinedPlane.plane = Eigen::Hyperplane<float, 3>(normal, rho);

  return refinedPlane;
}

template
DetectedPlane refinePlaneFromPointsFast<32>(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, 32> const & groups, nrt::PointCloud2 const & cloud, RefineFrom const refinefrom);

template
DetectedPlane refinePlaneFromPointsFast<64>(DetectedPlane const & plane,
    std::array<std::vector<LiDARGroup>, 64> const & groups, nrt::PointCloud2 const & cloud, RefineFrom const refinefrom);

// ######################################################################
DetectedPlane refinePlaneLeastSquares(DetectedPlane const & plane, nrt::PointCloud2 const & cloud)
{
  DetectedPlane refinedPlane = plane;

  Eigen::Vector3f centroid(0,0,0);
  float sumOfWeights = 0.0f;
  std::vector<float> ws(plane.indices.size());
  std::vector<float>::iterator wsIt = ws.begin();

  for(auto i = cloud.subset_begin<Distance>(plane.indices), end = cloud.subset_end<Distance>(plane.indices); i != end; ++i)
  {
    Eigen::Vector3f const point = i->geometry().getVector3Map();

    float const w_i = 1.0f / (0.02f + i->get<Distance>().value * (1.0f / 1000.0f));
    centroid += w_i * point;
    sumOfWeights += w_i;
    *wsIt++ = w_i;
  }

  centroid /= sumOfWeights;

  Eigen::Matrix3f M = Eigen::Matrix3f::Zero();
  wsIt = ws.begin();
  for(auto p : cloud.subset_range(plane.indices))
  {
    Eigen::Vector3f const p_m_c = (p.geometry().getVector3Map() - centroid);
    M += (*wsIt++) * p_m_c * p_m_c.transpose();
  }

  // Find the plane parameters
  Eigen::Vector3d const centroidD = centroid.cast<double>();
  Eigen::Matrix3d const MD = M.cast<double>();
  Eigen::Matrix3d eigenVectors;
  Eigen::Vector3d eigenValues;
  nrt::eigenDecomposition<double>(MD, eigenValues, eigenVectors);
  Eigen::Vector3d n(eigenVectors(0,0), eigenVectors(1,0), eigenVectors(2,0));
  double rho = n.dot(centroidD);
  refinedPlane.eigenValues = eigenValues;

  double const H_dd = -sumOfWeights;
  Eigen::Vector3d const H_nd = -H_dd * centroidD;
  Eigen::Matrix3d const H_nn = -MD + H_dd * centroidD * centroidD.transpose() + (n.transpose()*MD*n)[0]*Eigen::Matrix3d::Identity();

  // Set the hessian matrix
  refinedPlane.hessian.block<3,3>(0,0) = H_nn;
  refinedPlane.hessian.block<3,1>(0,3) = H_nd;
  refinedPlane.hessian.block<1,3>(3,0) = H_nd.transpose();
  refinedPlane.hessian(3,3) = H_dd;

  refinedPlane.covariance = -pseudoInverse(refinedPlane.hessian);

  Eigen::Matrix3d H_nn_inv = H_nn.inverse();
  double const den = (n.transpose() * H_nn_inv * H_nd)(0,0);
  refinedPlane.rhoCovariance = -(n.transpose() * H_nn_inv * n)(0,0) / (den*den);

  Eigen::Matrix3d const H_nn_prime = H_nn - (1.0 / H_dd) * (H_nd * H_nd.transpose());

  nrt::eigenDecomposition<double>(-H_nn_prime, eigenValues, eigenVectors);

  Eigen::Matrix3d test = Eigen::Matrix3d::Zero(); // TODO: look this over, use this or old way?
  test += eigenVectors.block<3, 1>(0, 2) * eigenVectors.block<3, 1>(0, 2).transpose() / eigenValues(2);
  test += eigenVectors.block<3, 1>(0, 1) * eigenVectors.block<3, 1>(0, 1).transpose() / eigenValues(1);
  refinedPlane.normalCovariance = test;
  //refinedPlane.normalCovariance = -pseudoInverse(H_nn_prime);

  rho = -rho;
  if((-centroidD).dot(n) < 0)
  {
    n = -n;
    rho = -n.dot(centroidD);
  }
  refinedPlane.plane = Eigen::Hyperplane<float, 3>(n.cast<float>(), rho);
  refinedPlane.centroid = centroidD;
  refinedPlane.pointCovariance = MD / sumOfWeights;

  return refinedPlane;
}
