#include "IOUtils.H"

#include <nrt/Core/Model/Manager.H>
#include <nrt/Graphics/ShapeRendererBasic.H>
#include <nrt/Graphics/Shapes.H>
#include <LiDARPlaneDetector.H>
#include <details/Timing.H>
#include <PlaneAssociation.H>
#include <FindTransforms.H>
#include <valgrind/callgrind.h>
#include <X11/keysym.h>
#include <DrawingUtils/DrawingUtils.H>
#include <details/Polygonalization.H>
#include "OctomapUtils.H"
#include <glog/logging.h>
#include <glog/log_severity.h>
#include <gflags/gflags.h>

#include <Config/Config.H>
#include <Util/Util.H>

#include <PointCloud/Registration/RegistrationGICP.H>
#include <PointCloud/Registration/Transformation/CeresFunctor.H>
#include <PointCloud/Registration/Convergence/GICPConvergenceCriteria.H>
#include <PointCloud/Features/VelodyneNormalsFLAS.H>
#include <PointCloud/Features/VelodynePlanarity.H>
#include <nrt/ImageProc/Drawing/ColorMapping.H>
#include <nrt/PointCloud2/Registration/Correspondence/CorrespondenceEstimationNearestNeighborNano.H>
#include <nrt/PointCloud2/Registration/Correspondence/Rejection/CorrespondenceRejectionDistance.H>
#include <nrt/PointCloud2/Registration/Correspondence/Rejection/CorrespondenceRejectionPredicate.H>
#include <nrt/PointCloud2/Filter/VoxelFilter.H>
#include <nrt/PointCloud2/Filter/RandomRemovalFilter.H>
#include <nrt/PointCloud2/Common/Centroid.H>
#include <nrt/PointCloud2/Common/Transforms.H>

#include <nrt/ImageProc/IO/ImageSink/VideoWriters/FfmpegVideoWriter.H>
#include <nrt/ImageProc/Reshaping/Transform.H>

using namespace nrt;

NRT_DECLARE_PARAMETER(filename,         std::string,  "The data file or image directory", "/home/rand/drc/messagelogs/hedco0.nrtlog");
NRT_DECLARE_PARAMETER(data,             DataSource,   "The source for the data", DataSource::nrtlog, DataSource_Values);
NRT_DECLARE_PARAMETER(frame,            int,          "The starting frame number", 1);
NRT_DECLARE_PARAMETER(minheight,        float,        "Minimum height for the jet color map", -1.0);
NRT_DECLARE_PARAMETER(maxheight,        float,        "Maximum height for the jet color map", 10.0);
NRT_DECLARE_PARAMETER(verbose,          bool,         "Whether to output diagnostic information", true);
NRT_DECLARE_PARAMETER(tilt,             bool,         "Tilt the incoming cloud by 20 degrees", false);
NRT_DECLARE_PARAMETER(configfile,       std::string,  "Path to the config file", "../config/default.cfg");
NRT_DECLARE_PARAMETER(logfile,          std::string,  "Log file name", "");

struct Parameters :
  public nrt::Component,
  public nrt::Parameter<filename,  data,      frame,
                        minheight, maxheight, verbose,
                        tilt, configfile, logfile>
{ using nrt::Component::Component; };

enum class CorrespondenceMethod { MUMC, MUMCFast, Good };

// ######################################################################
// Random number utilities
// ######################################################################
template <class T> inline
auto uniformInt( T lower, T upper ) -> std::function<T()>
{
  if( upper < lower )
    throw "Upper bound less than lower bound";

  return [=]()
  {
    static std::random_device rd;
    static std::mt19937 gen(rd());
    std::uniform_int_distribution<T> rng( lower, upper );
    return rng(gen);
  };
}

auto normalReal( double mean, double std ) -> std::function<double()>
{
  return [=]()
  {
    static std::random_device rd;
    static std::mt19937 gen(rd());
    std::normal_distribution<double> rng( mean, std );
    return rng(gen);
  };
}

