# Training MLP *AlphaZero* Agent

This notebook contains configuration and code to train a neural network, 
which is integrated with a `Mathematico` agent to play the game.

Consider an agent **A** which is using MCTS and neural network *N* to find the
best move.

This algorithm works as follows: 

* play *M* games, recording the whole MCTS tree including all statistics and 
the final outcome (in the expert memory).
* sample *k* moves from the memory (experience replay), compute the loss between 
*N(s)* and expected reward and update the params of the network.


In [None]:
config = {
    "algorithm": "recreate fasst 208",  # rough description of the algorithm
    "name": "more-batch-samples",  # wandb-run name
    "cuda": True,  # use CUDA if possible (beware of memory)
    
    # random related, due to 
    "seed": 0,
    "test_seed": 42,  # used for always measuring performance on the same games
    "test after . epochs": 500,  # how many epochs to train before conducting a tournament between agents
    
    # neural net
    "network": "Simple_Board_v0",  # name of the network to use
    
    ## pretraining on only final states, uses exponential decay
    "pretrain": {  # None for no pretraining
        "epochs": 1024,
        "samples": 128,
        "start-lr": 0.001,
        "weight-decay": 0,  # adam weight decay
    },
    
    ## Optimizer params
    "optimizer": "Adam",  # the only option
    "lr": 0.005,  # initial learning rate
    "betas": (0.9, 0.999),
    "weight-decay": 0,
    
    ## LR scheduler params
    "lr-scheduler": "ExponentialLR", # only option for now (used with ReduceLROnPlateau)
    "lr-gamma": 0.96,  
    
    ## gradient clipping
    "max-gradient-norm": 1,
    
    # about MCTS
    "stochastic": False,  # the only option now is deterministic
    "time_limit": None,    # milliseconds
    "simuls_limit": 20, # per move
    "policy repeats": 1, # how many times to rerun the rollout policy
    "static_policy": False, # if True, policy just returns the value of the node
    
    # algo params
    "test_games": 10,
    "n_simulated_games": 20,  # at least 2 for stddev to exist
    "sample": True,  # do random sampling from data or just shuffle, the only option due to RAM contraints
    "batch_size": 8,  # only applicable if "sample" = True
    "n_training_loops": 4, # per one RL epoch
    "n_epochs": 200,

    # loss function calculation
    "alpha": 1,  # how well do we approximate MCTS
    "beta": 0,  # discounted final score approximation
    "target-scale": 1.05,  # make neural network to overpredict by .5%
}

WANDB_PROJECT_NAME = None
assert WANDB_PROJECT_NAME is not None, "please provide w&b project name"
assert config["optimizer"] == "Adam"
assert config["lr-scheduler"] == "ExponentialLR"
assert config["n_simulated_games"] > 1

In [None]:
import os, sys
sys.path.append(os.path.abspath(os.path.join("../")))

import random
import statistics
from copy import deepcopy
import time
import math
import warnings

import torch
from torchview import draw_graph
from torchsummary import summary
import numpy as np
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
import wandb
import graphviz
graphviz.set_jupyter_format('png')  # VS code fix for cropped images from torchview

import mathematico
from src.utils import mcts
from src.utils.extract_data import extract
from src.utils.symmetries import all_symmetries
import src.nets as nets
from src.utils.lr import display_learning_rate


########################################
# random seed
########################################

torch.random.manual_seed(config["seed"])
random.seed(config["seed"])
np.random.seed(config["seed"])


########################################
# cuda settings
########################################

if not config["cuda"]:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
dev = torch.device("cuda") if torch.cuda.is_available() and config["cuda"] else torch.device("cpu")
dev

### Weight & Biases Initialisation

In [None]:
import wandb
wandb.init(config=config, project=WANDB_PROJECT_NAME, name=config.get("name", None), settings=wandb.Settings(start_method="fork"))

### Neural Network Definition

As the input, the network takes `list[list[int]]` - the board and it approximates the value function `V(s) = V(board)`.

All inputs are one-hot encoded.

In [None]:
_nn_cls = getattr(nets, config["network"])
model = _nn_cls().to(dev)

optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], betas=config["betas"], weight_decay=config["weight-decay"])
scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config["lr-gamma"])
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# create fake input for testing the net and displaying the info
_board_batch = torch.tensor([[[0, 1, 11, 12, 13]] * 5] * 32, device=dev)
_out = model.forward(_board_batch)
summary(model, [(5, 5)], depth=7);

