#include "IOUtils.H"

#include <nrt/Core/Model/Manager.H>
#include <nrt/Graphics/ShapeRendererBasic.H>
#include <nrt/Graphics/Shapes.H>
#include "PlaneSLAM.H"
#include <LiDARPlaneDetector.H>
#include <details/Timing.H>
#include <PlaneAssociation.H>
#include <FindTransforms.H>
#include <valgrind/callgrind.h>
#include <X11/keysym.h>
#include <DrawingUtils/DrawingUtils.H>
#include <DrawingUtils/ConstraintDisplay.H>
#include <details/Polygonalization.H>
#include "OctomapUtils.H"
#include <glog/logging.h>
#include <glog/log_severity.h>
#include <gflags/gflags.h>

#include <Config/Config.H>
#include <Util/Util.H>

#include <PointCloud/Registration/RegistrationGICP.H>
#include <PointCloud/Registration/Transformation/CeresFunctor.H>
#include <PointCloud/Registration/Convergence/GICPConvergenceCriteria.H>
#include <PointCloud/Features/VelodyneNormalsFLAS.H>
#include <PointCloud/Features/VelodynePlanarity.H>
#include <nrt/ImageProc/Drawing/ColorMapping.H>
#include <nrt/PointCloud2/Registration/Correspondence/CorrespondenceEstimationNearestNeighborNano.H>
#include <nrt/PointCloud2/Registration/Correspondence/Rejection/CorrespondenceRejectionDistance.H>
#include <nrt/PointCloud2/Registration/Correspondence/Rejection/CorrespondenceRejectionPredicate.H>
#include <nrt/PointCloud2/Filter/VoxelFilter.H>
#include <nrt/PointCloud2/Filter/RandomRemovalFilter.H>
#include <nrt/PointCloud2/Common/Centroid.H>
#include <nrt/PointCloud2/Common/Transforms.H>

#include <nrt/ImageProc/IO/ImageSink/VideoWriters/FfmpegVideoWriter.H>
#include <nrt/ImageProc/Reshaping/Transform.H>

using namespace nrt;

NRT_DECLARE_PARAMETER(filename,         std::string,  "The data file or image directory", "/home/rand/drc/messagelogs/hedco0.nrtlog");
NRT_DECLARE_PARAMETER(data,             DataSource,   "The source for the data", DataSource::nrtlog, DataSource_Values);
NRT_DECLARE_PARAMETER(frame,            int,          "The starting frame number", 1);
NRT_DECLARE_PARAMETER(play,             bool,         "Play through the file", true);
NRT_DECLARE_PARAMETER(minheight,        float,        "Minimum height for the jet color map", -1.0);
NRT_DECLARE_PARAMETER(maxheight,        float,        "Maximum height for the jet color map", 10.0);
NRT_DECLARE_PARAMETER(verbose,          bool,         "Whether to output diagnostic information", true);
NRT_DECLARE_PARAMETER(picp,             bool,         "Whether to use PICP for frame-frame alignment", true);
NRT_DECLARE_PARAMETER(justgicp,         bool,         "Just perform raw GICP", false);
NRT_DECLARE_PARAMETER(slam,             bool,         "Perform SLAM", true);
NRT_DECLARE_PARAMETER(tilt,             bool,         "Tilt the incoming cloud by 20 degrees", false);
NRT_DECLARE_PARAMETER(configfile,       std::string,  "Path to the config file", "../config/default.cfg");
NRT_DECLARE_PARAMETER(logfile,          std::string,  "Log file name", "");
NRT_DECLARE_PARAMETER(moviemode,        bool,         "Record a video of the run", false);

struct Parameters :
  public nrt::Component,
  public nrt::Parameter<filename,  data,      frame,
                        play,      minheight, maxheight, verbose,
                        picp, justgicp, slam, tilt, configfile, logfile,
                        moviemode>
{ using nrt::Component::Component; };

enum class CorrespondenceMethod { MUMC, MUMCFast, Good };