// ######################################################################
double computeMSE( nrt::PointCloud2 const & groundTruth, nrt::PointCloud2 const & transformed )
{
  assert( groundTruth.size() == transformed.size() );
  double mse = 0.0;
  size_t valid = 0;

  for( auto gtIter = groundTruth.geometry_begin(), gtEnd = groundTruth.geometry_end(),
            trIter = transformed.geometry_begin();
            gtIter != gtEnd;
            ++gtIter, ++trIter )
  {
    if( !gtIter->isValid() || !trIter->isValid() )
      continue;

    mse += (gtIter->getVector3Map() - trIter->getVector3Map()).squaredNorm();
    ++valid;
  }

  return mse / valid;
}


void printStatistics( std::array<double, 5> const & noiseValues, std::array<std::vector<std::pair<double, double>>, 5> const & noiseResults )
{
  for( size_t i = 0; i < noiseValues.size(); ++i )
  {
    std::cout << "For noise: " << noiseValues[i] << std::endl;
    double meanIC3PO = 0.0, meanRandom = 0.0;
    for( auto const & vals : noiseResults[i] )
    {
      meanIC3PO += vals.first;
      meanRandom += vals.second;
    }

    meanIC3PO /= noiseResults[i].size();
    meanRandom /= noiseResults[i].size();

    double stdIC3PO = 0.0, stdRandom = 0.0;
    for( auto const & vals : noiseResults[i] )
    {
      stdIC3PO += (vals.first - meanIC3PO) * (vals.first - meanIC3PO);
      stdRandom += (vals.first - meanRandom) * (vals.first - meanRandom);
    }

    stdIC3PO /= noiseResults[i].size();
    stdRandom /= noiseResults[i].size();

    stdIC3PO = std::sqrt( stdIC3PO );
    stdRandom = std::sqrt( stdRandom );

    std::cout << "\tIC3PO: " << meanIC3PO << ", " << stdIC3PO << " || random: " << meanRandom << ", " << stdRandom << std::endl;
  }
}


