import math
import sys
import time
from typing import Callable, Tuple, Optional, Any, Union, List, Dict
import line_profiler

import chess
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure
from matplotlib.patches import Rectangle, FancyArrowPatch

class MCTSNode:

    @line_profiler.profile
    def __init__(self, board: chess.Board, node_value: float, node_policy: torch.Tensor, parent: Optional['MCTSNode'], terminal: bool) -> None:
        self.board = board
        self.parent = parent
        self.is_terminal = terminal
        self.childrens: Dict[chess.Move, MCTSNode] = {}
        self.player_turn = board.turn
        self.node_value = node_value
        self.node_policy = node_policy
        self.W = self.node_value
        self.N = 1
        self.timing_stats = {
            'selection': 0.0,
            'evaluation': 0.0,
            'expansion': 0.0,
            'backpropagation': 0.0,
            'total_simulation': 0.0
        }

        # Initialize caches to None
        self._legal_moves_uci_cache: Optional[List[str]] = None
        self._legal_moves_obj_cache: Optional[List[chess.Move]] = None
        self._P_cache = {}

    def get_Q(self) -> float:
        return self.W / self.N if self.N > 0.0 else 0.0

    # TODO: This should return probability for the selected action from policy head
    @line_profiler.profile
    def get_P(self, action: chess.Move, all_moves_dict) -> float:
        if action in self._P_cache:
            return self._P_cache[action]

        uci = action
        if len(uci) == 5 and uci[-1] in "q":  # If its a queen promotion, then it is
                                              # taken as f7f8 and not f7f8q for example
            uci = uci[:4]  # Strip the promotion suffix

        policy = self.node_policy
        flat_index = all_moves_dict[uci]

        P_val = policy[flat_index].item()  # Return as Python float

        self._P_cache[action] = P_val
        return P_val

    def backpropagate(self, value: float) -> None:
        self.N += 1
        self.W += value
        if self.parent is not None:
            self.parent.backpropagate(-value)

    def get_unexplored_actions(self) -> list[str]:
        return [move for move in self.legal_moves_uci if move not in list(self.childrens.keys())]

    # TODO: Further possible performance improvement:
    #           Refactor code to remove sorting. Return N-th best action.
    @line_profiler.profile
    def get_actions(self, c: float, all_moves_dict) -> str:
        """Select the action with the highest UCB value."""
        ucb_action_pairs = []
        for action in self.legal_moves_uci:
            child_P = self.get_P(action, all_moves_dict)
            if action in self.childrens:
                child = self.childrens[action]
                child_Q = -child.get_Q()
                child_N = child.N
            else:
                child_Q = 0
                child_N = 0
            # Calculate UCB for each child node
            ucb_value = (child_Q + c * child_P * (math.sqrt(self.N) / (1 + child_N)))
            ucb_action_pairs.append((ucb_value, action))
        
        ucb_action_pairs.sort(key=lambda pair: pair[0], reverse=True)
        sorted_actions = [action for ucb, action in ucb_action_pairs]
        return sorted_actions

    # Generating legal moves is expensive, this will generate them only when they are needed
    @property
    def legal_moves_uci(self) -> List[str]:
        """
        Returns a list of legal moves for the node's board state.
        Generates and caches the moves on the first access.
        """
        if self._legal_moves_uci_cache is None:
            # The expensive move generation happens HERE, only when first needed.
            self._legal_moves_obj_cache = list(self.board.legal_moves)
            self._legal_moves_uci_cache = [ move.uci() for move in self._legal_moves_obj_cache ]
        return self._legal_moves_uci_cache

    @property
    def legal_moves_obj(self) -> List[chess.Move]:
        """
        Returns a list of legal moves for the node's board state.
        Generates and caches the moves on the first access.
        """
        if self._legal_moves_obj_cache is None:
            sys.exit("I think that this is a bug")

        return self._legal_moves_obj_cache

