#include "CeresFunctor.H"
#include <ceres/ceres.h>
#include <PointCloud/Features/Planes/details/FrameTiming.H>
#include <Config/Config.H>
#include <PointCloud/Features/Planes/details/PolygonHelpers.H>

// ######################################################################
  template<class T>
Eigen::Transform<T,3,Eigen::Isometry> makeTransform(T const * const x)
{
  Eigen::Transform<T, 3, Eigen::Isometry> transform;
  auto & mat = transform.matrix();

  // Rotation
  T const & r1 = x[3];
  T const & r2 = x[4];
  T const & r3 = x[5];

  T const s1 = ceres::sin(r3);
  T const s2 = ceres::sin(r2);
  T const s3 = ceres::sin(r1);

  T const c1 = ceres::cos(r3);
  T const c2 = ceres::cos(r2);
  T const c3 = ceres::cos(r1);

  mat.template topLeftCorner<3,3>() <<
   c1*c2,  c1*s2*s3 - s1*c3,  s1*s3 + c1*s2*c3,
   s1*c2,  s1*s2*s3 + c1*c3,  s1*s2*c3 - c1*s3,
   -s2,    c2*s3,             c2*c3;
  //mat.template topLeftCorner<3,3>() =
  //  (Eigen::AngleAxis<T> (x[5], Eigen::Matrix<T,3,1>::UnitZ ()) *
  //   Eigen::AngleAxis<T> (x[4], Eigen::Matrix<T,3,1>::UnitY ()) *
  //   Eigen::AngleAxis<T> (x[3], Eigen::Matrix<T,3,1>::UnitX ())).matrix();

  // Translation
  mat(0,3) = x[0];
  mat(1,3) = x[1];
  mat(2,3) = x[2];

  // Homogeneous
  mat(3,0) = T(0);
  mat(3,1) = T(0);
  mat(3,2) = T(0);
  mat(3,3) = T(1);

  return transform;
}

// ######################################################################
  template<class T>
Eigen::Matrix<T,3,1> makeVector(nrt::PointCloud2::Geometry const & p)
{
  Eigen::Matrix<T,3,1> ret;
  ret(0,0) = T(p.x());
  ret(1,0) = T(p.y());
  ret(2,0) = T(p.z());
  return ret;
}

// ######################################################################
struct PointsCostFunctor : ceres::SizedCostFunction<1, 6>
{
  using Vector3 = Eigen::Matrix<double,3,1>;
  using Vector6 = Eigen::Matrix<double,6,1>;
  using Matrix3 = Eigen::Matrix<double,3,3>;

  PointsCostFunctor(nrt::PointCloud2 const & sourceCloud, nrt::PointCloud2 const & targetCloud,
      nrt::Correspondences const & correspondences, std::vector<Eigen::Matrix3d> const & mahalanobis) :
    n(correspondences.size())
  {
    data.reserve(correspondences.size());
    for(size_t i=0; i<correspondences.size(); ++i)
      data.emplace_back(
          makeVector<double>(sourceCloud[correspondences[i].sourceIndex]),
          makeVector<double>(targetCloud[correspondences[i].targetIndex]),
          mahalanobis[i]
          );
  }

  bool Evaluate(double const * const * parameters, double * residuals, double** jacobians) const
  {
    auto x = Eigen::Map<const Vector6>(parameters[0]);

    auto transform = makeTransform(parameters[0]);

    double & f = residuals[0];
    f = 0;

    Vector6 J = Vector6::Zero();
    Matrix3 R = Matrix3::Zero();

    for(auto it=data.begin(); it != data.end(); ++it)
    {
      Vector3 const & pSrc = std::get<0>(*it);
      Vector3 const & pTgt = std::get<1>(*it);
      Matrix3 const & M    = std::get<2>(*it);

      Vector3 pTransformed = transform * pSrc; // T * a

      Vector3 d = pTransformed - pTgt;

      Vector3 temp = M * d;

      // Increment the error function
      f += d.transpose() * temp;

      // Increment the translation gradient
      if(jacobians)
      {
        J.head<3>() += temp;

        // Increment the rotation gradient
        R += pSrc * temp.transpose();
      }
    }

    f /= double(n);

    if(jacobians)
    {
      double const norm = 2.0 / double(n);
      J.head<3>() *= norm;
      R *= norm;

      J.tail<3>() = computeRDerivative( x, R );

      auto J_map = Eigen::Map<Vector6>(jacobians[0]);
      J_map = J;
    }

    return true;
  }