// ######################################################################
// The purpose of this program is to test whether the constraint ellipse
// is better than randomly selecting points. It performs the following test:
//
// Given an input dataset,
//  Select N individual scans
//    For each scan n, do:
//      scan_new = rotate and translate n
//      scan_new = jitter all points with Gaussian noise, parameterized by sigma
//
//      Loop over 50 averages:
//        set n_points = number of points solving via constraint ellipse
//        select n_points at random and compute transformation
//
//      Average error in transformations
//
#include <DrawingUtils/ConstraintDisplay.H>
int main(int argc, const char ** argv)
{
//  { // for some reason this is needed to prevent a crash later?
//    Eigen::Matrix3d snarf( Eigen::Matrix3d::Zero() );
//    Eigen::JacobiSVD<decltype(snarf)> svd1(snarf, Eigen::ComputeFullV);
//    auto s = svd1.singularValues();
//    auto v = svd1.singularValues();
//  }

  google::InitGoogleLogging(argv[0]);
  google::SetCommandLineOption("minloglevel", "4");

  Manager mgr(argc, argv);
  auto parameters = mgr.addComponent<Parameters>("Parameters");
  auto constraintDisplay = mgr.addComponent<ConstraintDisplay>("ConstraintDisplay"); // taking this out will cause boost lanczos run time error?!
  auto renderer = mgr.addComponent<graphics::ShapeRendererBasic>("Renderer");

  mgr.launch();

  config::openconfig(parameters->configfile::get());

  bool const verbose = parameters->verbose::get();
  bool const is32 = (
      parameters->data::get() == DataSource::nrtlog
#ifdef USE_MRPT
      || parameters->data::get() == DataSource::mrptlog
#endif // USE_MRPT
      );

  CorrespondenceMethod correspondenceMethod;
  if(config::lookup<std::string>("correspondences.method") == "MUMC")
    correspondenceMethod = CorrespondenceMethod::MUMC;
  else if(config::lookup<std::string>("correspondences.method") == "MUMCFast")
    correspondenceMethod = CorrespondenceMethod::MUMCFast;
  else if(config::lookup<std::string>("correspondences.method") == "Good")
    correspondenceMethod = CorrespondenceMethod::Good;
  else
    throw std::runtime_error("Unknown value for correspondences.method: " + config::lookup<std::string>("correspondences.method"));

  std::map<std::string, std::pair<KeySym, bool>> drawOptions =
  {
    // name  ,            key ,      on/off
    {"pause",             {XK_space, true}},
    {"step",              {XK_n,     true}},
    {"input cloud",       {XK_0,     true}},
    {"ground truth cloud",{XK_1,     false}},
    {"ground truth noise",{XK_2,     false}},
    {"ic3po cloud",       {XK_i,     false}},
    {"random cloud",      {XK_r,     false}},
    {"picp points",       {XK_p,     false}},
    {"random points",     {XK_k,     false}},
  };

  std::mutex mtx;

  bool updated = true;
  renderer->setKeyboardCallback([&drawOptions,&updated](std::vector<graphics::ShapeRenderer::KeyboardPress> const & keys)
      {
        for(auto & option : drawOptions)
          for(auto key : keys)
            if(key.release == false && key.key == option.second.first)
            {
              option.second.second = !option.second.second;
              updated = true;
            }
      });

  nrt::PointCloud2 gfxInputCloud, gfxGTCloudNoise, gfxGTCloud, gfxIC3POCloud, gfxRandomCloud;
  nrt::Indices nonPlaneIndices;
  nrt::Indices lastNonPlaneIndices;
  nrt::Indices sourceIndices;

  std::thread drawingThread([&]()
  {
    std::vector<std::shared_ptr<nrt::graphics::Shape>> shapes;
    while(true)
    {
      bool did_update = false;
      auto start_time = std::chrono::system_clock::now();
      {
        std::lock_guard<std::mutex> _(mtx);
        if(updated)
        {
          shapes.clear();

          if(drawOptions["input cloud"].second)
            drawCurrentCloud(shapes, Eigen::Isometry3d::Identity(), gfxInputCloud);
          if(drawOptions["ground truth cloud"].second)
            drawCurrentCloud(shapes, Eigen::Isometry3d::Identity(), gfxGTCloud);
          if(drawOptions["ground truth noise"].second)
            drawCurrentCloud(shapes, Eigen::Isometry3d::Identity(), gfxGTCloudNoise);
          if(drawOptions["ic3po cloud"].second)
            drawCurrentCloud(shapes, Eigen::Isometry3d::Identity(), gfxIC3POCloud);
          if(drawOptions["random cloud"].second)
            drawCurrentCloud(shapes, Eigen::Isometry3d::Identity(), gfxRandomCloud);
          if(drawOptions["picp points"].second)
            drawPICPPoints(shapes, Eigen::Isometry3d::Identity(), gfxInputCloud, lastNonPlaneIndices);
          if(drawOptions["random points"].second)
            drawPICPPoints(shapes, Eigen::Isometry3d::Identity(), gfxInputCloud, sourceIndices);

          drawMenu(shapes, drawOptions);

          updated = false;
          did_update = true;
        } // end updated
      } // end mutex lock

      renderer->initFrame();
      for(auto s : shapes) s->render(*renderer);
      renderer->renderFrame();

      auto end_time = std::chrono::system_clock::now();

      if(end_time - start_time < std::chrono::milliseconds(33))
        std::this_thread::sleep_for(std::chrono::milliseconds(33) - (end_time - start_time));
    }
  });

  int frame = parameters->frame::get();
  nrt::Correspondences correspondences;
  std::vector<nrt::PointCloud2> cloudHistory;
  nrt::PointCloud2 lastCloud, lastLastCloud;
  std::vector<Eigen::Isometry3d> transformHistory;
  std::vector<std::chrono::microseconds> times;

  std::string filename = parameters->filename::get();
  size_t const lastslash = filename.find_last_of('/') + 1;
  filename = filename.substr(lastslash);
  size_t const firstdot = filename.find_last_of('.');
  std::string filebasename = filename.substr(0, firstdot);


  std::string logfile = parameters->logfile::get();
  if(logfile == "")
  {
    logfile = filebasename + ".log";
  }

  std::ofstream logfile_detailed(logfile + ".detailed");

  nrt::FfmpegVideoWriter videoWriter;
  videoWriter.open(filebasename + ".mpeg");

  std::vector<std::array<std::vector<LiDARGroup>, 32>> groupHistory32;
  std::vector<std::array<std::vector<LiDARGroup>, 64>> groupHistory64;

  ////////////////////////////////////////////////////////////////////////
  // Plane Detection / SLAM Thread
  ////////////////////////////////////////////////////////////////////////


  std::array<std::vector<std::pair<double, double>>, 5> noiseResults;
  const std::array<double, 5> noise = {0.01, 0.05, 0.1, 0.5, 1.0};

  double const maxICPMatchDistance = std::pow(config::lookup<double>("picp.maxMatchDistance"), 2);
  auto convergence    = std::make_shared<GICPConvergenceCriteria>(nrt::PointCloud2::AffineTransform::Identity(), 40);
  auto correspondence = std::make_shared<nrt::CorrespondenceEstimationNearestNeighborNano>();
  auto picpRejection  = std::make_shared<nrt::CorrespondenceRejectionPredicate>(
      [maxICPMatchDistance](nrt::PointCloud2 const & source, nrt::PointCloud2 const & target, nrt::Correspondence const & i)
      {
        auto s = source.at<Covariance>(i.sourceIndex);
        if(!s.geometry().isValid()) return false;
        if(!s.get<Covariance>().valid) return false;

        auto t = target.at<Covariance>(i.targetIndex);
        if(!t.geometry().isValid()) return false;
        if(!t.get<Covariance>().valid) return false;

        if((s.geometry().getVector3Map() - t.geometry().getVector3Map()).squaredNorm() >
          maxICPMatchDistance) return false;

        return true;
      });

  RegistrationGICP registration(correspondence, { picpRejection }, convergence );

  //LiDARPlaneDetector<32> planeDetector(90, 45, 280, 70, 0.1, false);
  LiDARPlaneDetector<32> planeDetector(
      config::lookup<double>("planedetection.thetaBins"),
      config::lookup<double>("planedetection.phiBins"),
      config::lookup<double>("planedetection.rhoBins"),
      70, 0.1, false);

  //LiDARPlaneDetector<32> planeDetector(180, 90, 300, 70, 0.1, false);
  LiDARPlaneDetector<64> planeDetector64(180, 90, 51, 120, 0.1, false);

  planeDetector.setAccumulatorThreshold(config::lookup<double>("planedetection.accumulatorThreshold"));//2.0);
  planeDetector64.setAccumulatorThreshold(2.0);

  // The Normal engine
  VelodyneNormalsFLAS normalsEngine(3,1); // theta, phi

  const int MAX_ITERS = 100;
  for( size_t iter = 0; iter < MAX_ITERS; ++iter )
  {
    // Grab a random frame between "frame" and "frame" + 1000
    VelodyneIO reader( parameters->filename::get(), parameters->data::get());
    nrt::Optional<nrt::PointCloud2> cloud = nrt::OptionalEmpty;
    auto rng = uniformInt( frame, frame + 1000 );

    // Create the input clouds
    cloud = reader.getCloud( rng() );
    if( !cloud )
    {
      std::cout << "NO CLOUD DATA THIS ITERATION" << std::endl;
      continue;
    }

    if(parameters->tilt::get())
      nrt::transformPointCloudInPlace(*cloud, nrt::PointCloud2::AffineTransform(Eigen::AngleAxisf(-20*M_PI/180, Eigen::Vector3f::UnitX())));


    // Remove 5% of points randomly
    auto newCloud = nrt::filterPointCloud( *cloud, nrt::RandomRemovalFilter(0.05) );

    // Create random ground truth transformation
    auto angleRNG = normalReal( 0, 5.0 );
    auto transRNG = normalReal( 0, 0.15 );
    nrt::PointCloud2::AffineTransform groundTruth = Eigen::Translation3f( Eigen::Vector3f( transRNG(), transRNG(), transRNG() ) ) *
                                                    Eigen::AngleAxisf( angleRNG() * M_PI/180.0, Eigen::Vector3f::UnitZ() ) *
                                                    Eigen::AngleAxisf( angleRNG() * M_PI/180.0, Eigen::Vector3f::UnitY() ) *
                                                    Eigen::AngleAxisf( angleRNG() * M_PI/180.0, Eigen::Vector3f::UnitX() );

    // Transform cloud
    nrt::transformPointCloudInPlace( newCloud, groundTruth );

    // Iterate over noise values for jitter
    for( size_t noiseIndex = 0; noiseIndex < noise.size(); ++noiseIndex )
    {
      const double std = noise[noiseIndex];
      bool running = false;

      while( !running )
      {
        std::lock_guard<std::mutex> _(mtx);
        running = !drawOptions["pause"].second || drawOptions["step"].second;

        drawOptions["step"].second = false;
      }

      nrt::PointCloud2 cloudTransformed = newCloud;

      auto jitterRNG = normalReal( 0, std );
      for( nrt::PointCloud2::Geometry & p : cloudTransformed.geometry_range() )
      {
        p.x() += jitterRNG();
        p.y() += jitterRNG();
        p.z() += jitterRNG();
      }

      // IC3PO version
      frame_timing::begin_frame();
      frame_timing::start("IC3PO");
      auto frameStartTime = std::chrono::steady_clock::now();

      // Normal computation
      {
        frame_timing::start("Normals FLAS");

        auto doNormals = [&](nrt::PointCloud2 & cld, std::string const & name)
        {
          frame_timing::start(name);
          normalsEngine.compute( cld );
          frame_timing::stop(name);
        };

        doNormals( *cloud, "Base cloud" );
        doNormals( cloudTransformed, "Jittered cloud" );

        frame_timing::stop("Normals FLAS");
      }

      // Planarity
      {
        frame_timing::start("Normals Planarity");

        static const int planarityRho = config::lookup<double>("normals.planarity.rho");
        static const int planarityTheta = config::lookup<double>("normals.planarity.theta");

        auto doPlanarity = [=](nrt::PointCloud2 & cld, std::string const & name)
        {
          frame_timing::start(name);
          computeVelodynePlanarity( cld, planarityRho, planarityTheta );
          frame_timing::stop(name);
        };

        doPlanarity( *cloud, "Base cloud" );
        doPlanarity( cloudTransformed, "Jittered cloud" );

        frame_timing::stop("Normals Planarity");
      }

      // Variables
      nrt::Correspondences correspondences_;
      Eigen::Isometry3d odometryEstimateIC3PO = Eigen::Isometry3d::Identity();
      Eigen::Isometry3d odometryEstimateRandom = Eigen::Isometry3d::Identity();
      Eigen::Matrix3d translationCovariance, rotationCovariance;

      // perform GICP with no plane information
      {
        frame_timing::start("GICP");

        auto doCovariance = [&](nrt::PointCloud2 & cld, std::string const & name)
        {
          nrt::Indices result;

          frame_timing::start(name);
          PlaneConstraint constraintSphere(0);
          result = computeCovarianceGICP2(cld, {}, config::lookup<double>("picp.constraintthresh"), constraintSphere);
          frame_timing::stop(name);

          return result;
        };

        // The size of these indices equals the number of points selected by PICP for each cloud
        lastNonPlaneIndices = doCovariance(*cloud, "Base cloud covariance");
        nonPlaneIndices     = doCovariance(cloudTransformed, "Jittered cloud covariance");

        {
          Eigen::Isometry3f guess = Eigen::Isometry3f::Identity();
          convergence->setGuess(guess);

          odometryEstimateIC3PO = registration.align(cloudTransformed, nonPlaneIndices,
              *cloud, lastNonPlaneIndices, {}, guess).cast<double>();
        }

        frame_timing::stop("GICP");
      }


      const auto numPointsCurrent = lastNonPlaneIndices.size();
      const auto currentPercent = (double)numPointsCurrent / cloud->size();
      const auto numPointsTransformed = nonPlaneIndices.size();
      const auto transformedPercent = (double)numPointsTransformed / cloudTransformed.size();
      const auto msePICP = computeMSE( nrt::transformPointCloud( *cloud, groundTruth ), nrt::transformPointCloud( *cloud, Eigen::Affine3f(odometryEstimateIC3PO.cast<float>()) ) );

      auto frameEndTime = std::chrono::steady_clock::now();
      times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(frameEndTime - frameStartTime));
      frame_timing::stop("IC3PO");

      // ######################################################################
      // Perform GICP with a random selection of points, using the same number as PICP selected

      // Construct indices holding a random selection of points equal in number to those selected by PICP
      sourceIndices = nrt::Indices( cloud->size() ); std::iota( sourceIndices.begin(), sourceIndices.end(), 0 );
      nrt::Indices targetIndices( cloudTransformed.size() ); std::iota( targetIndices.begin(), targetIndices.end(), 0 );

      std::random_device rd;
      std::mt19937 rng(rd());

      std::shuffle( sourceIndices.begin(), sourceIndices.end(), rng );
      std::shuffle( targetIndices.begin(), targetIndices.end(), rng );

      sourceIndices.resize( numPointsCurrent );
      targetIndices.resize( numPointsTransformed );

      frame_timing::start("RANDOM");
      frameStartTime = std::chrono::steady_clock::now();

      // perform GICP with no plane information
      {
        frame_timing::start("GICP");

        {
          Eigen::Isometry3f guess = Eigen::Isometry3f::Identity();
          convergence->setGuess(guess);

          odometryEstimateRandom = registration.align(cloudTransformed, targetIndices,
              *cloud, sourceIndices, {}, guess).cast<double>();
        }

        frame_timing::stop("GICP");
      }

      const auto mseRandom = computeMSE( nrt::transformPointCloud( *cloud, groundTruth ), nrt::transformPointCloud( *cloud, Eigen::Affine3f(odometryEstimateRandom.cast<float>()) ) );

      std::cerr << msePICP << " vs " << mseRandom << " diff: " << (msePICP - mseRandom) << std::endl;
      noiseResults[noiseIndex].push_back( std::make_pair( msePICP, mseRandom ) );
      //std::cerr << currentPercent << " " << transformedPercent << " " << msePICP << " " << mseRandom << std::endl;
      //std::cerr << "   " << sourceIndices.size() << " " << targetIndices.size() << std::endl;

      frameEndTime = std::chrono::steady_clock::now();
      frame_timing::stop("RANDOM");
      frame_timing::end_frame();
      {
        std::lock_guard<std::mutex> _(mtx);
        updated = true;
        gfxInputCloud = *cloud;
        gfxGTCloud = newCloud;
        gfxGTCloudNoise = cloudTransformed;
        gfxIC3POCloud = nrt::transformPointCloud( *cloud, Eigen::Affine3f(odometryEstimateIC3PO.cast<float>()) );
        gfxRandomCloud = nrt::transformPointCloud( *cloud, Eigen::Affine3f(odometryEstimateRandom.cast<float>()) );
      }
    } // End for loop over noise values

    printStatistics( noise, noiseResults );

  } // End for loop over averaging iterations

  return 0;
}