In [None]:
draw_graph(model, input_data=_board_batch, depth=3, graph_dir="LR").visual_graph

### Agent

In [None]:
from src.agents.mcts_player import MctsPlayer, CardState, MoveState

def policy_static(state: mcts.StateI) -> float:
    board = torch.tensor([state.board.grid], device=dev)
    return model(board)


def policy_dynamic(state: mcts.StateI) -> float:
    _board = deepcopy(state.board)
    _possible_moves = set(_board.possible_moves())
    _deck = [k for k, v in state.deck.items() for _ in range(v)]
    random.shuffle(_deck)
    
    def mmove(move, card):
        b = deepcopy(_board.grid)
        b[move[0]][move[1]] = card
        return b
    
    with torch.no_grad():
        for i in range(len(_possible_moves)):
            batch = torch.tensor([
                mmove(move, _deck[i]) for move in _possible_moves
            ], device=dev)
            out = model(batch)
            idx = torch.argmax(out)
            move = list(_possible_moves)[idx]
            _board.make_move(move, _deck[i])
            _possible_moves.discard(move)
        return _board.score()
            
        
        
def repeated_dynamic(state):
    total = 0
    REPS = config['policy repeats']
    for _ in range(REPS):
        total += policy_dynamic(state)
    return total / REPS       
    


agent = MctsPlayer(
    max_time_ms=config["time_limit"], 
    max_simulations=config["simuls_limit"], 
    policy=policy_static if config["static_policy"] else repeated_dynamic
)

check it works by playing a random game

In [None]:
%%time

start = time.time()
arena = mathematico.Arena()
arena.add_player(agent)
arena.run(seed=0, rounds=1, verbose=True)
end = time.time()

per_move_sec = (end - start) / (5 * 5)
print(f"{per_move_sec=}")

## Training

#### Utils

In [None]:
class VNPlayer(mathematico.Player):
    def reset(self):
        self.board = mathematico.Board()
        
    def move(self, card: int) -> None:
        possibles = list(self.board.possible_moves())
        
        def place(row, col):
            _g = deepcopy(self.board.grid)
            _g[row][col] = card
            return _g
        
        batch = torch.tensor([place(row, col) for row, col in possibles], device=dev)
        scores = model(batch)
        idx = torch.argmax(scores)
        row, col = possibles[idx]
        self.board.make_move((row, col), card)


def _eval_against_players(agent, model, rounds):
    """
    Evaluate the agent against:
        1. random player
        2. mcts player with same time for a game (not very precise..)
        3. mcts with same number of simuls
        4. only value network
        
    Returns:
        win rate for 1, 2, 3, 4
        avg score for 1, 2, 3, 4
        descriptions
        ranking value (avg number of player defeated)
    """
    model.eval()
    players = [
        agent, 
        mathematico.RandomPlayer(), 
        MctsPlayer(max_time_ms=per_move_sec * 1000),
        MctsPlayer(max_simulations=config["simuls_limit"]),
        VNPlayer()
    ]
    
    desc = ["vn+mcts player", "random", "mcts(time)", "mcts(simuls)", "value net"]
    
    wins_agains = [0 for _ in players]
    total_score = [0 for _ in players]
    rank = 0
    
    for _round in trange(rounds, desc="Evaluating performance tournament", leave=None):
        game = mathematico.Mathematico(seed=_round + config["test_seed"])
        for player in players:
            player.reset()
            game.add_player(player)
            
        scores = game.play()
        
        # first is always our agent
        for i in range(len(scores)):
            if scores[0] >= scores[i]:
                wins_agains[i] += 1
                
        for i in range(len(scores)):
            total_score[i] += scores[i]
            
        rank += sum(scores[0] >= scores[i] for i in range(1, len(scores)))
        
    rank /= (rounds * (len(players) - 1))
    return [w/rounds for w in wins_agains], [s/rounds for s in total_score], desc, rank