class MCTS:

    def __init__(self, c: float, tau: float, evaluation_function: Callable[[torch.Tensor, chess.Board],
    Tuple[torch.Tensor, float]], chess_engine, all_moves_dict) -> None:
        self.node: Optional[MCTSNode] = None
        self.evaluation_function = evaluation_function
        self.chess_engine = chess_engine

        # Hyperparameters for MCTS
        self.c = c
        self.tau = tau
        self.timing_stats = {
            'selection': 0.0,
            'evaluation': 0.0,
            'expansion': 0.0,
            'backpropagation': 0.0,
            'total_simulation': 0.0
        }
        self.all_moves_dict = all_moves_dict

    def init_board(self, board: chess.Board, network_input: torch.Tensor) -> None:
        policy, value = self.evaluation_function(network_input, board)
        terminal = board.outcome() is not None

        policy = policy.squeeze(0)  # Remove batch dimension if present

        self.node = MCTSNode(board, value, policy, None, terminal)


    @line_profiler.profile
    def explore_node(self, node: MCTSNode) -> (float, list[MCTSNode]):

        # Check if this node is terminal, if yes then return
        if node.is_terminal:
            node.backpropagate(node.node_value)
            return node.node_value
        # --- Selection ---
        actions = node.get_actions(self.c, self.all_moves_dict)
        for action in actions:
            # If selected action was not explored yet, let's explore it
            if action not in list(node.childrens.keys()):
                new_board = node.board.copy(stack=False)
                action2 = new_board.parse_uci(action)
                new_board.push(action2)

                outcome = new_board.outcome()
                if outcome is None:
                    terminal = False
                    network_input = self.chess_engine.generate_network_input(new_board)
                    policy, value = self.evaluation_function(network_input, new_board)

                    policy = policy.squeeze(0)  # Remove batch dimension if present
                    value = value.item()  # Python float
                else:
                    terminal = True
                    policy = None
                    if outcome.winner == chess.WHITE:
                        absolute_game_result = 1.0
                    elif outcome.winner == chess.BLACK:
                        absolute_game_result = -1.0
                    else: # Draw
                        absolute_game_result = 0.0
                    
                    # 2. Determine the value from the perspective of the player whose turn it is in new_board (the child node)
                    # This is the value that the child node will be initialized with and will backpropagate.
                    # Let's use 'value_of_new_leaf_from_its_perspective' as the variable name
                    if new_board.turn == chess.WHITE: 
                        value_of_new_leaf_from_its_perspective = absolute_game_result
                    else: # new_board.turn == chess.BLACK
                        value_of_new_leaf_from_its_perspective = -absolute_game_result
                    value = value_of_new_leaf_from_its_perspective
                child = MCTSNode(new_board, value, policy, node, terminal)
                node.childrens[action] = child
                child.backpropagate(value)

                return value
            # Else explore the child node
            child_node = node.childrens[action]
            value = self.explore_node(child_node)

            if value is not None:
                return value

        return None

    @line_profiler.profile
    def simulate(self) -> bool:
        start_total = time.perf_counter()

        node = self.node
        value = self.explore_node(node)

        if value is None:
            return False

        # --- Backpropagation ---
        start = time.perf_counter()
        self.timing_stats['backpropagation'] += time.perf_counter() - start

        self.timing_stats['total_simulation'] += time.perf_counter() - start_total

        return True

    def get_action_probs(self) -> dict[str, Any]:
        """Return action probabilities based on visit counts."""
        if self.node is None:
            return {}
        visits = np.array([self.node.childrens[a].N if a in self.node.childrens else 0
                           for a in self.node.legal_moves_uci])
        if self.tau == 0:  # Deterministic selection of the best move
            probs = np.zeros_like(visits, dtype=float)
            best_action_idx = np.argmax(visits)
            probs[best_action_idx] = 1.0
        else:  # Apply softmax with temperature
            visits_temp = visits ** (1.0 / self.tau)
            visits_sum = np.sum(visits_temp)
            if visits_sum > 0:
                probs = visits_temp / visits_sum
            else:
                sys.exit("Hupps")
        best_idx = np.argmax(probs)
        best_move = self.node.legal_moves_uci[best_idx]
        best_prob = probs[best_idx]
        best_visits = visits[best_idx]
        # print(f"Best move: {best_move}, Probability: {best_prob:.3f}, Visits: {best_visits}")

        return dict(zip(self.node.legal_moves_uci, probs))

    def select_move(self) -> tuple[str, dict[str, Any]]:
        """Select a move based on the action probabilities."""
        probs = self.get_action_probs()
        if self.tau == 0:
            # Deterministic selection of the best move
            return max(probs.items(), key=lambda x: x[1])[0], probs
        else:
            # Sample from the probability distribution
            moves, probs_list = zip(*probs.items())
            result: Optional[str] = np.random.choice(moves, p=probs_list)
            if result is None:
                return ""
            return result, probs

    def visualize_tree(self, max_depth: int = 3, figsize: Tuple[int, int] = (25, 10), ax: Optional[Axes] = None) -> \
        Optional[Union[Figure, SubFigure]]:
        """
        Creates a visualization of the MCTS tree up to a certain depth.

        Args:
            max_depth (int): Maximum depth of the tree to visualize
            figsize (tuple): Size of the figure (width, height) - used only if ax is None
            ax (matplotlib.axes.Axes, optional): Axes to draw on. If None, a new figure is created.

        Returns:
            Union[Figure, SubFigure]: A figure object (only if ax is None)
        """
        # Create figure and axes if not provided
        if ax is None:
            fig, ax = plt.subplots(figsize=figsize)
            created_figure = True
        else:
            new_fig: Figure = ax.figure  # This can be a Figure or SubFigure
            created_figure = False
            fig = new_fig

        # First pass: count nodes at each level to calculate positions
        def count_nodes_by_level(node: Optional[MCTSNode], depth: int = 0, counts: Optional[dict[int, int]] = None) -> \
                dict[int, int]:
            if not node:
                return {}
            if counts is None:
                counts = {}

            if depth not in counts:
                counts[depth] = 0
            counts[depth] += 1

            if depth < max_depth:
                for child in node.childrens.values():
                    count_nodes_by_level(child, depth + 1, counts)

            return counts

        level_counts = count_nodes_by_level(self.node)
        max_nodes_in_level = max(level_counts.values())

        # Node style parameters
        node_width = min(300, 2400 / max_nodes_in_level)  # Wider nodes, adjust as needed  # Make nodes narrower if more nodes
        node_height = 150
        level_height = 200  # Vertical space between levels

        # Track positions of drawn nodes to avoid overlaps
        node_positions = {}

        # Get figure dimensions (in points)
        fig_width = ax.get_window_extent().width
        fig_height = ax.get_window_extent().height

        # If dimensions are 0 (not rendered yet), use figure size in inches * 72 (points per inch)
        if fig_width == 0:
            fig_width = fig.get_figwidth() * 72
        if fig_height == 0:
            fig_height = fig.get_figheight() * 72

        # Function to recursively draw nodes
        def draw_node(node: Optional[MCTSNode], depth: int = 0, parent_pos: Optional[Tuple[float, int]] = None,
                    action: str = "root", horizontal_idx: int = 0, parent: Optional[MCTSNode] = None) -> int:
            # Calculate position
            if node is None:
                return 0
            if depth == 0:
                # Root node is centered at the top
                x = fig_width / 2
                y = 50  # Top of figure
            else:
                # Determine how many siblings and position accordingly
                siblings_count = level_counts[depth]
                level_width = fig_width - 100  # Full width minus margins
                x = 50 + horizontal_idx * (level_width / siblings_count)
                y = 50 + (depth * level_height)

            # Store position
            node_id = id(node)
            node_positions[node_id] = (x, y)

            # Draw connecting line from parent
            if parent_pos:
                px, py = parent_pos
                # Draw arrow from parent to child
                arrow = FancyArrowPatch(
                    (px, py + node_height / 2),
                    (x, y - node_height / 2),
                    connectionstyle="arc3,rad=0.0",
                    arrowstyle="-|>",
                    mutation_scale=15,
                    linewidth=1.5,
                    color='black',
                    zorder=0
                )
                ax.add_patch(arrow)

            # Create node rectangle
            rect = Rectangle(
                (x - node_width / 2, y - node_height / 2),
                node_width, node_height,
                linewidth=2,
                edgecolor='black',
                facecolor='lightblue',
                alpha=0.8,
                zorder=1
            )
            ax.add_patch(rect)

            # Add text information with fixed font size
            q_value = round(float(node.get_Q()), 2)
            w_value = round(float(node.W), 2)
            n_value = node.N

            # Compute UCB value if not root
            ucb_value = None
            if parent is not None and action != "root":
                # Compute UCB for this node from parent's perspective
                c = self.c
                all_moves_dict = self.all_moves_dict
                # parent.get_P expects action as UCI string
                child_P = parent.get_P(action, all_moves_dict) if hasattr(parent, "get_P") else 0
                child_Q = -node.get_Q()
                child_N = node.N
                parent_N = parent.N
                ucb_value = child_Q + c * child_P * (math.sqrt(parent_N) / (1 + child_N))
                ucb_value = round(float(ucb_value), 3)
            else:
                ucb_value = None

            fontsize = 10

            # Split action text into two lines
            action_text = f"Action:\n{action}"

            ax.text(x, y + node_height / 4, action_text,
                    ha='center', va='center', fontsize=fontsize, fontweight='bold')
            ax.text(x, y, f"Q: {q_value}",
                    ha='center', va='center', fontsize=fontsize)
            ax.text(x - node_width / 4, y - node_height / 4, f"N: {n_value}",
                    ha='center', va='center', fontsize=fontsize)
            ax.text(x + node_width / 4, y - node_height / 4, f"W: {w_value}",
                    ha='center', va='center', fontsize=fontsize)
            if ucb_value is not None:
                ax.text(x, y - node_height / 2 + 10, f"UCB: {ucb_value}",
                        ha='center', va='bottom', fontsize=fontsize, color='purple')

            # Return if max depth reached or no children
            if depth >= max_depth or not node.childrens:
                return horizontal_idx + 1

            # Draw children
            next_idx = horizontal_idx
            for child_action, child in node.childrens.items():
                next_idx = draw_node(
                    child,
                    depth + 1,
                    (x, y),
                    child_action,
                    next_idx,
                    parent=node
                )

            return next_idx

        # Clear the axis before drawing
        ax.clear()

        # Start drawing from root
        draw_node(self.node)

        # Set axis properties
        ax.set_xlim(0, fig_width)
        ax.set_ylim(0, fig_height)
        ax.axis('off')
        ax.set_title('MCTS Tree Visualization', fontsize=16, pad=10)

        # Make sure we have enough space for the tree
        total_height = 50 + (max(level_counts.keys()) * level_height) + 100

        # Only return the figure if we created it
        if created_figure:
            # Resize if needed
            if total_height > fig_height:
                fig.set_figheight(total_height / 72 * 1.1)  # Add 10% margin
            return fig
        else:
            # Just adjust the view for an existing axis
            ax.set_ylim(0, max(total_height, fig_height))
        return None
