from enum import Enum
from io import BytesIO
from time import sleep
from typing import Tuple, Optional

# Libraries to display the svg board
import cairosvg
import chess
import chess.svg
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from matplotlib.widgets import Button
from matplotlib.backend_bases import Event
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from torch import Tensor

from chess_engine import ChessEngine


class GameState(Enum):
    GAME_VALID = 2
    CHECKMATE_WHITE_WON = 1
    CHECKMATE_BLACK_WON = -1
    DRAW = 0


class BoardWrapper:

    def __init__(self, debug: bool, custom_position: str = None) -> None:
        self.debug = debug
        self.board: chess.Board = chess.Board()

        if custom_position:
            self.board: chess.Board = chess.Board(custom_position)
        else:
            self.board: chess.Board = chess.Board()

        self.fig: Optional[Figure] = None
        self.ax: Optional[Axes] = None
        self.btn_continue: Optional[Button] = None
        self.btn_finish: Optional[Button] = None
        self.continue_clicked: bool = False
        self.finish_clicked: bool = False
        self.game_running: bool = True
        self.skip_to_end: bool = False

    def check_if_game_ended(self) -> GameState:
        # TODO: Timeout?
        outcome = self.board.outcome()
        if outcome:
            if outcome.winner == chess.WHITE:
                return GameState.CHECKMATE_WHITE_WON
            elif outcome.winner == chess.BLACK:
                return GameState.CHECKMATE_BLACK_WON
            else:
                return GameState.DRAW
        else:
            return GameState.GAME_VALID

    def play_game(self, chess_engine: ChessEngine) -> list[Tuple[Tensor, np.ndarray, float]]:
        self.game_running = True
        self.skip_to_end = False
        mcts_distributions = []

        while self.game_running:
            move, probs, network_input = chess_engine.get_move(self.board)
            mcts_distributions.append((self.board.copy(), network_input, probs))
            if not self.skip_to_end:
                self.print_board(False, chess_engine)

            action = self.board.parse_uci(move)
            self.board.push(action)
            game_state = self.check_if_game_ended()
            if game_state != GameState.GAME_VALID:
                self.game_running = False
            # if len(mcts_distributions) > 3:
            #     break
        self.print_board(True, chess_engine)
        # Return
        # board,
        # mcts_distributions for each move on the stack,
        # winner of the game
        result = game_state.value
        labeled_data = []
        for index, (board_snapshot, network_input, probs_dict) in enumerate(mcts_distributions):
            policy_vector = chess_engine.probs_dict_to_policy_vector(probs_dict, board_snapshot)
            if board_snapshot.turn == chess.WHITE:
                value_target_for_snapshot = float(result) # If White's turn, target is the absolute result
            else: # Black's turn
                value_target_for_snapshot = float(-result) # If Black's turn, target is the negative of the absolute result
            labeled_data.append((network_input, policy_vector, value_target_for_snapshot))
            # if index == 0:
            #     top_moves = sorted(probs_dict.items(), key=lambda x: x[1], reverse=True)[:5]
            #     print("Top 5 moves by probability:")
            #     for move, prob in top_moves:
            #         print(f"{move}: {prob:.4f}", result, value_target_for_snapshot)
        return labeled_data

    # This was generated by AI
    def on_continue(self, event: Event) -> None:
        self.continue_clicked = True

    def on_finish(self, event: Event) -> None:
        self.skip_to_end = True

    def print_board(self, final: bool, chess_engine: ChessEngine) -> None:
        """Display the chess board and MCTS tree visualization."""
        if not self.debug:
            return

        # Create a figure with layout: chess board on top, MCTS tree below
        if self.fig is None or not plt.fignum_exists(self.fig.number):
            # Set up the main figure with stacked layout (square aspect ratio)
            self.fig = plt.figure(figsize=(12, 12))

            # Create grid for the layout (2 rows, 1 column) with exactly 50% height each
            gs = self.fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.05)

            # Chess board subplot (top)
            self.ax = self.fig.add_subplot(gs[0, 0])  # Correctly assign to Axes

            # MCTS tree visualization subplot (bottom)
            self.ax_tree = self.fig.add_subplot(gs[1, 0])  # Correctly assign to Axes

            # Create button axes to the right of the chess board
            button_width = 0.1
            button_height = 0.05

            # Position buttons to the right of the chess board
            continue_ax = plt.axes((0.75, 0.75, button_width, button_height))
            finish_ax = plt.axes((0.75, 0.65, button_width, button_height))

            # Create button objects
            self.btn_continue = Button(continue_ax, 'Continue')
            self.btn_finish = Button(finish_ax, 'Finish Game')

            # Assign button click events
            self.btn_continue.on_clicked(self.on_continue)
            self.btn_finish.on_clicked(self.on_finish)

        # Update the chess board display
        if self.ax is not None:
            self.ax.clear()
            svg_board = chess.svg.board(self.board, size=800)
            img_png = cairosvg.svg2png(svg_board)
            img = Image.open(BytesIO(img_png))
            self.ax.imshow(img)  # Correctly use Axes method
            self.ax.axis('off')  # Correctly use Axes method
            self.ax.set_title('Chess Board')  # Correctly use Axes method

        # Update the MCTS tree visualization using the revised function
        if self.ax_tree is not None:
            try:
                # Call the visualize_tree function with our existing axis
                chess_engine.mcts.visualize_tree(max_depth=3, ax=self.ax_tree)
            except Exception as e:
                self.ax_tree.clear()
                self.ax_tree.text(0.5, 0.5, f"Tree visualization error: {str(e)}",
                                  ha='center', va='center', fontsize=12)
                print(f"Error in MCTS visualization: {str(e)}")  # Debug output

        # Update the display
        if self.ax is not None and self.ax_tree is not None:
            # Ensure exact 50% height for each subplot by setting position manually
            pos_board = self.ax.get_position()
            pos_tree = self.ax_tree.get_position()

            # Set the top subplot to use exactly the top half
            self.ax.set_position((pos_board.x0, 0.5, pos_board.width, 0.45))

            # Set the bottom subplot to use exactly the bottom half
            self.ax_tree.set_position((pos_tree.x0, 0.05, pos_tree.width, 0.45))
            self.fig.canvas.draw()

        if final:
            plt.pause(5)  # Just keep it open for 5 seconds for now
        else:
            self.continue_clicked = False
            plt.pause(0.1)  # Make UI responsive

            # Wait for user interaction
            while not (self.continue_clicked or self.skip_to_end) and self.game_running:
                plt.pause(0.1)
