#include "BFGSFunctor.H"

BFGSFunctor::BFGSFunctor( nrt::PointCloud2::ConstIterator<> && sourceIter,
                          nrt::PointCloud2::ConstIterator<> && targetIter,
                          std::vector<Matrix3> const & mahalanobis,
                          AffineTransformf const & guess,
                          size_t numInputs ) :
  itsSourceBegin( sourceIter ),
  itsTargetBegin( targetIter ),
  itsMahalanobis( mahalanobis ),
  itsGuess( guess ),
  itsNumInputs( numInputs )
{ }

//! Updates a transformation matrix using the 6D optimization value
void BFGSFunctor::updateTransform( Eigen::Matrix4f & t, Vector6 const & x )
{
  typedef float Scalar;
  typedef Eigen::AngleAxis<float> AngleAxis;

  // !!! CAUTION Stanford GICP uses the Z Y X euler angles convention
  Eigen::Matrix3f R;
  R = AngleAxis (static_cast<Scalar> (x[5]), Eigen::Vector3f::UnitZ ())
    * AngleAxis (static_cast<Scalar> (x[4]), Eigen::Vector3f::UnitY ())
    * AngleAxis (static_cast<Scalar> (x[3]), Eigen::Vector3f::UnitX ());

  t.topLeftCorner<3,3>().matrix() = R * t.topLeftCorner<3,3>().matrix();

  Eigen::Vector4f col(static_cast<Scalar> (x[0]), static_cast<Scalar> (x[1]), static_cast<Scalar> (x[2]), 0.0f);
  t.col(3) += col;
}

void BFGSFunctor::computeRDerivative( Vector6 const & x, Matrix3 const & R, Vector6 & g ) const
{
  //! Computes trace of mat1' * mat2
  auto matricesInnerProd = [](Matrix3 const & mat1, Matrix3 const & mat2)
  {
    double r = 0.0;
    for( size_t i = 0; i < 3; ++i )
      for( size_t j = 0; j < 3; ++j )
        r += mat1(j,i) * mat2(i,j);

    return r;
  };

  Matrix3 dR_dPhi;
  Matrix3 dR_dTheta;
  Matrix3 dR_dPsi;

  double phi = x[3], theta = x[4], psi = x[5];

  double cphi = std::cos(phi), sphi = std::sin(phi);
  double ctheta = std::cos(theta), stheta = std::sin(theta);
  double cpsi = std::cos(psi), spsi = std::sin(psi);

  dR_dPhi(0,0) = 0.;
  dR_dPhi(1,0) = 0.;
  dR_dPhi(2,0) = 0.;

  dR_dPhi(0,1) = sphi*spsi + cphi*cpsi*stheta;
  dR_dPhi(1,1) = -cpsi*sphi + cphi*spsi*stheta;
  dR_dPhi(2,1) = cphi*ctheta;

  dR_dPhi(0,2) = cphi*spsi - cpsi*sphi*stheta;
  dR_dPhi(1,2) = -cphi*cpsi - sphi*spsi*stheta;
  dR_dPhi(2,2) = -ctheta*sphi;

  dR_dTheta(0,0) = -cpsi*stheta;
  dR_dTheta(1,0) = -spsi*stheta;
  dR_dTheta(2,0) = -ctheta;

  dR_dTheta(0,1) = cpsi*ctheta*sphi;
  dR_dTheta(1,1) = ctheta*sphi*spsi;
  dR_dTheta(2,1) = -sphi*stheta;

  dR_dTheta(0,2) = cphi*cpsi*ctheta;
  dR_dTheta(1,2) = cphi*ctheta*spsi;
  dR_dTheta(2,2) = -cphi*stheta;

  dR_dPsi(0,0) = -ctheta*spsi;
  dR_dPsi(1,0) = cpsi*ctheta;
  dR_dPsi(2,0) = 0.;

  dR_dPsi(0,1) = -cphi*cpsi - sphi*spsi*stheta;
  dR_dPsi(1,1) = -cphi*spsi + cpsi*sphi*stheta;
  dR_dPsi(2,1) = 0.;

  dR_dPsi(0,2) = cpsi*sphi - cphi*spsi*stheta;
  dR_dPsi(1,2) = sphi*spsi + cphi*cpsi*stheta;
  dR_dPsi(2,2) = 0.;

  g[3] = matricesInnerProd(dR_dPhi, R);
  g[4] = matricesInnerProd(dR_dTheta, R);
  g[5] = matricesInnerProd(dR_dPsi, R);
}

