import time
from typing import Tuple, Any
import line_profiler

import chess
import numpy as np
import torch
from torch import Tensor, nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

from chess_network import ChessNet
from mcts import MCTS

torch.set_float32_matmul_precision('high')  # TF32
torch.backends.cudnn.benchmark = True      # Optimize for fixed input sizes
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

class ChessEngine:
    """Empty (for now) class for our chess engine"""

    def __init__(self, iteration_count: int, c: float, tau: float) -> None:
        move_index_mapping = self.create_pure_queen_move_to_index_mapping()
        move_index_mapping_knight = self.create_knight_move_to_index_mapping()
        move_index_mapping_promotion = self.create_underpromotion_move_to_index_mapping()

        # Combine the three dictionaries
        combined_move_index_mapping = move_index_mapping.copy()  # Make a copy to avoid modifying the original dicts
        combined_move_index_mapping.update(move_index_mapping_knight)  # Add knight moves
        combined_move_index_mapping.update(move_index_mapping_promotion)  # Add underpromotion moves
        self.iteration_count = iteration_count
        self.mcts = MCTS(c, tau, self.evaluation_function, self, combined_move_index_mapping)

        # What follows is chatgpt code to compile the neural net to make it go brrrrr
        uncompiled_net = ChessNet()
        model_device = uncompiled_net.device # Get the device the model chose

        # 3. Apply torch.compile() (Requires PyTorch 2.0 or newer)
        compiled_model = None
        if not hasattr(torch, 'compile'):
            sys.exit("torch.compile not available (requires PyTorch 2.0+). Using uncompiled model.")

        print(f"Attempting to compile the model for device: {model_device}...")
        try:
            # Default mode is usually a good start.
            # Other modes: "reduce-overhead" (good for inference), "max-autotune" (longer compile, potentially faster)
            compiled_model = torch.compile(uncompiled_net, mode="max-autotune")
            print("Model compiled successfully.")
            self.network = compiled_model
        except Exception as e:
            sys.exit(f"Failed to compile model: {e}.")

    def train(self, training_data, lr, epochs=5, batch_size=32, device=None):
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.network.to(device)
        self.network.train()

        optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
        states, policy_targets, results = zip(*training_data)

        states_tensor = torch.stack([
            s.float() if isinstance(s, torch.Tensor) else torch.tensor(s, dtype=torch.float32)
            for s in states
        ])
        states_tensor = states_tensor.squeeze(1)
        policy_tensor = torch.stack([
            p.float() if isinstance(p, torch.Tensor) else torch.tensor(p, dtype=torch.float32)
            for p in policy_targets
        ])
        result_tensor = torch.tensor(results, dtype=torch.float32).unsqueeze(1)

        dataset = TensorDataset(states_tensor, policy_tensor, result_tensor)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        loss_fn_value = nn.MSELoss()
        epoch_losses = []
        batch_losses = []

        for epoch in tqdm(range(epochs), desc="Training Epochs", unit="epoch"):
            total_loss = 0
            self.network.train()

            for state_batch, pi_batch, z_batch in tqdm(dataloader, desc="Batch Progress", leave=False, unit="batch"):
                state_batch = state_batch.to(device)
                pi_batch = pi_batch.to(device)
                z_batch = z_batch.to(device)

                policy_logits, value_pred = self.network(state_batch)
                # print(policy_logits.shape, type(policy_logits), policy_logits)
                # print(policy_logits[0].sum())
                # time.sleep(10)
                policy_log_probs = torch.log_softmax(policy_logits + 1e-10, dim=1)
                policy_loss = -torch.sum(pi_batch * policy_log_probs, dim=1).mean()
                value_loss = loss_fn_value(value_pred, z_batch)
                #print(policy_loss, value_loss)
                loss = policy_loss + value_loss
                total_loss += loss.item()
                batch_losses.append(loss.item())
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=1.0)
                # Add in your train method after loss.backward()
                #torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=1.0)
                optimizer.step()
            avg_loss = total_loss / len(dataloader)
            epoch_losses.append(avg_loss)
            print(f"Epoch {epoch + 1}/{epochs}, Avg Loss: {avg_loss:.4f}")

        return epoch_losses

    @line_profiler.profile
    def evaluation_function(self, network_input: torch.Tensor, board: chess.Board) -> Tuple[torch.Tensor, float]:
        # Create the legal moves mask
        legal_moves_mask = self.create_moves_mask(board)
        #print(legal_moves_mask)
        # Evaluate the position using the neural network

        legal_moves_mask = torch.tensor(legal_moves_mask, dtype=torch.float32, device=self.network.device)
        network_input = network_input.to(self.network.device)  # Move to device BEFORE calling forward

        with torch.no_grad():
            policy, value = self.network(network_input, legal_moves_mask)
        # print(policy.sum(), policy.shape)
        #print(policy, value)

        policy = policy.clone()
        value = value.clone()

        return policy, value

    def get_move(self, board: chess.Board) -> tuple[str, dict[str, Any], Tensor]:
        network_input = self.generate_network_input(board)
        self.mcts.init_board(board, network_input)
        for i in range(self.iteration_count):
            new_node = self.mcts.simulate()
            if not new_node:
                print("Note: MCTS search ended early on iteration "
                      "" + str(i) + " out of " + str(self.iteration_count))
                break
        move, probs = self.mcts.select_move()
        return move, probs, network_input

    def generate_network_input(self, board: chess.Board) -> torch.Tensor:
        # TODO: Fix this type error
        board_np_array: np.ndarray[Tuple[int, int, int], np.dtype[np.float64]] = np.zeros((21, 8, 8), dtype=np.float64)

        # 12 planes for each player's pieces
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece is not None:
                piece_type = piece.piece_type
                player = piece.color
                plane_index = (piece_type - 1) + (0 if player else 6)
                row, col = divmod(square, 8)
                board_np_array[plane_index, row, col] = 1

        # 2 planes to keep track of repetitions
        board_np_array[12, :, :] = np.full((8, 8), 1 if board.is_repetition(2) else 0)
        board_np_array[13, :, :] = np.full((8, 8), 1 if board.is_repetition(3) else 0)

        # 1 plane to keep track of who should move
        board_np_array[14, :, :] = np.full((8, 8), 1 if board.turn == chess.WHITE else 0)

        # 4 planes for castling rights
        board_np_array[15, :, :] = np.full((8, 8), 1 if board.has_kingside_castling_rights(chess.WHITE) else 0)
        board_np_array[16, :, :] = np.full((8, 8), 1 if board.has_queenside_castling_rights(chess.WHITE) else 0)
        board_np_array[17, :, :] = np.full((8, 8), 1 if board.has_kingside_castling_rights(chess.BLACK) else 0)
        board_np_array[18, :, :] = np.full((8, 8), 1 if board.has_queenside_castling_rights(chess.BLACK) else 0)

        # 1 plane for move count normalized
        move_count_normalized = board.fullmove_number / 100
        board_np_array[19, :, :] = np.full((8, 8), move_count_normalized)

        # 1 plane for halfmove clock to enforce the 50-move rule
        no_progress_count_normalized = board.halfmove_clock / 100
        board_np_array[20, :, :] = np.full((8, 8), no_progress_count_normalized)

        # Add batch_size dimension using np.expand_dims
        board_np_array = np.expand_dims(board_np_array, axis=0)  # Shape becomes (1, 21, 8, 8)

        # Convert to tensor
        board_tensor = torch.from_numpy(board_np_array).float()  # Convert to tensor

        return board_tensor

    # Promotion to queen is signalled by not setting a single underpromotion to true
    @line_profiler.profile
    def create_moves_mask(self, board: chess.Board) -> np.ndarray[Tuple[int], np.dtype[np.float64]]:
        moves_mask = np.zeros((8, 8, 73))
        legal_moves = board.legal_moves
        # We can move each piece in 8 directions and for a maximum of 7 squares
        for move in list(legal_moves):
            # Get 9 underpromotions as separate planes
            piece = board.piece_at(move.from_square)
            if (
                    ((chess.square_rank(move.from_square) == 6 and chess.square_rank(move.to_square) == 7) or
                     (chess.square_rank(move.from_square) == 1 and chess.square_rank(move.to_square) == 0)) and
                    piece is not None and
                    piece.piece_type == chess.PAWN and
                    move.promotion is not None and
                    move.promotion != chess.QUEEN
            ):
                from_row, from_col, move_index = self.encode_underpromotions(move, move.promotion)
            # Get 8 knight moves as separate planes
            elif (piece is not None and \
                  piece.piece_type == chess.KNIGHT):
                from_row, from_col, move_index = self.encode_knight_moves(move)
            # Get 56 possible queen moves which incorporate
            # rook, bishop, queen, all pawn moves except promotion moves, all king moves
            else:
                from_row, from_col, move_index = self.encode_queen_moves(move)
            # If the move is not a promotion to queen, then set the particular plane to 1
            if move_index != -1:
                moves_mask[from_row][from_col][move_index] = 1
        return moves_mask.reshape(-1)

    def encode_underpromotions(self, move: chess.Move, promoted_piece: int, multiplier: int = 9) -> Tuple[
        int, int, int]:
        from_square = move.from_square
        to_square = move.to_square

        from_row, from_col = divmod(from_square, 8)
        to_row, to_col = divmod(to_square, 8)
        horizontal_difference = from_col - to_col

        if horizontal_difference == 0 and promoted_piece == chess.BISHOP:
            return from_row, from_col, multiplier * 7 + 0
        elif horizontal_difference == 0 and promoted_piece == chess.ROOK:
            return from_row, from_col, multiplier * 7 + 1
        elif horizontal_difference == 0 and promoted_piece == chess.KNIGHT:
            return from_row, from_col, multiplier * 7 + 2
        elif horizontal_difference == 1 and promoted_piece == chess.BISHOP:
            return from_row, from_col, multiplier * 7 + 3
        elif horizontal_difference == 1 and promoted_piece == chess.ROOK:
            return from_row, from_col, multiplier * 7 + 4
        elif horizontal_difference == 1 and promoted_piece == chess.KNIGHT:
            return from_row, from_col, multiplier * 7 + 5
        if horizontal_difference == -1 and promoted_piece == chess.BISHOP:
            return from_row, from_col, multiplier * 7 + 6
        elif horizontal_difference == -1 and promoted_piece == chess.ROOK:
            return from_row, from_col, multiplier * 7 + 7
        elif horizontal_difference == -1 and promoted_piece == chess.KNIGHT:
            return from_row, from_col, multiplier * 7 + 8
        return from_row, from_col, -1

    def encode_queen_moves(self, move: chess.Move) -> Tuple[int, int, int]:
        from_square = move.from_square
        to_square = move.to_square
        multipliers = {"N": 0, "NE": 1, "E": 2, "SE": 3, "S": 4, "SW": 5, "W": 6, "NW": 7}

        from_row, from_col = divmod(from_square, 8)
        to_row, to_col = divmod(to_square, 8)
        vertical_difference = from_row - to_row
        horizontal_difference = from_col - to_col
        # if from_row == 7:
        #     print(from_row, from_col, to_row, to_col,
        #           vertical_difference > 0 , horizontal_difference > 0,
        #           vertical_difference < 0, horizontal_difference < 0,
        #           vertical_difference == 0, horizontal_difference == 0)
        #     input("Press Enter to continue...")
        # All directions: {N, NE, E, SE, S, SW, W, NW }
        # N direction
        if vertical_difference < 0 and horizontal_difference == 0:
            return from_row, from_col, multipliers["N"] * 7 + abs(vertical_difference) - 1
        # NE direction
        elif vertical_difference < 0 < horizontal_difference:
            return from_row, from_col, multipliers["NE"] * 7 + abs(vertical_difference) - 1
        # E direction
        elif vertical_difference == 0 and horizontal_difference > 0:
            return from_row, from_col, multipliers["E"] * 7 + abs(horizontal_difference) - 1
        # SE direction
        elif vertical_difference > 0 and horizontal_difference > 0:
            return from_row, from_col, multipliers["SE"] * 7 + abs(vertical_difference) - 1
        # S direction
        elif vertical_difference > 0 and horizontal_difference == 0:
            return from_row, from_col, multipliers["S"] * 7 + abs(vertical_difference) - 1
        # SW direction
        elif vertical_difference > 0 > horizontal_difference:
            return from_row, from_col, multipliers["SW"] * 7 + abs(vertical_difference) - 1
        # W direction
        elif vertical_difference == 0 and horizontal_difference < 0:
            return from_row, from_col, multipliers["W"] * 7 + abs(horizontal_difference) - 1
        # NW direction
        elif vertical_difference < 0 and horizontal_difference < 0:
            return from_row, from_col, multipliers["NW"] * 7 + abs(vertical_difference) - 1
        return -1, -1, -1

    def encode_knight_moves(self, move: chess.Move, multiplier: int = 8) -> Tuple[int, int, int]:
        from_square = move.from_square
        to_square = move.to_square

        from_row, from_col = divmod(from_square, 8)
        to_row, to_col = divmod(to_square, 8)
        vertical_difference = from_row - to_row
        horizontal_difference = from_col - to_col

        if vertical_difference == -2 and horizontal_difference == 1:
            return from_row, from_col, multiplier * 7 + 0
        elif vertical_difference == -2 and horizontal_difference == -1:
            return from_row, from_col, multiplier * 7 + 1
        elif vertical_difference == -1 and horizontal_difference == 2:
            return from_row, from_col, multiplier * 7 + 2
        elif vertical_difference == -1 and horizontal_difference == -2:
            return from_row, from_col, multiplier * 7 + 3
        elif vertical_difference == 1 and horizontal_difference == 2:
            return from_row, from_col, multiplier * 7 + 4
        elif vertical_difference == 1 and horizontal_difference == -2:
            return from_row, from_col, multiplier * 7 + 5
        elif vertical_difference == 2 and horizontal_difference == 1:
            return from_row, from_col, multiplier * 7 + 6
        elif vertical_difference == 2 and horizontal_difference == -1:
            return from_row, from_col, multiplier * 7 + 7
        return from_row, from_col, multiplier * 7 + 7

    def probs_dict_to_policy_vector(self, probs_dict: dict[str, float], board: chess.Board) -> np.ndarray:
        policy_vector = np.zeros(8 * 8 * 73, dtype=np.float32)

        for move, prob in probs_dict.items():
            # try:
            #     print(type(move_san), move_san)
            #     move = board.parse_san(move_san)
            # except ValueError:
            #     continue

            # This is a bit of a workaround, the returned object should be
            # a move object, not an uci string
            move = board.parse_uci(move)

            piece = board.piece_at(move.from_square)

            if (
                    ((chess.square_rank(move.from_square) == 6 and chess.square_rank(move.to_square) == 7) or
                     (chess.square_rank(move.from_square) == 1 and chess.square_rank(move.to_square) == 0)) and
                    piece is not None and
                    piece.piece_type == chess.PAWN and
                    move.promotion is not None
            ):
                from_row, from_col, move_index = self.encode_underpromotions(move, move.promotion)
            elif piece is not None and piece.piece_type == chess.KNIGHT:
                from_row, from_col, move_index = self.encode_knight_moves(move)
            else:
                from_row, from_col, move_index = self.encode_queen_moves(move)

            if move_index != -1:
                idx = (from_row * 8 + from_col) * 73 + move_index
                policy_vector[idx] = prob

        return policy_vector

    def create_pure_queen_move_to_index_mapping(self) -> dict[str, int]:
        """
        Creates a dictionary mapping UCI moves that are purely queen-like
        to their flattened index.
        """
        move_to_index = {}
        multipliers = {"N": 0, "NE": 1, "E": 2, "SE": 3, "S": 4, "SW": 5, "W": 6, "NW": 7}
        for from_square in chess.SQUARES:
            for to_square in chess.SQUARES:
                if from_square == to_square:
                    continue
                move = chess.Move(from_square, to_square)
                from_row, from_col, move_index = self.encode_queen_moves(move)
                if move_index != -1:
                    # Further check to ensure it's a 'true' queen move
                    rank_diff = abs(chess.square_rank(from_square) - chess.square_rank(to_square))
                    file_diff = abs(chess.square_file(from_square) - chess.square_file(to_square))

                    is_straight = (rank_diff > 0 and file_diff == 0) or (rank_diff == 0 and file_diff > 0)
                    is_diagonal = rank_diff == file_diff and rank_diff > 0
                    # if from_row == 7:
                    #     print(is_diagonal, is_straight, chess.square_name(from_square),
                    #           chess.square_name(to_square), move_index)
                    #     input("Press Enter to continue...")
                    if is_straight or is_diagonal:
                        flat_index = from_row * 8 * 73 + from_col * 73 + move_index
                        move_to_index[move.uci()] = flat_index
        return move_to_index

    def create_knight_move_to_index_mapping(self) -> dict[str, tuple[int, int, int]]:
        """
        Creates a dictionary mapping all geometrically valid knight UCI moves
        to their encoded (from_row, from_col, move_index) tuples.
        """
        move_to_index = {}
        for from_square in chess.SQUARES:
            from_row, from_col = divmod(from_square, 8)

            # All 8 knight move offsets
            knight_offsets = [
                (-2, -1), (-2, 1), (-1, -2), (-1, 2),
                (1, -2), (1, 2), (2, -1), (2, 1)
            ]

            for dr, dc in knight_offsets:
                to_row = from_row + dr
                to_col = from_col + dc

                # Check board bounds
                if 0 <= to_row < 8 and 0 <= to_col < 8:
                    to_square = chess.square(to_col, to_row)
                    move = chess.Move(from_square, to_square)

                    from_row, from_col, move_index = self.encode_knight_moves(move)
                    if move_index != -1:
                        # Optional: double-check it's a knight move based on geometry
                        if abs(dr) in [1, 2] and abs(dc) in [1, 2] and abs(dr) + abs(dc) == 3:
                            flat_index = from_row * 8 * 73 + from_col * 73 + move_index
                            move_to_index[move.uci()] = flat_index
        return move_to_index

    def create_underpromotion_move_to_index_mapping(self) -> dict[str, int]:
        """
        Creates a dictionary mapping UCI moves that are valid underpromotions
        (to knight, bishop, or rook) to their (from_row, from_col, move_index) encoding.
        """
        move_to_index = {}
        underpromotion_pieces = [chess.KNIGHT, chess.BISHOP, chess.ROOK]

        for from_square in chess.SQUARES:
            from_rank = chess.square_rank(from_square)
            from_file = chess.square_file(from_square)

            # Only pawns on rank 6 (white) or rank 1 (black) can promote
            if from_rank == 6:  # white to promote on rank 7
                direction = 1
                promotion_rank = 7
            elif from_rank == 1:  # black to promote on rank 0
                direction = -1
                promotion_rank = 0
            else:
                continue

            to_rank = from_rank + direction

            for file_offset in [-1, 0, 1]:  # left capture, forward, right capture
                to_file = from_file + file_offset
                if 0 <= to_file < 8:
                    to_square = chess.square(to_file, to_rank)
                    for piece in underpromotion_pieces:
                        move = chess.Move(from_square, to_square, promotion=piece)
                        from_row, from_col, move_index = self.encode_underpromotions(move, promoted_piece=piece)
                        if move_index != -1:
                            flat_index = from_row * 8 * 73 + from_col * 73 + move_index
                            move_to_index[move.uci()] = flat_index
        return move_to_index