def _log(start_time, mean, std, min_score, max_score, it, loss, loss_mcts, loss_final):
    # assert torch.isclose(loss_mcts + loss_final + loss_max, loss)
    
    duration = time.time() - start_time
    log_dict = {
        "time": duration,
        "mean": mean,
        "std": std,
        "min score": min_score,
        "max score": max_score,
        "loss": loss,
        "loss [mcts]": loss_mcts,
        "loss [final]": loss_final,
        "lr": optimizer.param_groups[0]['lr']  # valid only with one param group for optimizer
    }

    if it % config["test after . epochs"] == 0:
        wins, scores, desc, rank = _eval_against_players(agent, model, rounds=config["test_games"])
        for d, w, s in zip(desc, wins, scores):
            log_dict[d + " wins %"] = w
            log_dict[d + " [[avg score]]"] = s
        log_dict["rank"] = rank
        
    wandb.log(log_dict)



def learn_episode(agent: MctsPlayer, model: torch.nn.Module, n_games, batch_size, m_training):
    expert_memory = []   
    
    # for logging..
    _scores = []
    loss = 0
    loss_mcts = 0
    loss_final = 0

    #############################################################################
    #                           playing phase
    #############################################################################
    
    model.eval()
    for game in trange(n_games, desc="Game playing phase", leave=None, position=1):
        agent.reset()
        cards = [i for i in range(1, 13+1) for _ in range(4)]
        random.shuffle(cards)

        # game memory - all states visited during mcts
        game_memory = []
        
        # which states were actually played
        true_states = []

        # play all the moves till the end
        for move in trange(5*5, desc="Playing moves", leave=None, position=2):
            state = deepcopy(agent.board.grid)
            card = cards[move]
            estimate, root = agent.move_(card)
            visited_states = extract(root)
            for b, e, v, d, h in visited_states:
                for s in all_symmetries(b):
                    game_memory.append((s, e, v, d, h))
            true_states.append(visited_states[0])

        final_score = agent.board.score()
        _scores.append(final_score)
        
        for board, exp, visits, depth, height in true_states:
            for s in all_symmetries(board):
                expert_memory.append((s, final_score, exp, visits, depth, height))
        for b, e, v, d, h in game_memory:
            expert_memory.append((b, None, e, v, d, h))

            
    #############################################################################
    #                           training phase
    #############################################################################
    
    model.train()
    weights = [entry[3] for entry in expert_memory]  # visit counts
    _s = sum(weights)
    weights = [w/_s for w in weights]
    indices = np.random.choice(len(expert_memory), size=(m_training, batch_size), replace=m_training*batch_size > len(expert_memory), p=weights)
    
    for train in trange(m_training, desc="Training loop", leave=None, position=1):
        batch = [expert_memory[idx] for idx in indices[train]]
        
        for with_final in (True, False):  # two passes, one for played states, one for hypothetical        
            boards, reals, exps, viss, deps, heis = [], [], [], [], [], []
            for b, real, exp, vis, dep, hei in batch:
                if (real is not None) == with_final:
                    boards.append(b)
                    reals.append(real)
                    exps.append(exp)
                    viss.append(vis)
                    deps.append(dep)
                    heis.append(hei)

            if boards:                
                
                optimizer.zero_grad()
                outs = torch.squeeze(model(torch.tensor(boards, device=dev)), dim=1)
                target = config["target-scale"] * torch.tensor(exps, device=dev)
                
                _mcts_loss_norm = torch.log(torch.tensor(viss, device=dev) + 1) / (torch.tensor(heis, device=dev) + 1)
                _mcts_loss = torch.mean(config["alpha"] * _mcts_loss_norm * (target - outs)**2)
                
                _final_coef = (math.log(2) + 1) / (1 + torch.log(1 + torch.tensor(heis, device=dev)))
                _final_loss = 0 if not with_final else torch.mean(config["beta"] * _final_coef * _mcts_loss_norm * (torch.tensor(reals, device=dev) - outs)**2)
                
                _loss = _mcts_loss + _final_loss
                
                if torch.any(torch.isnan(_loss)):
                    raise RuntimeError("NaN detected, instable learning..." + f"{_mcts_loss_norm=}\t{_mcts_loss=}\t{_final_loss=}")
                    
                _loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config["max-gradient-norm"]) 
                optimizer.step()
                
                with torch.no_grad():
                    loss += _loss / m_training
                    loss_mcts += _mcts_loss / m_training
                    if with_final:
                        loss_final += _final_loss / m_training
            
    scheduler1.step()  
    scheduler2.step(loss_mcts)
    
    return (
        statistics.mean(_scores), 
        statistics.stdev(_scores),
        min(_scores),
        max(_scores),
        loss,
        loss_mcts,
        loss_final
    )


