import itertools
import math
import numpy as np
from scipy.optimize import minimize_scalar


def fit_arclen_poly(s, x, y, degree=3):
    """
    Fits the cubic splines to, x,y points but and reparametrize them according to arch-length
    i.e. for example if s is time it will reparametrize them cubic splines to path progress
    """
    # Compute initial splines
    initial_poly_x = np.poly1d(np.polyfit(s, x, degree))
    initial_poly_y = np.poly1d(np.polyfit(s, y, degree))

    # Approximate arc lengths
    arclengths = compute_arc_lengths(s, initial_poly_x, initial_poly_y)
    # Compute final splines
    x_poly = np.polyfit(arclengths, x, degree)
    y_poly = np.polyfit(arclengths, y, degree)

    return np.poly1d(x_poly), np.poly1d(y_poly)


def compute_curvature(s_values, x_poly, y_poly):
    """
    Computes curvature kappa for given s values using the fitted polynomials.
    """

    dx_ds = x_poly.deriv()(s_values)
    dy_ds = y_poly.deriv()(s_values)
    d2x_ds2 = x_poly.deriv(2)(s_values)
    d2y_ds2 = y_poly.deriv(2)(s_values)

    # Curvature formula: kappa = (dx/ds * d2y/ds2 - dy/ds * d2x/ds2) / (dx/ds^2 + dy/ds^2)^(3/2)
    curvature = (dx_ds * d2y_ds2 - dy_ds * d2x_ds2) / (dx_ds ** 2 + dy_ds ** 2) ** (3 / 2)
    return curvature


def find_closest_s(x_poly, y_poly, target_point, s_range) -> (float, float):
    """
    Finds the closest s parameter on the curve to a given target point (x, y).
    Also returns the signed distance to the closest point.
    """
    x_target, y_target = target_point

    def distance_function(s):
        x_s = x_poly(s)
        y_s = y_poly(s)
        return (x_s - x_target) ** 2 + (y_s - y_target) ** 2  # Squared distance

    result = minimize_scalar(distance_function, bounds=(s_range[0], s_range[-1]), method='bounded')
    s_closest = result.x.item()

    # Compute signed distance
    x_closest = x_poly(s_closest)
    y_closest = y_poly(s_closest)
    dx = x_target - x_closest
    dy = y_target - y_closest

    psi = compute_tangent_angle(x_poly, y_poly, s_closest)
    signed_distance = dy * np.cos(psi) - dx * np.sin(psi)  # Projection along normal

    return s_closest, signed_distance.item()



def compute_arc_length(s0, s1, x_poly, y_poly, xy1=None, xy2=None, max_arc_increase=0.1) -> float:
    """
    Compute the arc length between two points, recursively splitting the interval if necessary.

    Args:
    - s0, s1: The s-parameter values for the start and end points.
    - poly_x, poly_y: Parametric functions to compute x, y coordinates for given s values.
    - xy1, xy2: The (x, y) coordinates of the start and end points.
    - max_increase: The maximum allowed distance between consecutive points.

    Returns:
    - The arc length between s0 and s1.
    """

    if xy1 is None:
        xy1 = (x_poly(s0), y_poly(s0))

    if xy2 is None:
        xy2 = (x_poly(s1), y_poly(s1))
    # Compute the distance between xy1 and xy2
    dist = np.sqrt((xy2[0] - xy1[0]) ** 2 + (xy2[1] - xy1[1]) ** 2)

    # If the distance is larger than the maximum allowed increase, refine the step size
    if dist > max_arc_increase:
        # Estimate the number of subdivisions needed based on the rule
        num_intervals = int(np.ceil(dist * 1.3 / max_arc_increase))  # Adjust step size

        # Create refined s-values between s0 and s1
        refined_s_values = np.linspace(s0, s1, num_intervals + 1)

        arc_length = 0

        x_start, y_start = xy1
        s_start = s0
        for i in range(1, len(refined_s_values)):
            s_end = refined_s_values[i]

            # Get the positions corresponding to s_start and s_end
            x_end, y_end = x_poly(s_end), y_poly(s_end)

            # Compute the distance between the two points (distance between two positions)
            arc_length += compute_arc_length(s_start, s_end, x_poly, y_poly, (x_start, y_start),
                                             (x_end, y_end),
                                             max_arc_increase)

            x_start = x_end
            y_start = y_end
            s_start = s_end
        return arc_length

    else:
        # If the distance is within the allowable range, calculate the arc length directly
        return dist


def compute_arc_lengths(s, x_poly, y_poly, max_arc_step=0.1):
    arc_lengths = [0.0]  # Start with 0 at the first place
    cumulative_length = 0.0
    # Loop through the consecutive pairs in s
    start_xy = (x_poly(s[0]), y_poly(s[0]))
    for i in range(1, len(s)):
        s0, s1 = s[i - 1], s[i]  # consecutive pairs in sorted s

        end_xy = (x_poly(s1), y_poly(s1))
        arc_length = compute_arc_length(s0, s1, x_poly, y_poly,start_xy, end_xy,  max_arc_increase=max_arc_step)

        cumulative_length += arc_length
        arc_lengths.append(cumulative_length)
        start_xy = end_xy
    return arc_lengths



def wrap_angle(x):
    while x > math.pi:
        x -= 2 * math.pi
    while x < -math.pi:
        x += 2 * math.pi
    return x
def compute_tangent_angle(x_poly, y_poly, s):
    """
    Computes the tangent direction (angle) of a cubic spline curve at a given arc length s.

    Parameters:
    - x_spline: CubicSpline object for x(s)
    - y_spline: CubicSpline object for y(s)
    - s: the arc length value where we want to compute the direction

    Returns:
    - theta: The tangent angle in radians
    """
    # Compute derivatives dx/ds and dy/ds
    dx_ds = x_poly.deriv()(s)
    dy_ds = y_poly.deriv()(s)

    # Compute the tangent angle theta
    theta = np.arctan2(dy_ds, dx_ds)  # arctan2 ensures correct quadrant

    return theta.item()


def get_track_points(min_s, max_s, track):
    """
    Extracts track points where Sref is within [min_s, max_s], handling circular track cases.

    Parameters:
        min_s (float): Minimum s value.
        max_s (float): Maximum s value.
        track (tuple): Track data as (Sref, xRef, yRef, psiRef, kappa).

    Returns:
        tuple: (s, x, y) where each array contains only the filtered values, sorted in circular order.
    """
    s_ref, x_ref, y_ref, _, _ = track  # Unpack the track data

    track_len = s_ref[-1]
    max_s = max_s % track_len
    if max_s >= min_s:
        # Normal case: select values within [min_s, max_s]
        mask = (s_ref >= min_s) & (s_ref <= max_s)
        s_filtered = s_ref[mask]
        x_filtered = x_ref[mask]
        y_filtered = y_ref[mask]
    else:
        # Circular case: select values outside [max_s, min_s]
        mask1 = (s_ref >= min_s)
        mask2 = (s_ref <= max_s)

        s_filtered = np.concatenate((s_ref[mask1], s_ref[mask2]))
        x_filtered = np.concatenate((x_ref[mask1], x_ref[mask2]))
        y_filtered = np.concatenate((y_ref[mask1], y_ref[mask2]))


    return s_filtered, x_filtered, y_filtered


def transform_from_poly_proj(si, ni, alpha, x_poly, y_poly):
    psi = wrap_angle(compute_tangent_angle( x_poly, y_poly, si) + alpha)

    x = x_poly(si) - np.sin(psi) * ni
    y = y_poly(si) + np.cos(psi) * ni
    return x, y, psi

