
from casadi import *
import casadi as ca

def bicycle_model_poly(model_name: str, poly_degree: int = 3, max_predict : float = 1.0):
    """
    :param model_name: Name under which to export model
    :param poly_degree:  Degree of the used polynom
    :param max_predict: Maximum distance for which to rely track approximation
    :return: Model and Constraints
    """
    # define structs
    constraint = types.SimpleNamespace()
    model = types.SimpleNamespace()

    # Setup function for polynom tracking
    # create params
    poly_params = vertcat(*[ca.MX.sym(f'p_{i}') for i in range(0, poly_degree + 1)])
    s_eq = ca.MX.sym('s_eq')
    kapparef_equation = ca.sum1(poly_params * ca.vertcat(*[s_eq ** i for i in range(poly_degree, -1, -1)]))
    kapparef_equation = ca.if_else(
        ca.logic_or(s_eq > max_predict, s_eq < 0.0),  # Condition
        0,  # Output if condition is True
        kapparef_equation  # Output if condition is False
    )

    kapparef_s = ca.Function('kapparef_s', [s_eq, poly_params], [kapparef_equation])


    ## Race car parameters
    #car weigh
    m = 0.043
    #Control effectivity
    C1 = 0.5
    C2 = 15.0

    #Motor resistance
    Cm1 = 0.28
    Cm2 = 0.05
    #Rolling resistance
    Cr0 = 0.011
    Cr2 = 0.006

    MaxSteering = 0.8

    ## CasADi Model
    # set up states & controls
    s = MX.sym("s")  #path progress
    n = MX.sym("n")  #lateral position
    alpha = MX.sym("alpha")  #orientation
    v = MX.sym("v")  #velocity
    D = MX.sym("D")  #acceleration
    delta = MX.sym("delta")  #Steering
    x = vertcat(s, n, alpha, v, D, delta)

    # controls
    derD = MX.sym("derD")
    derDelta = MX.sym("derDelta")
    u = vertcat(derD, derDelta)

    # xdot
    sdot = MX.sym("sdot")
    ndot = MX.sym("ndot")
    alphadot = MX.sym("alphadot")
    vdot = MX.sym("vdot")
    Ddot = MX.sym("Ddot")
    deltadot = MX.sym("deltadot")
    xdot = vertcat(sdot, ndot, alphadot, vdot, Ddot, deltadot)

    # algebraic variables
    z = vertcat([])

    # dynamics
    Fxd = (Cm1 - Cm2 * v) * D - Cr2 * v * v - Cr0 * tanh(5 * v)
    sdota = (v * cos(alpha + C1 * delta)) / (1 - kapparef_s(s, poly_params) * n)
    f_expl = vertcat(
        sdota,
        v * sin(alpha + C1 * delta),
        v * C2 * delta - kapparef_s(s, poly_params) * sdota,
        Fxd / m * cos(C1 * delta),
        derD,
        derDelta,
    )

    # constraint on forces
    a_lat = C2 * v * v * delta + Fxd * sin(C1 * delta) / m
    a_long = Fxd / m

    # Model bounds
    model.n_min = -0.18  # width of the track [m]
    model.n_max = 0.18 # width of the track [m]

    # state bounds
    model.throttle_min = -1.0
    model.throttle_max = 1.0

    model.delta_min = -MaxSteering  # minimum steering angle [rad]
    model.delta_max = MaxSteering  # maximum steering angle [rad]

    # input bounds
    model.ddelta_min = -4.0  # minimum change rate of stering angle [rad/s]
    model.ddelta_max = 4.0  # maximum change rate of steering angle [rad/s]
    model.dthrottle_min = -20  # -10.0  # minimum throttle change rate
    model.dthrottle_max = 20  # 10.0  # maximum throttle change rate

    # nonlinear constraint
    constraint.alat_min = -10  # maximum lateral acceleration [m/s^2]
    constraint.alat_max = 10  # maximum lateral acceleration [m/s^2]

    constraint.along_min = -10  # maximum  force [m/s^2]
    constraint.along_max = 6  # maximum lateral force [m/s^2]

    # Define initial conditions
    model.x0 = np.array([0, 0, 0, 0, 0, 0])

    # define constraints struct
    constraint.alat = Function("a_lat", [x, u], [a_lat])
    constraint.expr = vertcat(a_long, a_lat, n, D, delta)

    # Define model struct
    params = types.SimpleNamespace()
    params.C1 = C1
    params.C2 = C2
    params.Cm1 = Cm1
    params.Cm2 = Cm2
    params.Cr0 = Cr0
    params.Cr2 = Cr2
    model.f_impl_expr = xdot - f_expl
    model.f_expl_expr = f_expl
    model.x = x
    model.xdot = xdot
    model.u = u
    model.z = z
    model.p = poly_params
    model.name = model_name
    model.params = params
    return model, constraint
