# GTrXL

In [None]:
%load_ext autoreload
%autoreload 2

import jax
from quicknav_jax import (
  RoomParams,
  generate_rooms,
  NavigationEnvParams,
  NavigationEnv
)

ROOM_SEED = 42

# Generate rooms
room_key = jax.random.PRNGKey(ROOM_SEED)
room_params = RoomParams(size=8.0, grid_size=16)
obstacles, free_positions = generate_rooms(room_key, room_params)

# Initialize environment parameters with generated rooms
env_params = {
  "rooms": room_params,
  "obstacles": obstacles,
  "free_positions": free_positions,
  "lidar_fov": 90,
  'step_penalty': 0.2316337353320162,
  'progress_reward': 0.29118751458609604,
  'cycling_penalty': 0.2755619843588947
}

config = {
  "env": NavigationEnv(),
  "learning_rate": 0.0004203802421088965,
  "num_minibatches": 128, # adjusted to fit model
  "num_steps": 512,
  "gae_lambda": 0.9265010993996222,
  "ent_coef": 0.00951169140356109,
  "clip_eps": 0.14550640064594372,
  "gamma": 0.9601333552614683,
  "total_timesteps": 2_000_000,
  "normalize_observations": True,
  "num_envs": 512
}

In [None]:
from models.critic import Critic
from criteria.ppo import PPO
from criteria.gaussian_policy import GaussianPolicy
import numpy as np


create_actor = lambda model:  GaussianPolicy(2, (np.array([-1., -1.]), np.array([1., 1.])), model)

In [None]:
from models.gtrxl import GTrXL


head_dim: int = 4
embedding_dim: int = 8
head_num: int = 2
mlp_num: int = 2
layer_num: int = 3
memory_len: int = 32

# Initialize the training algorithm parameters
grtxl_config = {
    # Pass our environment to the agent
    "env_params": NavigationEnvParams(
      memory_init=lambda: GTrXL.init_memory(memory_len, embedding_dim, layer_num),
      **env_params,
    ),
    **config,
}

# Create the training algorithm agent from `rejax` library
grtxl_agent = PPO.create(**grtxl_config)
grtxl_agent = grtxl_agent.replace(
  actor=create_actor(GTrXL(head_dim, embedding_dim, head_num, mlp_num, layer_num, memory_len)),
  critic=Critic(GTrXL(head_dim, embedding_dim, head_num, mlp_num, layer_num, memory_len))
)

In [None]:
import jax
import time

# Set the seed for reproducibility
TRAIN_SEED = 43

# Set training seed and jit train function
rng = jax.random.PRNGKey(TRAIN_SEED)
grtxl_train_fn = jax.jit(grtxl_agent.train)

print("Starting to train")

# Train!
start: float = time.time()
grtxl_train_state, grtxl_train_evaluation = grtxl_train_fn(rng)
time_elapsed = time.time() - start

sps = grtxl_agent.total_timesteps / time_elapsed
print(f"Finished training in {time_elapsed:g} seconds ({sps:g} steps/second).")

In [None]:
from matplotlib import pyplot as plt
import os


print(grtxl_train_state)

episode_lengths, episode_returns = grtxl_train_evaluation
mean_return = episode_returns.mean(axis=1)

plt.plot(jax.numpy.linspace(0, grtxl_agent.total_timesteps, len(mean_return)), mean_return)
plt.xlabel("Environment step")
plt.ylabel("Episodic return")
plt.title(f"Training of {grtxl_agent.__class__.__name__} agent")
plt.show()

# Create temp directory if it doesn't exist
os.makedirs("temp", exist_ok=True)

# Save the training curve data as numpy array for comparison with other algorithms
np.save(f"temp/{grtxl_agent.__class__.__name__}_training_curve.npy", mean_return)


In [None]:
from quicknav_jax import evaluate_model

# Set the seed for reproducibility
TEST_SEED = 100

evaluation = evaluate_model(
    agent=grtxl_agent,
    train_state=grtxl_train_state,
    seed=TEST_SEED,
    render=True,
    n_eval_episodes=10,
)

# Save the returns for comparison with other algorithms
np.save(f"temp/{grtxl_agent.__class__.__name__}_returns.npy", evaluation.returns)

In [None]:
from quicknav_utils.env_vis import save_gif
from pathlib import Path
from IPython.display import Image as IPImage, display

if evaluation.rendered_frames is not None:
    path = Path(f"temp/{grtxl_agent.__class__.__name__}_policy.gif")
    save_gif(evaluation.rendered_frames, path)

    display(IPImage(filename=path))

    # Save the rendered frames as numpy array for comparison with other algorithms
    np.save(
        f"temp/{grtxl_agent.__class__.__name__}_rendered_frames.npy",
        np.array(evaluation.rendered_frames, dtype=object),
        allow_pickle=True,
    )