#include "Filtering.H"
#include <cassert>
#include <cmath>
#include <string.h>
#include <xmmintrin.h>
#include <smmintrin.h>
#include <cstring>

namespace
{
  //! round n down to the nearest multiple of m
  inline size_t roundDown(size_t n, size_t m) {
    return (n / m) * m;
  }

  //! round n up to the nearest multiple of m
  inline size_t roundUp(size_t n, size_t m) {
    return ((n + m - 1) / m) * m;
  }
}


// ######################################################################
void planefiltering::convolveSSE(float const * const __restrict__ v, size_t const v_length,
    float const * const __restrict__ k, size_t const k_length,
    float * const __restrict__ result)
{
  // Populate shifted kernels
  size_t const kernelBufferSize = roundUp(k_length + 3, 4);
  alignas(16) float kernels[4][kernelBufferSize];

  std::memset(kernels, 0, 4 * kernelBufferSize * sizeof(float));

  std::memcpy(kernels[0] + 0, k, k_length * sizeof(float));
  std::memcpy(kernels[1] + 1, k, k_length * sizeof(float));
  std::memcpy(kernels[2] + 2, k, k_length * sizeof(float));
  std::memcpy(kernels[3] + 3, k, k_length * sizeof(float));

  size_t const safeRight = roundDown(v_length - kernelBufferSize, 4); // safe zone is [0, safeRight]
  float const * safeEndPtr = v + safeRight;

  size_t const half_width = k_length/2;

  float const * inputPtr = v;
  float * outputPtr = result+half_width;
  for(; inputPtr <= safeEndPtr; inputPtr += 4, outputPtr += 4)
  {
    __m128 acc0 = _mm_setzero_ps();
    __m128 acc1 = _mm_setzero_ps();
    __m128 acc2 = _mm_setzero_ps();
    __m128 acc3 = _mm_setzero_ps();
    for(size_t kernelOffset = 0; kernelOffset < kernelBufferSize; kernelOffset += 4)
    {
      __m128 inputBlock = _mm_load_ps(inputPtr + kernelOffset);
      __m128 kernel0 = _mm_load_ps(&kernels[0][kernelOffset]);
      __m128 kernel1 = _mm_load_ps(&kernels[1][kernelOffset]);
      __m128 kernel2 = _mm_load_ps(&kernels[2][kernelOffset]);
      __m128 kernel3 = _mm_load_ps(&kernels[3][kernelOffset]);

      acc0 += _mm_dp_ps(inputBlock, kernel0, 0xFF);
      acc1 += _mm_dp_ps(inputBlock, kernel1, 0xFF);
      acc2 += _mm_dp_ps(inputBlock, kernel2, 0xFF);
      acc3 += _mm_dp_ps(inputBlock, kernel3, 0xFF);
    }
    outputPtr[0] = acc0[0];
    outputPtr[1] = acc1[0];
    outputPtr[2] = acc2[0];
    outputPtr[3] = acc3[0];
  }

  // Finish the first part
  for(size_t i=0; i<half_width; ++i)
  {
    float sum = 0;
    for(size_t j=0; j<k_length; ++j)
    {
      int idx = i + j - half_width;
      if(idx < 0) idx += v_length;
      sum += v[idx] * k[j];
    }
    result[i] = sum;
  }

  // Finish the last part
  for(size_t i=safeRight; i<v_length; ++i)
  {
    float sum = 0;
    for(size_t j=0; j<k_length; ++j)
    {
      size_t idx = i + j - half_width;
      if(idx >= v_length) idx -= v_length;
      sum += v[idx] * k[j];
    }
    result[i] = sum;
  }
}


// ######################################################################
planefiltering::AlignedVector planefiltering::difference(AlignedVector const & v)
{
  AlignedVector d(v.size());
  for(size_t i=1; i<v.size(); ++i)
    d[i] = v[i] - v[i-1];
  d[0] = v.front() - v.back();
  return d;
}

// ######################################################################
planefiltering::AlignedVector planefiltering::convolve(AlignedVector const & v, AlignedVector const & k)
{
  assert(k.size() % 2 == 1);

  AlignedVector result(v.size());

 convolveSSE(v.data(), v.size(), k.data(), k.size(), result.data());

  //size_t half_width = k.size()/2;
  //for(size_t i=half_width; i < v.size() - half_width -1; ++i)
  //{
  //  float sum = 0;
  //  for(size_t j=0; j<k.size(); ++j)
  //    sum += v[i + j - half_width] * k[j];
  //
  //  result[i] = sum;
  //}
  //
  //for(size_t i=0; i<half_width; ++i)
  //{
  //  float sum = 0;
  //  for(size_t j=0; j<k.size(); ++j)
  //  {
  //    int idx = i + j - half_width;
  //    if(idx < 0) idx += v.size();
  //    sum += v[idx] * k[j];
  //  }
  //  result[i] = sum;
  //}
  //
  //for(size_t i=v.size()-half_width-1; i<v.size(); ++i)
  //{
  //  float sum = 0;
  //  for(size_t j=0; j<k.size(); ++j)
  //  {
  //    size_t idx = i + j - half_width;
  //    if(idx >= v.size()) idx -= v.size();
  //    sum += v[idx] * k[j];
  //  }
  //  result[i] = sum;
  //}

  return result;
}

// ######################################################################
planefiltering::AlignedVector planefiltering::makeGaussian(size_t half_width, float sigma)
{
  AlignedVector f( half_width*2+1 );

  float sum = 0;
  float const norm = 1.0 / (sigma * sqrt(2.0*M_PI));
  float const expnorm = 2.0*sigma*sigma;
  for(size_t i=1; i<=half_width; ++i)
  {
    float const val = norm * std::exp(-float(i*i) / expnorm);
    f[half_width + i] = val;
    f[half_width - i] = val;
    sum += 2*val;
  }
  f[half_width] = norm;
  sum += norm;

  for(float & val : f) val /= sum;

  return f;
}

// ######################################################################
planefiltering::AlignedVector planefiltering::makeGaussianDifference(size_t half_width, float sigma)
{
  return difference(makeGaussian(half_width, sigma));
}

// ######################################################################
planefiltering::AlignedVector planefiltering::makeDifferenceOfGaussian(size_t half_width, float sigma1, float sigma2_mult)
{
  AlignedVector dog(half_width * 2 + 1);

  AlignedVector gauss1 = makeGaussian(half_width, sigma1);
  AlignedVector gauss2 = makeGaussian(half_width, sigma1 * sigma2_mult);

  for(size_t i=0; i<gauss1.size(); ++i)
    dog[i] = gauss1[i] - gauss2[i];

  return dog;
}

// ######################################################################
std::vector<size_t> planefiltering::zeroCrossings(AlignedVector const & v)
{
  std::vector<size_t> crossings;
  crossings.reserve(v.size());

  for(size_t i=1; i<v.size(); ++i)
    if(v[i]*v[i-1] <= 0 || v[i] != v[i])
      crossings.push_back(i-1);

  if(v[0]*v.back() <= 0 || v[0] != v[0])
    crossings.push_back(v.size()-1);

  return crossings;
}