## Pre-training

In [None]:
if config["pretrain"] is not None:
    losses = []
    lr = []
    
    o = torch.optim.Adam(model.parameters(), lr=config["pretrain"]["start-lr"])
    # s = torch.optim.lr_scheduler.ExponentialLR(o, gamma=config["pretrain"]["gamma"])
    s = torch.optim.lr_scheduler.ReduceLROnPlateau(o)
    
    for epoch in (pbar := trange(config["pretrain"]["epochs"], desc="Pretraining")):
        boards = []
        vals = []
        for _ in range(config["pretrain"]["samples"]):
            deck = [i for i in range(1, 14) for _ in range(4)]
            random.shuffle(deck)
            b = mathematico.Board()
            for r in range(5):
                for c in range(5):
                    b.make_move((r, c), deck[5*r + c])
            boards.append(b.grid)
            vals.append(b.score())

        _loss_fn = torch.nn.MSELoss()
        o.zero_grad()
        outs = model(torch.tensor(boards, device=dev))
        target = torch.unsqueeze(torch.tensor(vals, device=dev, dtype=outs.dtype), -1)
        loss = _loss_fn(outs, target)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config["max-gradient-norm"])
        o.step()
        s.step(loss)
        
        if torch.any(torch.isnan(loss)):
            print("NaN detected...")
            break
        
        losses.append(loss.sum().detach().cpu().numpy())
        pbar.set_description(f"Pretraining (loss: {losses[-1]:.02f})")
        lr.append(o.param_groups[0]['lr'])
        
    

In [None]:
plt.plot(losses)
ax = plt.gca()
ax.set_yscale('log')
ax.set_title("log-loss")
plt.show()

plt.plot(lr)
ax = plt.gca()
ax.set_yscale('log')
ax.set_title("learning rate")
plt.show()

Check how are the predictions on different boards:

In [None]:
_set = [[1,2,3,4,5]]*4 + [[6,7,8,9,10]]
_b = mathematico.Board()
for r in range(5):
    for c in range(5):
        _b.make_move((r, c), _set[r][c])
print(_b)

with torch.no_grad():
    model.eval()
    print(f"Real = {_b.score()}\tPredicted = {model(torch.tensor([_set], device=dev)).cpu().numpy()[0][0]:.3f}")

In [None]:
for _ in range(25):
    _b = mathematico.Board()
    _deck = [k for k in range(1, 14) for _ in range(4)]
    random.shuffle(_deck)
    for i in range(5):
        for j in range(5):
            _b.make_move((i, j), _deck[5*i+j])

    # print(_b)
    with torch.no_grad():
        model.eval()
        print(f"Real = {_b.score()}\tPredicted = {model(torch.tensor([_b.grid], device=dev)).cpu().numpy()[0][0]:.3f}")

## Training

In [None]:
START = time.time()

# _log(START, None, None, None, None, 0, None, None, None)
for epoch in trange(1, 1+config["n_epochs"], desc="Epochs"):
    mean, std, mini, maxi, L, Lm, Lf = learn_episode(
        agent, 
        model, 
        n_games=config["n_simulated_games"], 
        batch_size=config["batch_size"], 
        m_training=config["n_training_loops"]
    ) 
    
    
    
    if torch.any(torch.isnan(L)) or torch.any(torch.isnan(Lm)):
        print("Instabilities (NaN), aborting...")
        break
    _log(START, mean, std, mini, maxi, epoch, L, Lm, Lf)
    
    # form of early stopping
    if optimizer.param_groups[0]['lr'] < 1e-8:
        print(f"[{epoch=}] Learning rate is too low, aborting...")
        break

In [None]:
wandb.finish()

## Save & Load Trained Agent

In [None]:
# save the loaded network
torch.save(model.state_dict(), config["network"] + ".pt")

In [None]:
# load the saved model
new_model = nets.Simple_Board_v0().to(dev)
new_model.load_state_dict(torch.load("Simple_Board_v0.pt"))
all(new_model.block[2]._parameters["bias"] == model.block[2]._parameters["bias"])