import torch
import chess
from chess_engine import ChessEngine
from chess_network import ChessNet
from board_wrapper import BoardWrapper

def load_model_and_play():
    # Add ChessNet to the safe globals list
    torch.serialization.add_safe_globals([ChessNet])
    
    # Method 1: Load the entire model
    network = torch.load('chess_model.pth', weights_only=False)
    
    # Alternatively, you could use weights_only=False (less secure)
    # network = torch.load('chess_model.pth', weights_only=False)
    
    # Or load just the weights (Method 2)
    # network = ChessNet()
    # network.load_state_dict(torch.load('chess_model_weights.pth'))
    
    # Make sure model is in evaluation mode
    network.eval()
    
    # Configure the chess engine to use the loaded network
    chess_engine = ChessEngine(iteration_count=100, c=1.0, tau=0.0)
    chess_engine.network = network
    
    # Create a board for playing
    board = BoardWrapper(debug=True, custom_position="8/1n6/8/7Q/k2K4/pNB5/b2p4/1n1r4")  # Set debug=True to visualize the board
    
    # Game loop
    while not board.board.is_game_over():
        if board.board.turn == chess.WHITE:  # AI plays white
            print("AI is thinking...")
            move, _, _ = chess_engine.get_move(board.board)
            print(f"AI plays: {move}")
            board.board.push(chess.Move.from_uci(move))
            board.print_board(False, chess_engine)
        else:  # Human plays black
            board.print_board(False, chess_engine)
            valid_move = False
            while not valid_move:
                try:
                    user_move = input("Enter your move (e.g., 'e7e5'): ")
                    move = chess.Move.from_uci(user_move)
                    if move in board.board.legal_moves:
                        board.board.push(chess.Move.from_uci(move))
                        valid_move = True
                    else:
                        print("Illegal move. Try again.")
                except ValueError:
                    print("Invalid format. Please use format like 'e7e5'")
    
    # Game over
    result = board.board.outcome()
    if result.winner == chess.WHITE:
        print("AI wins!")
    elif result.winner == chess.BLACK:
        print("You win!")
    else:
        print("It's a draw!")

if __name__ == "__main__":
    load_model_and_play()