import pygame
import numpy as np
import math
import os
from pathlib import Path
# Initialize Pygame
pygame.init()
screen = pygame.display.set_mode((1920, 1440))
pygame.display.set_caption("Track Generator")
clock = pygame.time.Clock()

# Parameters

TURN_ANGLE = math.radians(45)
TURN_RADIUS = 0.890
RESOLUTION = 0.025  # Distance between points
SCALE = 400  # 1 meter = 100 pixels
TRACK_WIDTH = 0.55
STRIP_WIDTH = 0.02

# Track data
s = 0.0
x, y = 0.0, 0.0  # in meters
psi = 0.0
points = [(s, x, y, psi, 0.0)]
segments = [0]  # Index of each segment start in points

# Functions

def add_straight(repeat=1, length = 0.17):
    global s, x, y, psi
    for _ in range(repeat):
        start_index = len(points)
        n_steps = int(length / RESOLUTION)
        step_len = length / n_steps
        for _ in range(n_steps):
            x += step_len * math.cos(psi)
            y += step_len * math.sin(psi)
            s += step_len
            points.append((s, x, y, psi, 0.0))
        segments.append(len(points))

def add_turn(direction):
    global s, x, y, psi
    radius = TURN_RADIUS / 2
    sign = -1 if direction == 'left' else 1  # Corrected turning direction

    n_steps = int(TURN_ANGLE / (RESOLUTION / radius))
    d_angle = TURN_ANGLE / n_steps
    d_len = radius * d_angle

    center_x = x - sign * radius * math.sin(psi)
    center_y = y + sign * radius * math.cos(psi)

    start_index = len(points)
    for _ in range(n_steps):
        psi += sign * d_angle
        x = center_x + sign * radius * math.sin(psi)
        y = center_y - sign * radius * math.cos(psi)
        s += d_len
        points.append((s, x, y, psi, sign / radius))
    segments.append(len(points))

def delete_last_segment():
    if len(segments) > 1:
        del points[segments[-2]:segments[-1]]
        del segments[-1]
        if points:
            last = points[-1]
            global s, x, y, psi
            s, x, y, psi = last[:4]


def draw_track():
    screen.fill((255, 255, 255))
    if not points:
        return
    x_last, y_last = points[-1][1], points[-1][2]
    x_offset = 960 - x_last * SCALE
    y_offset = 540 - y_last * SCALE

    for i in range(1, len(points)):
        x1 = points[i-1][1] * SCALE + x_offset
        y1 = points[i-1][2] * SCALE + y_offset
        x2 = points[i][1] * SCALE + x_offset
        y2 = points[i][2] * SCALE + y_offset

        # Centerline
        pygame.draw.line(screen, (255, 0, 0), (x1, y1), (x2, y2), 2)

        # Border lines
        psi1 = points[i-1][3]
        psi2 = points[i][3]

        offset1 = TRACK_WIDTH / 2
        offset2 = STRIP_WIDTH / 2

        dx1 = offset1 * math.sin(psi1) * SCALE
        dy1 = -offset1 * math.cos(psi1) * SCALE
        dx2 = offset1 * math.sin(psi2) * SCALE
        dy2 = -offset1 * math.cos(psi2) * SCALE
        pygame.draw.line(screen, (0, 0, 0), (x1 - dx1, y1 - dy1), (x2 - dx2, y2 - dy2), 1)
        pygame.draw.line(screen, (0, 0, 0), (x1 + dx1, y1 + dy1), (x2 + dx2, y2 + dy2), 1)

        dx1 = (offset1 - offset2) * math.sin(psi1) * SCALE
        dy1 = -(offset1 - offset2) * math.cos(psi1) * SCALE
        dx2 = (offset1 - offset2) * math.sin(psi2) * SCALE
        dy2 = -(offset1 - offset2) * math.cos(psi2) * SCALE
        pygame.draw.line(screen, (0, 0, 0), (x1 - dx1, y1 - dy1), (x2 - dx2, y2 - dy2), 1)
        pygame.draw.line(screen, (0, 0, 0), (x1 + dx1, y1 + dy1), (x2 + dx2, y2 + dy2), 1)


def save_track(filename="generated_track.txt"):
    array = np.array(points)
    np.savetxt(filename, array, fmt='%.7e')

# Main loop
def print_help():
    print("""
=== Track Generator Controls ===
Arrow UP ........... Add short straight segment
Shift + Arrow UP ... Add long straight segment
Arrow LEFT ......... Add left turn
Arrow RIGHT ........ Add right turn
Backspace .......... Delete last segment
ESC ................ Exit and save track
""")

print_help()


running = True
while running:
    draw_track()
    pygame.display.flip()
    clock.tick(30)

    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_UP:
                if pygame.key.get_mods() & pygame.KMOD_SHIFT:
                    add_straight(length=0.72)
                else:
                    add_straight()
            elif event.key == pygame.K_LEFT:
                add_turn('left')
            elif event.key == pygame.K_RIGHT:
                add_turn('right')
            elif event.key == pygame.K_BACKSPACE:
                delete_last_segment()
            elif event.key == pygame.K_ESCAPE:
                running = False

pygame.quit()

# Prompt for file name after Pygame exits
filename = input("Enter filename to save track (default: generated_track.txt): ").strip()
if not filename:
    filename = "generated_track.txt"
save_track(filename)
print(f"Track saved to: {filename}")