  Vector3 computeRDerivative( Vector6 const & x, Matrix3 const & R) const
  {
    //! Computes trace of mat1' * mat2
    auto matricesInnerProd = [](Matrix3 const & mat1, Matrix3 const & mat2)
    {
      double r = 0.0;
      for( size_t i = 0; i < 3; ++i )
        for( size_t j = 0; j < 3; ++j )
          r += mat1(j,i) * mat2(i,j);

      return r;
    };

    Matrix3 dR_dPhi;
    Matrix3 dR_dTheta;
    Matrix3 dR_dPsi;

    double phi = x[3], theta = x[4], psi = x[5];

    double cphi   = std::cos(phi),    sphi   = std::sin(phi);
    double ctheta = std::cos(theta),  stheta = std::sin(theta);
    double cpsi   = std::cos(psi),    spsi   = std::sin(psi);

    dR_dPhi(0,0) = 0.;
    dR_dPhi(1,0) = 0.;
    dR_dPhi(2,0) = 0.;

    dR_dPhi(0,1) = sphi*spsi + cphi*cpsi*stheta;
    dR_dPhi(1,1) = -cpsi*sphi + cphi*spsi*stheta;
    dR_dPhi(2,1) = cphi*ctheta;

    dR_dPhi(0,2) = cphi*spsi - cpsi*sphi*stheta;
    dR_dPhi(1,2) = -cphi*cpsi - sphi*spsi*stheta;
    dR_dPhi(2,2) = -ctheta*sphi;

    dR_dTheta(0,0) = -cpsi*stheta;
    dR_dTheta(1,0) = -spsi*stheta;
    dR_dTheta(2,0) = -ctheta;

    dR_dTheta(0,1) = cpsi*ctheta*sphi;
    dR_dTheta(1,1) = ctheta*sphi*spsi;
    dR_dTheta(2,1) = -sphi*stheta;

    dR_dTheta(0,2) = cphi*cpsi*ctheta;
    dR_dTheta(1,2) = cphi*ctheta*spsi;
    dR_dTheta(2,2) = -cphi*stheta;

    dR_dPsi(0,0) = -ctheta*spsi;
    dR_dPsi(1,0) = cpsi*ctheta;
    dR_dPsi(2,0) = 0.;

    dR_dPsi(0,1) = -cphi*cpsi - sphi*spsi*stheta;
    dR_dPsi(1,1) = -cphi*spsi + cpsi*sphi*stheta;
    dR_dPsi(2,1) = 0.;

    dR_dPsi(0,2) = cpsi*sphi - cphi*spsi*stheta;
    dR_dPsi(1,2) = sphi*spsi + cphi*cpsi*stheta;
    dR_dPsi(2,2) = 0.;

    return Vector3
      (
       matricesInnerProd(dR_dPhi, R),
       matricesInnerProd(dR_dTheta, R),
       matricesInnerProd(dR_dPsi, R)
      );
  }

  size_t n;
  std::vector<std::tuple<Vector3, Vector3, Matrix3>> data;
};

// ######################################################################
Eigen::Matrix3d skewSymmetric(Eigen::Vector3d const & w)
{
  return (Eigen::Matrix3d() <<
      0.0,     -w.z(),  +w.y(),
      +w.z(),  0.0,     -w.x(),
      -w.y(),  +w.x(),  0.0).finished();
}

// ######################################################################
struct PlaneCostFunctor : ceres::SizedCostFunction<4, 6>
{
  using Plane = Eigen::Hyperplane<double, 3>;

  PlaneCostFunctor(DetectedPlane const & source, DetectedPlane const & target) :
    source(source.plane), target(target.plane)
  { }