// ######################################################################
int main(int argc, const char ** argv)
{
//  { // for some reason this is needed to prevent a crash later?
//    Eigen::Matrix3d snarf( Eigen::Matrix3d::Zero() );
//    Eigen::JacobiSVD<decltype(snarf)> svd1(snarf, Eigen::ComputeFullV);
//    auto s = svd1.singularValues();
//    auto v = svd1.singularValues();
//  }

  google::InitGoogleLogging(argv[0]);
  google::SetCommandLineOption("minloglevel", "4");

  Manager mgr(argc, argv);
  auto parameters = mgr.addComponent<Parameters>("Parameters");

  auto constraintDisplay = mgr.addComponent<ConstraintDisplay>("ConstraintDisplay");

  auto renderer = mgr.addComponent<graphics::ShapeRendererBasic>("Renderer");

  mgr.launch();

  config::openconfig(parameters->configfile::get());

  bool const verbose = parameters->verbose::get();
  bool const is32 = (
      parameters->data::get() == DataSource::nrtlog
#ifdef USE_MRPT
      || parameters->data::get() == DataSource::mrptlog
#endif // USE_MRPT
      );
  bool const use_picp = parameters->picp::get();
  bool const use_just_gicp = parameters->justgicp::get();
  bool const use_slam = use_just_gicp ? false : parameters->slam::get();
  bool const use_moviemode = parameters->moviemode::get();

  CorrespondenceMethod correspondenceMethod;
  if(config::lookup<std::string>("correspondences.method") == "MUMC")
    correspondenceMethod = CorrespondenceMethod::MUMC;
  else if(config::lookup<std::string>("correspondences.method") == "MUMCFast")
    correspondenceMethod = CorrespondenceMethod::MUMCFast;
  else if(config::lookup<std::string>("correspondences.method") == "Good")
    correspondenceMethod = CorrespondenceMethod::Good;
  else
    throw std::runtime_error("Unknown value for correspondences.method: " + config::lookup<std::string>("correspondences.method"));

  VelodyneIO reader( parameters->filename::get(), parameters->data::get());


  std::map<std::string, std::pair<KeySym, bool>> drawOptions =
  {
    // name  ,            key ,      on/off
    {"pause",             {XK_space, true}},
    {"step",              {XK_n,     true}},
    {"planes",            {XK_p,     false}},
    {"lastplanes",        {XK_o,     false}},
    {"correspondences",   {XK_c,     false}},
    {"slam map",          {XK_m,     false}},
    {"current cloud",     {XK_0,     true}},
    {"integrated cloud",  {XK_i,     false}},
    {"jet",               {XK_j,     false}},
    {"path",              {XK_l,     true}},
    {"groups",            {XK_g,     false}},
    {"breaks",            {XK_b,     false}},
    {"picp points",       {XK_u,     false}},
    {"pure picp points",  {XK_v,     false}},
    {"random points",     {XK_y,     false}},
    {"planes last frame", {XK_x,     false}},
    {"normals",           {XK_1,     false}},
    {"planarity",         {XK_2,     false}},
    {"write log",         {XK_w,     false}},
  };

  bool updated = true;
  bool first_logfile_entry = true;

  renderer->setKeyboardCallback([&drawOptions,&updated](std::vector<graphics::ShapeRenderer::KeyboardPress> const & keys)
      {
        for(auto & option : drawOptions)
          for(auto key : keys)
            if(key.release == false && key.key == option.second.first)
            {
              option.second.second = !option.second.second;
              updated = true;
            }
      });

  int frame = parameters->frame::get();
  nrt::Correspondences correspondences;
  PlaneSLAM::RenderedMap globalMap;
  std::vector<nrt::PointCloud2> cloudHistory;
  nrt::PointCloud2 lastCloud, lastLastCloud;
  std::vector<DetectedPlane> lastPlanes_;
  std::vector<DetectedPlane> lastLastPlanes_;
  std::vector<Eigen::Isometry3d> transformHistory;
  std::vector<std::chrono::microseconds> times;

  std::string filename = parameters->filename::get();
  size_t const lastslash = filename.find_last_of('/') + 1;
  filename = filename.substr(lastslash);
  size_t const firstdot = filename.find_last_of('.');
  std::string filebasename = filename.substr(0, firstdot);


  std::string logfile = parameters->logfile::get();
  if(logfile == "")
  {
    logfile = filebasename + ".log";
  }

  std::ofstream logfile_detailed(logfile + ".detailed");

  nrt::FfmpegVideoWriter videoWriter;
  videoWriter.open(filebasename + ".mpeg");

  nrt::Indices lastNonPlaneIndices;
  nrt::Indices purePICPIndices;
  nrt::Indices randomIndices;

  std::vector<std::array<std::vector<LiDARGroup>, 32>> groupHistory32;
  std::vector<std::array<std::vector<LiDARGroup>, 64>> groupHistory64;

  std::mutex mtx;

  ////////////////////////////////////////////////////////////////////////
  // Drawing Thread
  ////////////////////////////////////////////////////////////////////////
  std::thread drawingThread([&]()
  {
    std::vector<std::shared_ptr<nrt::graphics::Shape>> shapes;
    while(true)
    {
      bool did_update = false;
      auto start_time = std::chrono::system_clock::now();
      {
        std::lock_guard<std::mutex> _(mtx);
        if(updated)
        {
          shapes.clear();

          if(drawOptions["current cloud"].second && !transformHistory.empty())
          {
            drawCurrentCloud(shapes, transformHistory.back(), lastCloud);
          }

          if(drawOptions["planarity"].second && !transformHistory.empty())
          {
            drawPlanarity(shapes, transformHistory.back(), lastCloud);
          }

          if(drawOptions["integrated cloud"].second && !transformHistory.empty())
          {
            if(drawOptions["jet"].second)
              drawIntegratedJetCloud(shapes, transformHistory, cloudHistory, parameters->minheight::get(), parameters->maxheight::get());
            else
              drawIntegratedCloud(shapes, transformHistory, cloudHistory);
          }

          if(drawOptions["planes"].second && !transformHistory.empty())
          {
            drawPlanes(shapes, lastPlanes_, transformHistory.back());
          }
          if(drawOptions["lastplanes"].second && transformHistory.size() > 2)
          {
            drawPlanes(shapes, lastLastPlanes_, *(transformHistory.rbegin()+1));
          }
          if(drawOptions["planes last frame"].second && transformHistory.size() > 2)
          {
            drawPlanes(shapes, lastPlanes_, *(transformHistory.rbegin()+1));
          }

          if(drawOptions["correspondences"].second && cloudHistory.size() > 1 && transformHistory.size() > 1)
          {
            drawCorrespondences(shapes, correspondences,
                lastLastPlanes_, lastPlanes_,
                lastLastCloud, lastCloud,
                globalMap, transformHistory.back(), *(transformHistory.rbegin()+1));
          }

          if(drawOptions["normals"].second && !transformHistory.empty())
          {
            drawNormals(shapes, transformHistory.back(), lastCloud, config::lookup<bool>("display.scaleNormalByCurvature"));
          }

          if(drawOptions["slam map"].second)
            drawSLAMMap(shapes, globalMap);

          if(drawOptions["path"].second)
            drawPath(shapes, transformHistory);

          if(drawOptions["groups"].second || drawOptions["breaks"].second )
          {
            if( is32 )
              drawBreaksGroups<32>( shapes, transformHistory.back(), lastCloud, groupHistory32.back(), drawOptions["breaks"].second, drawOptions["groups"].second );
            else
              drawBreaksGroups<64>( shapes, transformHistory.back(), lastCloud, groupHistory64.back(), drawOptions["breaks"].second, drawOptions["groups"].second );
          }

          if(drawOptions["picp points"].second)
            drawPICPPoints(shapes, transformHistory.back(), lastCloud, lastNonPlaneIndices, {178, 24, 43});

          if(drawOptions["pure picp points"].second)
            drawPICPPoints(shapes, transformHistory.back(), lastCloud, purePICPIndices, {178, 24, 43});

          if(drawOptions["random points"].second)
            drawPICPPoints(shapes, transformHistory.back(), lastCloud, randomIndices, {178, 24, 43});

          if(!use_moviemode) drawMenu(shapes, drawOptions);

          updated = false;
          did_update = true;
        } // end updated
      } // end mutex lock

      if(use_moviemode)
      {
        Eigen::Vector3d eye, look;
        std::tie(eye, look) = filterCamera(transformHistory);
        renderer->lookAt(eye, look, Eigen::Vector3d::UnitZ());
      }

      renderer->initFrame();
      for(auto s : shapes) s->render(*renderer);
      renderer->renderFrame();

      if(did_update && use_moviemode)
      {
        nrt::Image<nrt::PixRGB<uint8_t>> framebuffer(renderer->dims());
        glReadPixels(0,0,framebuffer.width(),framebuffer.height(), GL_RGB, GL_UNSIGNED_BYTE, framebuffer.pod_begin());
        framebuffer = nrt::flipVertical(framebuffer);
        videoWriter.appendFrame(nrt::GenericImage(framebuffer));
      }

      auto end_time = std::chrono::system_clock::now();

      if(end_time - start_time < std::chrono::milliseconds(33))
        std::this_thread::sleep_for(std::chrono::milliseconds(33) - (end_time - start_time));
    }
  });




  ////////////////////////////////////////////////////////////////////////
  // Plane Detection / SLAM Thread
  ////////////////////////////////////////////////////////////////////////

  double const maxICPMatchDistance = std::pow(config::lookup<double>("picp.maxMatchDistance"), 2);
  auto convergence    = std::make_shared<GICPConvergenceCriteria>(nrt::PointCloud2::AffineTransform::Identity(), 40);
  auto correspondence = std::make_shared<nrt::CorrespondenceEstimationNearestNeighborNano>();
  auto picpRejection  = std::make_shared<nrt::CorrespondenceRejectionPredicate>(
      [maxICPMatchDistance](nrt::PointCloud2 const & source, nrt::PointCloud2 const & target, nrt::Correspondence const & i)
      {
        auto s = source.at<Covariance>(i.sourceIndex);
        if(!s.geometry().isValid()) return false;
        if(!s.get<Covariance>().valid) return false;

        auto t = target.at<Covariance>(i.targetIndex);
        if(!t.geometry().isValid()) return false;
        if(!t.get<Covariance>().valid) return false;

        if((s.geometry().getVector3Map() - t.geometry().getVector3Map()).squaredNorm() >
          maxICPMatchDistance) return false;

        return true;
      });

  RegistrationGICP registration(correspondence, { picpRejection }, convergence );

  // SLAM
  PlaneSLAM slam;

  //LiDARPlaneDetector<32> planeDetector(90, 45, 280, 70, 0.1, false);
  LiDARPlaneDetector<32> planeDetector(
      config::lookup<double>("planedetection.thetaBins"),
      config::lookup<double>("planedetection.phiBins"),
      config::lookup<double>("planedetection.rhoBins"),
      70, 0.1, false);

  //LiDARPlaneDetector<32> planeDetector(180, 90, 300, 70, 0.1, false);
  LiDARPlaneDetector<64> planeDetector64(180, 90, 51, 120, 0.1, false);

  planeDetector.setAccumulatorThreshold(config::lookup<double>("planedetection.accumulatorThreshold"));//2.0);
  planeDetector64.setAccumulatorThreshold(2.0);

  // The Normal engine
  VelodyneNormalsFLAS normalsEngine(3,1); // theta, phi

  while(true)
  {
    frame_timing::begin_frame();
    nrt::Optional<nrt::PointCloud2> cloud = nrt::OptionalEmpty;

    bool running = false;
    {
      std::lock_guard<std::mutex> _(mtx);
      running = !drawOptions["pause"].second || drawOptions["step"].second;

      drawOptions["step"].second = false;

      if(drawOptions["write log"].second && !transformHistory.empty())
      {
        writePath(logfile, transformHistory, times);
        drawOptions["write log"].second = false;
      }

    }

    try { if(running) cloud = reader.getCloud( frame++ ); }
    catch(...) { }

    if(cloud)
    {
      auto frameStartTime = std::chrono::steady_clock::now();

      std::cout << "----------------------------------- Frame: " << frame << "-----------------------------------" << std::endl;

      if(parameters->tilt::get())
        nrt::transformPointCloudInPlace(*cloud, nrt::PointCloud2::AffineTransform(Eigen::AngleAxisf(-20*M_PI/180, Eigen::Vector3f::UnitX())));

      std::vector<DetectedPlane> planes;
      std::vector<DetectedPlane> lastPlanes;
      nrt::Indices nonPlaneIndices;
      nrt::Correspondences correspondences_;
      Eigen::Isometry3d odometryEstimate = Eigen::Isometry3d::Identity();
      Eigen::Matrix3d translationCovariance, rotationCovariance;
      std::vector<std::pair<DetectedPlane, DetectedPlane>> zippedPlanes;

      // Compute fast normals
      frame_timing::start("Normals FLAS");
      normalsEngine.compute( *cloud );
      frame_timing::stop("Normals FLAS");

      frame_timing::start("Normals Planarity");
      static const int planarityRho = config::lookup<double>("normals.planarity.rho");
      static const int planarityTheta = config::lookup<double>("normals.planarity.theta");
      computeVelodynePlanarity( *cloud, planarityRho, planarityTheta );
      frame_timing::stop("Normals Planarity");



      if(use_just_gicp == false)
      {
        // Detect the planes in the current cloud
        frame_timing::start("Detect Planes");
        if( is32 )
          planes = planeDetector.detect(*cloud);
        else
          planes = planeDetector64.detect(*cloud);
        frame_timing::stop("Detect Planes");

        // Retrieve the last set of detected planes
        if(cloudHistory.size() > 1) lastPlanes = lastPlanes_;

        frame_timing::start("Correspondence/Transform Estimation");

        switch(correspondenceMethod)
        {
          case CorrespondenceMethod::Good:
            correspondences_ = findGoodCorrespondences(lastPlanes, planes);
            odometryEstimate = findTransform(correspondences_, lastPlanes, planes,
                translationCovariance, rotationCovariance).inverse();
            break;
          case CorrespondenceMethod::MUMC:
            odometryEstimate = findTransformMUMC(lastPlanes, planes, correspondences_).inverse();
            break;
          case CorrespondenceMethod::MUMCFast:
            odometryEstimate = findTransformMUMCFast(lastPlanes, planes, correspondences_).inverse();
            break;
        }

        // Zip up the corresponding planes
        for(auto c : correspondences_)
          zippedPlanes.emplace_back(planes[c.targetIndex], lastPlanes[c.sourceIndex]);

        static const int useCeres = config::lookup<bool>("correspondences.useCeres");
        if(useCeres)
        {
          CeresFunctor ceresSolver;
          odometryEstimate = ceresSolver(zippedPlanes, {}, {}, {}, {}, Eigen::Isometry3f::Identity(), Eigen::Isometry3f::Identity()).cast<double>();
        }

        //if(translationCovariance != translationCovariance)
        translationCovariance = Eigen::Matrix3d::Identity() * config::lookup<double>("correspondences.translationCovariance"); //0.03;

        //if(rotationCovariance != rotationCovariance)
        rotationCovariance = Eigen::Matrix3d::Identity() * config::lookup<double>("correspondences.rotationCovariance"); //0.01;

        frame_timing::stop("Correspondence/Transform Estimation");
      }


      if(use_picp || use_just_gicp)
      {
        frame_timing::start("PICP");

        frame_timing::start("Compute Covariance");
        {
          PlaneConstraint constraintSphere(0);
          nonPlaneIndices = computeCovarianceGICP2(*cloud, zippedPlanes, config::lookup<double>("picp.constraintthresh"), constraintSphere);

          constraintDisplay->update(constraintSphere);
        }
        frame_timing::stop("Compute Covariance");

        // FOR DISPLAY TAKE OUT FOR TIMING
        {
          PlaneConstraint constraintSphere(0);
          purePICPIndices = computeCovarianceGICP2(*cloud, {}, config::lookup<double>("picp.constraintthresh"), constraintSphere);

          randomIndices = nrt::Indices( cloud->size() ); std::iota( randomIndices.begin(), randomIndices.end(), 0 );

          std::random_device rd;
          std::mt19937 rng(rd());

          std::shuffle( randomIndices.begin(), randomIndices.end(), rng );

          randomIndices.resize( purePICPIndices.size() );
        }

        if(!cloudHistory.empty() && (!nonPlaneIndices.empty() || config::lookup<bool>("picp.forcePICP")))
        {
          //Eigen::Isometry3f guess(odometryEstimate.cast<float>());
          Eigen::Isometry3f guess = Eigen::Isometry3f::Identity();
          convergence->setGuess(guess);

          odometryEstimate = registration.align(*cloud, nonPlaneIndices,
              lastCloud, lastNonPlaneIndices, zippedPlanes, guess).cast<double>();

          // Calculate correspondences again
          std::vector<DetectedPlane> transformedLastPlanes( lastPlanes.size() );
          std::transform( lastPlanes.begin(), lastPlanes.end(), transformedLastPlanes.begin(),
              [odometryEstimate](DetectedPlane const & p)
              {
              DetectedPlane ret = p;
              ret.plane.transform( Eigen::Affine3f(odometryEstimate.cast<float>()) );
              return ret;
              } );


          translationCovariance = Eigen::Matrix3d::Identity() * config::lookup<double>("picp.translationCovariance"); //0.03;
          rotationCovariance = Eigen::Matrix3d::Identity() * config::lookup<double>("picp.rotationCovariance"); //0.01;
        }
        else
        {
          std::cout << "SKIPPING PICP"<< std::endl;
        }

        frame_timing::stop("PICP");
      }

      std::vector<Eigen::Isometry3d> tformHistory;
      PlaneSLAM::RenderedMap map;
      if(use_slam)
      {
        // Update SLAM to find the latest path estimate
        frame_timing::start("SLAM Update");
        tformHistory = slam.update(odometryEstimate, correspondences_, lastPlanes, planes, translationCovariance, rotationCovariance);
        frame_timing::stop("SLAM Update");

        // Get the latest global map
        frame_timing::start("SLAM Extract Map");
        map = slam.getRenderMap();
        frame_timing::stop("SLAM Extract Map");
      }
      else
      {
        tformHistory = transformHistory;
        Eigen::Isometry3d lastTransform = transformHistory.empty() ? Eigen::Isometry3d::Identity() : transformHistory.back();
        tformHistory.push_back(lastTransform * odometryEstimate);
      }
      auto frameEndTime = std::chrono::steady_clock::now();
      times.push_back(std::chrono::duration_cast<std::chrono::microseconds>(frameEndTime - frameStartTime));

      {
        std::lock_guard<std::mutex> _(mtx);
        updated = true;

        lastNonPlaneIndices = nonPlaneIndices;
        cloudHistory.push_back(filterPointCloud( *cloud, RandomRemovalFilter(0.9) ));
        lastLastCloud = lastCloud;
        lastCloud = *cloud;
        lastLastPlanes_ = lastPlanes_;
        lastPlanes_ = planes;
        transformHistory = tformHistory;
        globalMap = map;
        correspondences = correspondences_;

        if( is32 )
          groupHistory32.push_back( planeDetector.getGroups() );
        else
          groupHistory64.push_back( planeDetector64.getGroups() );
      }
      frame_timing::end_frame();

      if( verbose )
      {
        frame_timing::report_all(std::chrono::seconds(1));
      }

      if(!first_logfile_entry)
      {
        logfile_detailed << ",\n";
      }
      first_logfile_entry = false;
      logfile_detailed << "{\n\"frame\" : " << frame << ",\n \"tasks\": [" << frame_timing::frame_json() << "]\n}";
      logfile_detailed.flush();
    }
  }

  drawOptions["pause"].second = false;
  drawingThread.join();
  return 0;
}
