import torch

from chess_engine import ChessEngine
from chess_network import ChessNet
from board_wrapper import BoardWrapper
import matplotlib.pyplot as plt
from tqdm import tqdm

def main() -> None:
    total_games_to_play = 2000  # Total number of games to play across all iterations
    games_per_iteration = 20  # Number of games per iteration
    max_training_games = 100000  # Maximum number of games to keep in memory
    debug = False
    all_losses = []
    
    # Statistics tracking
    stats_history = []

    iteration_count = 50  # Number of iterations for each training session
    c = 1.0
    tau = 0.5
    chess_engine = ChessEngine(iteration_count, c, tau)
    training_set = []
    
    # Learning rate warmup parameters
    warmup_iterations = 3  # Number of iterations for warmup
    warmup_factor = 0.01  # Start with 1% of target learning rate
    
    def get_lr(iteration):
        """Learning rate scheduler with warmup based on the iteration number."""
        # Base learning rate based on training progress
        if games_played < 1000:
            base_lr = 0.001  # High learning rate at the start
        elif games_played < 2500:
            base_lr = 0.0005  # Lower learning rate as we progress
        elif games_played < 4000:
            base_lr = 0.0001  # Further lower learning rate for fine-tuning
        
        # Apply warmup if in warmup phase
        if iteration < warmup_iterations:
            # Linear warmup from warmup_factor to 1.0
            warmup_progress = iteration / warmup_iterations
            warmup_multiplier = warmup_factor + (1.0 - warmup_factor) * warmup_progress
            return base_lr * warmup_multiplier
        else:
            return base_lr

    counter = 0
    games_played = 0

    while games_played < total_games_to_play:
        # Game outcome statistics for this iteration
        iteration_stats = {"white_wins": 0, "black_wins": 0, "draws": 0}
        
        # Create a progress bar for this batch of games
        pbar = tqdm(range(games_per_iteration), 
                   desc=f"Training games (iteration {counter+1}/{total_games_to_play//games_per_iteration})",
                   unit="game")
        
        # Play a batch of games
        for i in pbar:
            board = BoardWrapper(debug, custom_position="8/4k3/8/8/8/2R5/2K5/R7 w - - 0 1")
            labeled_data = board.play_game(chess_engine)
            
            # Track game outcome statistics (result is in the last element of each tuple)
            if labeled_data:
                # Check the result of the game (using the last move's result)
                result = labeled_data[-1][2]  # (state, policy, result)
                
                if result == 1:  # WHITE_WON
                    iteration_stats["white_wins"] += 1
                elif result == -1:  # BLACK_WON
                    iteration_stats["black_wins"] += 1
                elif result == 0:  # DRAW
                    iteration_stats["draws"] += 1
            
            # Add new labeled data while maintaining max size
            training_set.extend(labeled_data)
            
            # Keep only the most recent moves if we exceed the limit
            if len(training_set) > max_training_games:
                # Remove oldest games to maintain max_training_games size
                excess = len(training_set) - max_training_games
                training_set = training_set[excess:]
            
            games_played += 1
            pbar.set_postfix({"Total games": games_played, "Remaining": total_games_to_play - games_played, 
                              "Training data": len(training_set)})

            # Stop if we reached the total number of games
            if games_played >= total_games_to_play:
                break

        # Store statistics for this iteration
        stats_history.append(iteration_stats)
        
        # Adjust learning rate based on the iteration and warmup schedule
        lr = 0.001
        
        # Print more detailed info about learning rate
        warmup_status = " (warming up)" if counter < warmup_iterations else ""
        print(f"Training iteration {counter + 1} with learning rate {lr:.6f}{warmup_status}")
        print(f"Training on {len(training_set)} positions")
        
        # Print game outcomes for this iteration
        print(f"Game outcomes: White wins: {iteration_stats['white_wins']}, "
              f"Black wins: {iteration_stats['black_wins']}, "
              f"Draws: {iteration_stats['draws']}")
        
        # Print win rate trend if we have more than one iteration
        if len(stats_history) > 1:
            prev_white_wins = stats_history[-2]['white_wins']
            curr_white_wins = stats_history[-1]['white_wins']
            win_change = curr_white_wins - prev_white_wins
            trend = "↑" if win_change > 0 else "↓" if win_change < 0 else "→"
            print(f"Win trend: {trend} ({win_change:+d} compared to previous iteration)")

        # Train the chess engine on the collected games
        epoch_losses = chess_engine.train(training_set, lr=lr)
        all_losses.append(epoch_losses)
        
        counter += 1

    # After all iterations, print performance trend
    print("\nPerformance History:")
    for i, stats in enumerate(stats_history):
        print(f"Iteration {i+1}: White wins: {stats['white_wins']}, "
              f"Black wins: {stats['black_wins']}, "
              f"Draws: {stats['draws']}")

    flat_losses = [loss for sublist in all_losses for loss in sublist]
    
    # Save model weights and entire model after training
    torch.save(chess_engine.network.state_dict(), 'chess_model_weights.pth')
    print("Model weights saved as 'chess_model_weights.pth'.")
    
    torch.save(chess_engine.network, 'chess_model.pth')
    print("Entire model saved as 'chess_model.pth'.")
    
    # Plot the training loss over time
    plt.plot(flat_losses)
    plt.title("Training Loss Over Time")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()


if __name__ == '__main__':
    main()