  bool Evaluate(double const * const * parameters, double * residuals, double** jacobians) const
  {
    auto const transform = makeTransform(parameters[0]);

    Eigen::Matrix3d const Rt = transform.rotation().transpose();
    Eigen::Vector3d const t = transform.translation();

    Eigen::Map<Eigen::Vector4d> error(residuals);
    error.head<3>() = Rt*target.normal() - source.normal();
    error[3] = t.transpose()*target.normal() + target.offset() - source.offset();

    if(jacobians)
    {
      Eigen::Map<Eigen::Matrix<double,4,6,Eigen::RowMajor>> H(jacobians[0]);

      // dn / dw
      H.block<3,3>(0,3) = skewSymmetric(Rt * target.normal());

      // dd / dw
      H.block<1,3>(3,3) = Eigen::Matrix<double, 1, 3>::Zero();

      // dn / dt
      H.block<3,3>(0,0) = Eigen::Matrix<double, 3, 3>::Zero();

      // dd / dt
      H.block<1,3>(3,0) = (Rt*target.normal()).transpose();
    }

    return true;
  }

  Plane const source;
  Plane const target;
};

template<class T>
void printit(Eigen::Matrix<T, 3, 1> const & v) { std::cout << v.transpose() << std::endl; }
template<>
void printit<ceres::Jet<double, 6>>(Eigen::Matrix<ceres::Jet<double,6>,3,1> const & v)
{ std::cout << v[0].a << " " << v[1].a << " " << v[2].a << std::endl; }

// ######################################################################
struct PlaneOverlapCostFunctor
{
  using Plane = Eigen::Hyperplane<double, 3>;

  PlaneOverlapCostFunctor(DetectedPlane const & source, DetectedPlane const & target) :
    source(source), target(target)
  { }

  template<class T>
    bool operator()(T const * const parameters, T * residuals) const
    {
      auto const transform = makeTransform(parameters);

      assert(source.centroid);
      assert(target.centroid);

      auto const R = transform.rotation();
      auto const t = transform.translation();

      auto const n_m = source.plane.normal().cast<T>();
      auto const d_m = T(source.plane.offset());

      auto const c_m = source.centroid->cast<T>();
      auto const c_z = target.centroid->cast<T>();

      // Target centroid, moved into source frame
      auto c_z_m = R*c_z - t;

      // Target centroid, moved into source frame, and projected onto source plane
      auto c_z_m_p = c_z_m - (n_m.dot(c_z_m) + d_m) * n_m;

      Eigen::Quaternionf quat;
      quat.setFromTwoVectors(source.plane.normal(), Eigen::Vector3f::UnitZ());
      Eigen::Matrix<T,3,3> const R_2d = quat.normalized().matrix().cast<T>();


      Eigen::Matrix<T,3,1> ee = R_2d*(c_m - c_z_m_p);


      residuals[0] = ee[0];
      residuals[1] = ee[1];
      residuals[2] = T(0); //ee[2];

      return true;
    }

  DetectedPlane const source;
  DetectedPlane const target;
};