// Evaluate error function
double BFGSFunctor::operator()( Vector6 const & x )
{
  // Create transform matrix
  Eigen::Matrix4f transformMatrix = itsGuess.matrix();
  updateTransform( transformMatrix, x );

  // Evaluate cost function
  double f = 0;

  for( size_t i = 0; i < itsNumInputs; ++i )
  {
    // paper deals with aligning A to B, in our case we align SOURCE to TARGET
    auto pSrc = itsSourceBegin->geometry().getVectorMap(); // a in paper
    auto pTgt = itsTargetBegin->geometry().getVectorMap(); // b in paper

    ++itsSourceBegin;
    ++itsTargetBegin;

    auto pTransformed = transformMatrix * pSrc; // T * a

    // even though the paper now says to do b - T*a, authors original source
    // code does T*a - b
    Vector3 d = pTransformed.topLeftCorner<3,1>().cast<double>() - pTgt.topLeftCorner<3,1>().cast<double>(); // d in paper

    // d' * M * d
    f += d.transpose() * itsMahalanobis[i] * d;
  }

  // reset iterators for next time around
  itsSourceBegin -= itsNumInputs;
  itsTargetBegin -= itsNumInputs;

  // return normalized cost
  return f / static_cast<double>( itsNumInputs );
}

// Compute first derivative
void BFGSFunctor::df( Vector6 const & x, Vector6 & g )
{
  // Create transform matrix
  Eigen::Matrix4f transformMatrix = itsGuess.matrix();
  updateTransform( transformMatrix, x );

  g.setZero(); // zero out g

  Matrix3 R = Matrix3::Zero();

  for( size_t i = 0; i < itsNumInputs; ++i )
  {
    // paper deals with aligning A to B, in our case we align SOURCE to TARGET
    auto pSrc = itsSourceBegin->geometry().getVectorMap(); // a in paper
    auto pTgt = itsTargetBegin->geometry().getVectorMap(); // b in paper

    ++itsSourceBegin;
    ++itsTargetBegin;

    Eigen::Vector4f pTransformed = transformMatrix * pSrc; // T * a

    Vector3 d = pTransformed.topLeftCorner<3,1>().cast<double>() - pTgt.topLeftCorner<3,1>().cast<double>(); // d in paper

    Vector3 temp = itsMahalanobis[i] * d;

    // Increment translation gradient
    g.head<3>() += temp;

    // Increment rotation gradient
    pTransformed = itsGuess * pSrc;
    R += pTransformed.topLeftCorner<3,1>().cast<double>() * temp.transpose();
  }

  // reset iterators for next time around
  itsSourceBegin -= itsNumInputs;
  itsTargetBegin -= itsNumInputs;

  // Normalize values
  const double norm = 2.0 / static_cast<double>( itsNumInputs );
  g.head<3>() *= norm;
  R *= norm;
  computeRDerivative( x, R, g );
}

// Compute second derivative
void BFGSFunctor::fdf( Vector6 const & x, double & f, Vector6 & g )
{
  // Create transform matrix
  Eigen::Matrix4f transformMatrix = itsGuess.matrix();
  updateTransform( transformMatrix, x );

  // set zeros
  f = 0.0;
  g.setZero();

  Matrix3 R = Matrix3::Zero();

  for( size_t i = 0; i < itsNumInputs; ++i )
  {
    // paper deals with aligning A to B, in our case we align SOURCE to TARGET
    auto pSrc = itsSourceBegin->geometry().getVectorMap(); // a in paper
    auto pTgt = itsTargetBegin->geometry().getVectorMap(); // b in paper

    ++itsSourceBegin;
    ++itsTargetBegin;

    Eigen::Vector4f pTransformed = transformMatrix * pSrc; // T * a

    Vector3 d = pTransformed.topLeftCorner<3,1>().cast<double>() - pTgt.topLeftCorner<3,1>().cast<double>(); // d in paper

    Vector3 temp = itsMahalanobis[i] * d;

    // Increment total error
    f += static_cast<double>( d.transpose() * temp );

    // Increment translation gradient
    g.head<3>() += temp;

    // Increment rotation gradient
    pTransformed = itsGuess * pSrc;
    R += pTransformed.topLeftCorner<3,1>().cast<double>() * temp.transpose();
  }

  // reset iterators for next time around
  itsSourceBegin -= itsNumInputs;
  itsTargetBegin -= itsNumInputs;

  // Normalize values
  f /= static_cast<double>( itsNumInputs );

  const double norm = 2.0 / static_cast<double>( itsNumInputs );
  g.head<3>() *= norm;
  R *= norm;
  computeRDerivative( x, R, g );
}
