
import time
from helpers.time2spatial import *

from helpers.processing import *
from settings_acados_poly_track import *


def original_solver_update(real_x, solver, sref_N, N):
    for j in range(N):
        yref = np.array([real_x[0] + sref_N * j / N, 0, 0, 0, 0, 0, 0, 0])
        solver.set(j, "yref", yref)


    for j in range(1, N, 1):
        move_x = solver.get(j, "x")
        solver.set(j - 1, "x", move_x)
    solver.set(N, "yref", np.array([real_x[0] + sref_N, 0, 0, 0, 0, 0]))

    solver.set(0, "lbx", real_x)
    solver.set(0, "ubx", real_x)

def poly_solver_update(real_x, poly_solver, track, poly_estimate_dist, degree, N, sref_N):
    x, y, psi, v = transformProj2XY(real_x[0], real_x[1], real_x[2], real_x[3], track)
    real_s0 = real_x[0]
    pointsS, pointsX, pointsY = get_track_points(real_s0 - 0.07, real_s0 + poly_estimate_dist, track)
    # Get polynomial approximation of po trajectory
    x_poly, y_poly = fit_arclen_poly(pointsS, pointsX, pointsY, degree)
    arc_lengths = np.linspace(0.0, poly_estimate_dist, 60)
    curvatures = compute_curvature(arc_lengths, x_poly, y_poly)
    curvature_poly = np.polyfit(arc_lengths, curvatures, degree)

    s0, lateral = find_closest_s(x_poly, y_poly, [x, y], [0.02, 0.30])
    alpha = wrap_angle(psi.item() - compute_tangent_angle(x_poly, y_poly, s0))

    x0 = np.array([s0, lateral, alpha, v, real_x[4], real_x[5]])

    for j in range(N):
        yref = np.array([s0 + sref_N * j / N, 0, 0, 0, 0, 0, 0, 0])
        poly_solver.set(j, "yref", yref)
        poly_solver.set(j, "p", curvature_poly)

    for j in range(1, N, 1):
        move_x = poly_solver.get(j, "x")
        poly_solver.set(j - 1, "x", move_x)
    poly_solver.set(N, "yref", np.array([s0 + sref_N, 0, 0, 0, 0, 0]))

    poly_solver.set(0, "lbx", x0)
    poly_solver.set(0, "ubx", x0)

    return x_poly, y_poly



def run_solver(solver, N, nx, nu):

    start_time = time.time()  # Record start time

    status = solver.solve()

    end_time = time.time()  # Record end time

    elapsed_time = end_time - start_time
    # Initialize arrays to store planned states and controls
    x_plan = np.zeros((N + 1, nx))  # Stores planned states for all stages
    u_plan = np.zeros((N, nu))  # Stores planned controls for all stages except final

    # Extract planned states
    for stage in range(N + 1):  # Includes x_0 to x_N
        x_plan[stage, :] = solver.get(stage, "x")

    # Extract planned controls
    for stage in range(N):  # Only defined for stages 0 to N-1
        u_plan[stage, :] = solver.get(stage, "u")


    return  x_plan, u_plan, elapsed_time, status