#include "PoseToPlaneFactor.H"
#include <gtsam/geometry/Pose3.h>
#include <gtsam/geometry/Point3.h>
#include <gtsam/nonlinear/Symbol.h>
#include <gtsam/slam/PriorFactor.h>
#include <gtsam/slam/BetweenFactor.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/nonlinear/ISAM2.h>
#include <gtsam/nonlinear/Values.h>
#include <gtsam/nonlinear/NonlinearFactor.h>

#include <random>

#define UPDATE_ISAM(frame) \
  isam.update(graph, initialEstimate); \
  currentEstimate = isam.calculateBestEstimate(); \
  std::cout << "======================================" << std::endl; \
  currentEstimate.print("Frame " + std::to_string(frame) + " Estimate:\n"); \
  graph.resize(0); \
  initialEstimate.clear(); \

const static bool GROUND_TRUTH = false;
const static double SIGMA_A = 0.1;
const static double SIGMA_T = 0.05;

double ga()
{
  static std::random_device rd;
  static std::mt19937 gen(rd());
  static std::normal_distribution<> d(0,(GROUND_TRUTH) ? 0. : SIGMA_A);

  return d(gen);
}

double gt()
{
  static std::random_device rd;
  static std::mt19937 gen(rd());
  static std::normal_distribution<> d(0,(GROUND_TRUTH) ? 0. : SIGMA_T);

  return d(gen);
}

int main()
{
  using namespace gtsam;

  Symbol p0('p', 0), p1('p', 1), p2('p', 2), p3('p', 3);
  Symbol m0('m', 0), m1('m', 1), m2('m', 2), m3('m', 3);

  NonlinearFactorGraph graph;
  Values initialEstimate;
  Values currentEstimate;

  ISAM2Params parameters;
  parameters.relinearizeThreshold = 0.01;
  parameters.relinearizeSkip = 1;
  ISAM2 isam(parameters);

  noiseModel::Diagonal::shared_ptr priorNoise       = noiseModel::Diagonal::Sigmas((Vector(6) << SIGMA_A, SIGMA_A, SIGMA_A, SIGMA_T, SIGMA_T, SIGMA_T));
  noiseModel::Diagonal::shared_ptr odometryNoise    = noiseModel::Diagonal::Sigmas((Vector(6) << SIGMA_A, SIGMA_A, SIGMA_A, SIGMA_T, SIGMA_T, SIGMA_T));
  noiseModel::Diagonal::shared_ptr measurementNoise = noiseModel::Diagonal::Sigmas((Vector(4) << SIGMA_T, SIGMA_T, SIGMA_T, SIGMA_T));

  graph.add(PriorFactor<Pose3>(p0, Pose3(Rot3::identity(), Point3(0+gt(),0+gt(),0+gt())), priorNoise));
  graph.add(PoseToPlaneFactor(PlaneValue(-1+gt(), 0+gt(), 0+gt(), 10+gt()), measurementNoise, p0, m0));
  graph.add(PoseToPlaneFactor(PlaneValue(0+gt(), 0+gt(), -1+gt(), 10+gt()), measurementNoise, p0, m2));

  initialEstimate.insert(p0, Pose3(Rot3::Rz(ga()), Point3(gt(), gt(), gt()))); // gt
  initialEstimate.insert(m0, PlaneValue(-1 + gt(), 0 + gt(), 0 + gt(), 10 + gt())); // gt
  initialEstimate.insert(m2, PlaneValue(0 + gt(), 0 + gt(), -1 + gt(), 10 + gt())); // gt

  UPDATE_ISAM(0);

  graph.add(BetweenFactor<Pose3>(p0, p1, Pose3(Rot3::Rz(M_PI/4.0+ga()), Point3(2+gt(), 2+gt(), 0+gt())), odometryNoise));
  graph.add(PoseToPlaneFactor(PlaneValue(-1./std::sqrt(2.)+gt(), 1./std::sqrt(2.)+gt(), 0+gt(), 8+gt()), measurementNoise, p1, m0));
  graph.add(PoseToPlaneFactor(PlaneValue(0+gt(), -1+gt(), 0+gt(), 10+gt()), measurementNoise, p1, m1));
  graph.add(PoseToPlaneFactor(PlaneValue(0+gt(), 0+gt(), -1+gt(), 10+gt()), measurementNoise, p1, m2));

  initialEstimate.insert(p1, Pose3(Rot3::Rz(M_PI/4.0 + ga()), Point3(2 + gt(), 2 + gt(), 0 + gt()))); // gt
  initialEstimate.insert(m1, PlaneValue(1./std::sqrt(2.) + gt(), -1./std::sqrt(2.) + gt(), 0 + gt(), 10 + gt())); // gt

  UPDATE_ISAM(1);

  graph.add(BetweenFactor<Pose3>(p1, p2, Pose3(Rot3::Rz(-M_PI/4.0+ga()), Point3(0+gt(), 2.*std::sqrt(2.)+gt(), 0+gt())), odometryNoise));
  graph.add(PoseToPlaneFactor(PlaneValue(-1+gt(), 0+gt(), 0+gt(), 10+gt()), measurementNoise, p2, m0));
  graph.add(PoseToPlaneFactor(PlaneValue(0+gt(), 0+gt(), -1+gt(), 10+gt()), measurementNoise, p2, m2));

  // ground truth
  initialEstimate.insert(p2, Pose3(Rot3::Rz(ga()), Point3(0 + gt(), 4 + gt(), 0 + gt()))); // gt

  UPDATE_ISAM(2);

  graph.add(BetweenFactor<Pose3>(p2, p3, Pose3(Rot3::Rz(M_PI/2.+ga()), Point3(-2+gt(), 0+gt(), 0+gt())), odometryNoise));
  graph.add(PoseToPlaneFactor(PlaneValue(-1./std::sqrt(2.)+gt(), -1./std::sqrt(2.)+gt(), 0+gt(), 5.7573593128588+gt()), measurementNoise, p3, m1));
  graph.add(PoseToPlaneFactor(PlaneValue(0+gt(), 0+gt(), -1+gt(), 10+gt()), measurementNoise, p3, m2));

  // ground truth
  initialEstimate.insert(p3, Pose3(Rot3::Rz(M_PI/2. + ga()), Point3(-2 + gt(), 4 + gt(), 0 + gt()))); // gt
  UPDATE_ISAM(3);

  std::cout << "---------------------------------------------------" << std::endl;


  // Testing clique stuff
  //auto const & ordering = isam.getOrdering();
  //for(auto o : ordering)
  //  std::cout << o.first << " " << o.second << std::endl;

  //std::cout << "---------" << std::endl;

  // Get the current factor graph from isam
  auto factorGraph = isam.getFactorsUnsafe();

  // Get the variable index map from isam
  auto variableIndex = isam.getVariableIndex();

  // Find all of the factors that are connected to m1
  //auto factors = variableIndex[ordering[m1]];
  auto factors = variableIndex[m1];

  std::vector<NonlinearFactorGraph::sharedFactor> rekeyedFactors;
  gtsam::FastVector<size_t> indicesToRemove;
  for(auto f : factors)
  {
    // Add a rekeyed factor to the graph
    graph.add( factorGraph[f]->rekey( {{{ gtsam::Symbol('m', 1), gtsam::Symbol('m', 2) }}} ));

    // Remove the original factor from the graph
    indicesToRemove.push_back(f);
  }

  isam.update(graph, gtsam::Values(), indicesToRemove);

  currentEstimate = isam.calculateBestEstimate();
  std::cout << "======================================" << std::endl;
  currentEstimate.print("Remixed Estimate:\n");

  return 0;
}