// ######################################################################
Eigen::Isometry3f CeresFunctor::operator()(
    std::vector<std::pair<DetectedPlane, DetectedPlane>> sourceTargetPlanes,
    nrt::PointCloud2 const & sourceCloud, nrt::PointCloud2 const & targetCloud,
    nrt::Correspondences const & correspondences,
    std::vector<Eigen::Matrix3d> const & mahalanobis,
    Eigen::Isometry3f const & transform,
    Eigen::Isometry3f const & )
{
  ceres::Problem problem;

  auto & transform_matrix = transform.matrix();

  // Set up the initial conditions for the solver

  std::array<double, 6> x = {{
    0,0,0,0,0,0
    //transform_matrix(0,3),
    //transform_matrix(1,3),
    //transform_matrix(2,3),
    //std::atan2(transform_matrix(2,1), transform_matrix(2,2)),
    //std::asin(-transform_matrix(2,0)),
    //std::atan2(transform_matrix(1,0), transform_matrix(0,0))
  }};

  // Create a single cost functor to calculate the cost for all non-plane points
  frame_timing::start("Create Point Cost Functors");
  if(!correspondences.empty())
  {
    auto cost_function =
      new PointsCostFunctor(sourceCloud, targetCloud, correspondences, mahalanobis);

    double const scale = (config::lookup<bool>("picp.optimizer.scaleBySize") ? correspondences.size() : 1.0);

    problem.AddResidualBlock(cost_function,
        new ceres::ScaledLoss(new ceres::TrivialLoss(), scale, ceres::TAKE_OWNERSHIP), x.data());
  }
  frame_timing::stop("Create Point Cost Functors");

  // Create a cost functor for each plane
  frame_timing::start("Create Plane Cost Functors");
  for(auto const & stp : sourceTargetPlanes)
  {
    double const loss_function_scale = config::lookup<double>("picp.optimizer.planeLossScale");
    std::string const loss_function_type = config::lookup<std::string>("picp.optimizer.planeLossFunction");

    auto create_loss_function = [loss_function_type](float const loss_function_scale)
    {
      ceres::LossFunction * loss_function;
      if(loss_function_type == "Trivial")
        loss_function = new ceres::TrivialLoss();
      else if(loss_function_type == "Huber")
        loss_function = new ceres::HuberLoss(loss_function_scale);
      else if(loss_function_type == "SoftLOneLoss")
        loss_function = new ceres::SoftLOneLoss(loss_function_scale);
      else if(loss_function_type == "Cauchy")
        loss_function = new ceres::CauchyLoss(loss_function_scale);
      else if(loss_function_type == "Arctan")
        loss_function = new ceres::ArctanLoss(loss_function_scale);
      else
        throw std::runtime_error("Unknown loss function type " + loss_function_type);
      return loss_function;
    };

    double const scale =
      (config::lookup<bool>("picp.optimizer.scaleBySize") ? (stp.first.indices.size() + stp.second.indices.size()) / 2.0 : 1.0)
      * config::lookup<double>("picp.optimizer.planeScale");

    auto cost_function = new PlaneCostFunctor(stp.first, stp.second);
    problem.AddResidualBlock(cost_function, new ceres::ScaledLoss(create_loss_function(loss_function_scale), scale, ceres::TAKE_OWNERSHIP), x.data());

    //if(std::abs(stp.first.plane.normal().dot(Eigen::Vector3f::UnitZ())) > .5) continue;

    //float const overlap_scale = std::pow(polygonhelpers::intersectionOverUnion(stp.first, stp.second), 1);

    //std::cout << "::::::::::::::: " << overlap_scale << std::endl;

    //auto overlap_cost_function = new ceres::AutoDiffCostFunction<PlaneOverlapCostFunctor, 3, 6>(
    //    new PlaneOverlapCostFunctor(stp.first, stp.second));
    //problem.AddResidualBlock(overlap_cost_function, new ceres::ScaledLoss(
    //      create_loss_function(loss_function_scale), overlap_scale, ceres::TAKE_OWNERSHIP), x.data());
  }
  frame_timing::stop("Create Plane Cost Functors");

  // Solve!
  ceres::Solver::Options options;
  options.num_threads = 1;
  //options.dense_linear_algebra_library_type  = ceres::LAPACK;
  //options.sparse_linear_algebra_library_type = ceres::SUITE_SPARSE;

  // Line Search / BFGS converges _much_ better than any of the trust region methods.
  options.minimizer_type = ceres::MinimizerType::LINE_SEARCH;
  options.line_search_direction_type = ceres::LineSearchDirectionType::BFGS;
  //options.parameter_tolerance = config::lookup<double>("picp.optimizer.parameterTolerance");// 1e-12;
  //options.function_tolerance = config::lookup<double>("picp.optimizer.functionTolerance"); //1e-16;
  //options.max_num_iterations = 100;

  //options.use_approximate_eigenvalue_bfgs_scaling = true;

  //options.check_gradients = true;
  //options.minimizer_progress_to_stdout = true;
  options.logging_type = ceres::LoggingType::SILENT;
  ceres::Solver::Summary summary;


  frame_timing::start("Ceres Solve");
  ceres::Solve(options, &problem, &summary);
  frame_timing::stop("Ceres Solve");

  //std::cout << summary.FullReport() << std::endl;


  Eigen::Isometry3f result = Eigen::Isometry3f::Identity();
  Eigen::Matrix3f R;
  R = Eigen::AngleAxisf (x[5], Eigen::Vector3f::UnitZ ())
    * Eigen::AngleAxisf (x[4], Eigen::Vector3f::UnitY ())
    * Eigen::AngleAxisf (x[3], Eigen::Vector3f::UnitX ());
  result.matrix().topLeftCorner<3,3>() = R.matrix();


  result.matrix().block<3,1>(0,3) = Eigen::Vector3f(x[0], x[1], x[2]);

  return result;
}
