include("geometry.jl")

using Winston

function transform_plane(T, p)
    n = p[1:3] / norm(p[1:3])
    d = p[4]

    R = T[1:3, 1:3]
    t = T[1:3, 4];

    return [R*n, d - dot(R*n,t)]
end

# Calculate the jacobian on the transformed plane w.r.t. the error
# [ dn/dR dn/dt ]
# [ dd/dR dd/dt ]
function H1(T, p)
    H = zeros(4, 6)

    R = T[1:3,1:3]
    t = T[1:3, 4]

    n = p[1:3]
    d = p[4]

    # dn/dR
    H[1:3, 1:3] = R * Geometry.skew(n)

    # dn/dt
    H[1:3, 4:6] = zeros(3,3)

    # dd/dR
    H[4, 1:3] = -t' * R * Geometry.skew(n)

    # dd/dt
    H[4, 4:6] = (R * n)'

    return H
end

function H2(T, p)
    H = zeros(4, 4)

    R = T[1:3,1:3]
    t = T[1:3, 4]

    n = p[1:3]
    d = p[4]

    # dn / dn_map
    H[1:3, 1:3] = -R

    # dn / dd_map
    H[4, 1:3] = t' * R

    # dd / dn_map
    H[1:3, 4] = zeros(3, 1)

    # dd / dd_map
    H[4, 4] = -1

    return H
end

num_samples = 100
delta = []

errors_mean = zeros(4, length(range))
errors_m2 = zeros(4, length(range))

for n in 1:num_samples

    # Make a random plane
    plane_map = [ -1/sqrt(2) -1/sqrt(2) 0 100 ]'

    # Make a random rotation
    w = (rand(3,1) - 0.5) * 2.0

    # Make a random translation
    t = (rand(3,1) - 0.5) * 1.0

    # Turn them into a random transformation
    T = [Geometry.rodrigues(w) t; [0 0 0 1]]

    # "Measure" the plane
    plane_measured = transform_plane(T, plane_map) + (rand(4,1) - 0.5)*0.1

    # Compute the error
    error = plane_measured - transform_plane(T, plane_map)

    # Try a new value for T by perturbing it
    range = linspace(-.5, .5)

    errors = []
    for x in range
        w_perturb = [0 0 0]' * x
        t_perturb = [0 0 0]' * x

        p_perturb = [1 0 0 0]' * x

        plane_perturb = plane_map + p_perturb

        T_perturb = [Geometry.rodrigues(w + w_perturb) t + t_perturb; [0 0 0 1]]

        error_perturb = plane_measured - transform_plane(T_perturb, plane_perturb)

        J1 = H1(T, plane_map)
        J2 = H2(T, plane_map)

        result = abs(error_perturb - (error + J1*[w_perturb; t_perturb] + J2*p_perturb))

        if length(errors) == 0
            errors = result
        else
            errors = [errors  result]
        end
    end

    delta = errors - errors_mean
    errors_mean = errors_mean + delta ./ n
    errors_m2 = errors_m2 + delta .* (errors - errors_mean)

end

errors_variance = errors_m2 / (num_samples - 1)

plots = cell(4,1)
plot_indices = 1:length(plots)

# Create the plot frames
[plots[i] = FramedPlot() for i in plot_indices]

# Plot the errors
[add(plots[i], Curve(range, errors_mean[i,:])) for i in plot_indices]
[add(plots[i], SymmetricErrorBarsY(range, errors_mean[i,:], errors_variance[i,:])) for i in plot_indices]
[add(plots[i], LineY(0, "color", "red")) for i in plot_indices]
[setattr(plots[i], "ylabel", string("e",i)) for i in plot_indices]
[setattr(plots[i].frame, "draw_grid", true) for i in plot_indices]
[setattr(plots[i].frame, "grid_style", ("linewidth", .2, "linetype", "dotted")) for i in plot_indices]

# Add the plots to a table
t = Table(length(plots), 1)
[t[i,1] = plots[i] for i in plot_indices]

Winston.display(t)

println("Done")
