import math
import pygame
import numpy as np

import matplotlib.cm as cm
import matplotlib.colors as mcolors

from .processing import transform_from_poly_proj
from .time2spatial import transformProj2Orig

# Colors
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
RED = (255, 0, 0)
BLUE = (0, 0, 255)
GREEN = (0, 255, 0)


class RaceVisualization:

    def __init__(self, width, height, track, timestep=0.1, track_width=0.55):
        pygame.init()
        self.width, self.height = width, height  # Window size
        self.screen = pygame.display.set_mode((self.width, self.height))

        self.time_step = timestep
        self.track = track
        self.track_offset = track_width / 2.0

        pygame.display.set_caption("Race Simulation")

        # Extract track coordinates
        [s_ref, center_x, center_y, alpha_ref, kappa_ref] = self.track

        # Get min and max values for x and y
        min_x, max_x = min(center_x), max(center_x)
        min_y, max_y = min(center_y), max(center_y)

        # Compute scaling factors for x and y
        scale_x = (self.width * 0.8) / (max_x - min_x)  # 90% of width to leave margins
        scale_y = (self.height * 0.8) / (max_y - min_y)  # 90% of height

        # Use the same scale for both to maintain aspect ratio
        self.scale = min(scale_x, scale_y)

        # Compute translation offsets to center the track in the window
        self.offset_x = self.width // 2 - self.scale * (max_x + min_x) / 2
        self.offset_y = self.height // 2 - self.scale * (max_y + min_y) / 2

    def draw_xy_car(self, xy_car, v, color=RED, radius=4):

        x, y, psi = xy_car

        px, py = self.project(x, y)

        # Compute the endpoint of the line based on car's heading (psi) and speed (v)
        line_length = 3 * v * v  # Scale factor to make the line visible
        end_x = px + line_length * math.cos(psi)
        end_y = py + line_length * math.sin(psi)

        car_shape = np.array([[60, 50], [-60, 50], [-60, -50], [60, -50]])

        # Draw the direction line
        pygame.draw.line(self.screen, color, (px, py), (int(end_x), int(end_y)), 2)
        pygame.draw.circle(self.screen, color, (px, py), radius)

    def draw_car(self, x_car, color=RED, radius=12):
        s, ni, alpha, v, torque, steering = x_car
        x, y, psi = transformProj2Orig(s, ni, alpha, self.track)

        self.draw_xy_car((x, y, psi), v, color, radius)

    def project(self, x: float, y: float):
        return int(x * self.scale + self.offset_x), int(y * self.scale + self.offset_y)

    def project_pair(self, xy):
        x, y = xy
        return int(x * self.scale + self.offset_x), int(y * self.scale + self.offset_y)

    def draw_track(self, draw_center=True, draw_points= False):

        s_ref, x_center, y_center, alpha_ref, kappa_ref = self.track
        # Convert track coordinates to pixel positions
        points = [self.project(x, y) for x, y in zip(x_center, y_center)]

        if draw_center:
            # Draw lines connecting the centerline points
            if len(points) > 1:
                pygame.draw.lines(self.screen, GREEN, False, points, 1)
        if draw_points:
            for pt in points:
                pygame.draw.circle(self.screen, GREEN, pt, 2)

            # Draw track borders
        left_x, left_y, _ = transformProj2Orig(s_ref, -self.track_offset, alpha_ref, self.track)
        right_x, right_y, _ = transformProj2Orig(s_ref, self.track_offset, alpha_ref, self.track)

        # Convert boundary coordinates to pixel positions
        left_points = [
            (int(x * self.scale + self.offset_x), int(self.offset_y + y * self.scale))
            for x, y in zip(left_x, left_y)
        ]
        right_points = [
            (int(x * self.scale + self.offset_x), int(self.offset_y + y * self.scale))
            for x, y in zip(right_x, right_y)
        ]

        # Draw left and right boundary lines
        if len(left_points) > 1:
            pygame.draw.lines(self.screen, BLACK, False, left_points, 4)  # Left boundary
        if len(right_points) > 1:
            pygame.draw.lines(self.screen, BLACK, False, right_points, 4)  # Right boundary

    def draw_plan(self, poses, draw_poly, every_nth=1):
        for index, position in enumerate(poses):
            s, ni, alpha, _, _, _ = position

            if index % every_nth == 0:  # Only draw the car at every nth position
                color = (index * 3, index * 3, index * 3)
                pose = transform_from_poly_proj(s, ni, alpha, draw_poly[0], draw_poly[1])
                self.draw_xy_car(pose, position[3], color)

    def draw_trajectory(self, poses, back_color=None, thickness=9):
        # Collect (x, y) positions
        pts_list = []
        color_list = []
        for index, position in enumerate(poses):
            s, ni, alpha, v, _, _ = position
            x, y, psi = transformProj2Orig(s, ni, alpha, self.track)

            pts_list.append(self.project(x, y))
            color_list.append(value_to_color(v, 2.5, 3.5))

        if back_color: #draw accent back color
            for index in range(0, len(pts_list) - 1, 1):
                pygame.draw.line(self.screen,
                                 back_color,
                                 pts_list[index],
                                 pts_list[index + 1], thickness + 3)

        for index in range(0, len(pts_list) - 1, 1):
            pygame.draw.line(self.screen,
                             color_list[index],
                             pts_list[index],
                             pts_list[index + 1], thickness)

        # Draw line if at least two points are available

    # if len(pts_list) >= 2:
    #    pygame.draw.lines(self.screen, color, False, pts_list, width=2)

    def clear(self):
        self.screen.fill(WHITE)
        self.draw_track()

    def flip(self, clear=True):
        pygame.display.flip()
        self.screen.fill(WHITE)
        self.clear()

    def draw_track_aproximation(self, track_poly):
        # Draw track approximation:
        x_poly = track_poly[0]
        y_poly = track_poly[1]

        start = 0
        end = 1.5
        points = 40
        space = np.linspace(start, end, points)

        poly_track_pts = [
            (int(x_poly(arc) * self.scale + self.offset_x), int(self.offset_y + y_poly(arc) * self.scale))
            for arc in space
        ]

        pygame.draw.lines(self.screen, BLUE, False, poly_track_pts, 2)
        pygame.draw.circle(self.screen, BLUE, self.project(x_poly(1), y_poly(1)), 4)


def value_to_color(value, vmin=0, vmax=1, cmap_name='jet'):
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    cmap = cm.get_cmap(cmap_name)
    rgba = cmap(norm(value))  # Returns (r, g, b, a) in [0, 1]
    rgb = tuple(int(255 * x) for x in rgba[:3])  # Convert to (r, g, b)
    return rgb
