
from acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
from bicycle_model_poly import bicycle_model_poly
import scipy.linalg
import numpy as np


def setting_poly_track(model_name : str, prediction_horizon : float, N : int, poly_degree : int, max_predict : float):
    '''
    :param model_name:  Name of the model
    :param prediction_horizon:  How far into future to predict (s)
    :param N: Discretization steps
    :param poly_degree: Degree of polynom for track approximation
    :param max_predict: How far into future to predict
    :return:
    '''
    # create render arguments
    ocp = AcadosOcp()

    # export model
    model, constraint = bicycle_model_poly(model_name, poly_degree, max_predict)

    # define acados ODE
    model_ac = AcadosModel()
    model_ac.f_impl_expr = model.f_impl_expr
    model_ac.f_expl_expr = model.f_expl_expr
    model_ac.x = model.x
    model_ac.xdot = model.xdot
    model_ac.u = model.u
    model_ac.z = model.z
    model_ac.p = model.p
    model_ac.name = model.name
    ocp.model = model_ac

    # define constraint
    model_ac.con_h_expr = constraint.expr

    # dimensions
    nx = model.x.rows()
    nu = model.u.rows()
    ny = nx + nu
    ny_e = nx

    nsbx = 1
    nsh = constraint.expr.shape[0]
    ns = nsh + nsbx

    # discretization
    ocp.dims.N = N

    #s,n,alpha,v,D,delta
    #progress, diversion, orientation, velocity, steering, steering-change,
    Q = np.diag([ 1e-2, 1e-3, 1e-8, 1e-8, 1e-3, 5e-3 ])

    R = np.eye(nu)
    R[0, 0] = 1e-3
    R[1, 1] = 5e-3

    Qe = np.diag([ 5e1, 1e1, 1e-8, 1e-8, 5e-3, 2e-3 ])

    ocp.cost.cost_type = "LINEAR_LS"
    ocp.cost.cost_type_e = "LINEAR_LS"
    unscale = N / prediction_horizon

    ocp.cost.W = unscale * scipy.linalg.block_diag(Q, R)
    ocp.cost.W_e = Qe / unscale

    Vx = np.zeros((ny, nx))
    Vx[:nx, :nx] = np.eye(nx)
    ocp.cost.Vx = Vx

    Vu = np.zeros((ny, nu))
    Vu[6, 0] = 1.0
    Vu[7, 1] = 1.0
    ocp.cost.Vu = Vu

    Vx_e = np.zeros((ny_e, nx))
    Vx_e[:nx, :nx] = np.eye(nx)
    ocp.cost.Vx_e = Vx_e

    ocp.cost.zl = 100 * np.ones((ns,))
    ocp.cost.zu = 100 * np.ones((ns,))
    ocp.cost.Zl = 1 * np.ones((ns,))
    ocp.cost.Zu = 1 * np.ones((ns,))

    # set intial references
    ocp.cost.yref = np.array([1, 0, 0, 0, 0, 0, 0, 0])
    ocp.cost.yref_e = np.array([0, 0, 0, 0, 0, 0])

    # setting constraints
    # State constraint
    ocp.constraints.lbx = np.array([-0.18])
    ocp.constraints.ubx = np.array([0.18])

    ocp.constraints.idxbx = np.array([1])
    # Control conststraints
    ocp.constraints.lbu = np.array([model.dthrottle_min, model.ddelta_min])
    ocp.constraints.ubu = np.array([model.dthrottle_max, model.ddelta_max])
    ocp.constraints.idxbu = np.array([0, 1])
    # Soft state consttraints
    ocp.constraints.lsbx = np.zeros([nsbx])
    ocp.constraints.usbx = np.zeros([nsbx])
    ocp.constraints.idxsbx = np.array(range(nsbx))


    # Non linear constrains
    #long a, lat a, lateral pos, throttle, sterring_d
    ocp.constraints.lh = np.array(
        [
            constraint.along_min,
            constraint.alat_min,
            model.n_min,
            model.throttle_min,
            model.delta_min,
        ]
    )
    ocp.constraints.uh = np.array(
        [
            constraint.along_max,
            constraint.alat_max,
            model.n_max,
            model.throttle_max,
            model.delta_max,
        ]
    )
    ocp.constraints.lsh = np.zeros(nsh)
    ocp.constraints.ush = np.zeros(nsh)
    ocp.constraints.idxsh = np.array(range(nsh))

    ocp.parameter_values = np.zeros(poly_degree + 1, dtype=float)
    ocp.constraints.x0 = model.x0

    # set QP solver and integration
    ocp.solver_options.tf = prediction_horizon
    # ocp.solver_options.qp_solver = 'FULL_CONDENSING_QPOASES'
    ocp.solver_options.qp_solver = "PARTIAL_CONDENSING_HPIPM"
    ocp.solver_options.nlp_solver_type = "SQP"
    ocp.solver_options.hessian_approx = "GAUSS_NEWTON"
    ocp.solver_options.integrator_type = "ERK"
    ocp.solver_options.sim_method_num_stages = 4
    ocp.solver_options.sim_method_num_steps = 3
    # ocp.solver_options.nlp_solver_step_length = 0.05
    ocp.solver_options.nlp_solver_max_iter = 20
    ocp.solver_options.tol = 1e-3
    # ocp.solver_options.nlp_solver_tol_comp = 1e-1

    # create solver
    acados_solver = AcadosOcpSolver(ocp, json_file=f"{model_name}.json")
    return constraint, model, acados_solver
