# LSTM

In [None]:
%load_ext autoreload
%autoreload 2

import jax
from quicknav_jax import (
  RoomParams,
  generate_rooms,
  NavigationEnvParams,
  NavigationEnv
)

ROOM_SEED = 42

# Generate rooms
room_key = jax.random.PRNGKey(ROOM_SEED)
room_params = RoomParams(size=8.0, grid_size=16)
obstacles, free_positions = generate_rooms(room_key, room_params)

# Initialize environment parameters with generated rooms
env_params = {
  "rooms": room_params,
  "obstacles": obstacles,
  "free_positions": free_positions,
  "lidar_fov": 90,
  'step_penalty': 0.3699952377493355,
  'progress_reward': 0.44331196445624815,
  'cycling_penalty': 0.2030851055334363
}

config = {
  "env": NavigationEnv(),
  "learning_rate": 0.0004203802421088965,
  "num_minibatches": 128, # adjusted to fit model
  "num_steps": 512,
  "gae_lambda": 0.9265010993996222,
  "ent_coef": 0.00951169140356109,
  "clip_eps": 0.14550640064594372,
  "gamma": 0.9601333552614683,
  "total_timesteps": 2_000_000,
  "normalize_observations": True,
  "num_envs": 512,
}

In [None]:
from models.critic import Critic
from criteria.ppo import PPO
from criteria.gaussian_policy import GaussianPolicy
import numpy as np

create_actor = lambda model:  GaussianPolicy(2, (np.array([-1., -1.]), np.array([1., 1.])), model)

In [None]:
from models.lstm import LSTMMultiLayer


N_LAYRERS = 2
D_MODEL = 15

# Initialize the training algorithm parameters
lstm_config = {
    # Pass our environment to the agent
    "env_params": NavigationEnvParams(
      memory_init=lambda: LSTMMultiLayer.initialize_state(d_model=D_MODEL, n_layers=N_LAYRERS),
      **env_params,
    ),
    **config,
}

# Create the training algorithm agent from `rejax` library
lstm_agent = PPO.create(**lstm_config)
lstm_agent = lstm_agent.replace(
    actor=create_actor(LSTMMultiLayer(d_model=D_MODEL, n_layers=N_LAYRERS)),
    critic=Critic(LSTMMultiLayer(d_model=D_MODEL, n_layers=N_LAYRERS))
)

In [None]:
import jax
import time

# Set the seed for reproducibility
TRAIN_SEED = 42

# Set training seed and jit train function
rng = jax.random.PRNGKey(TRAIN_SEED)
lstm_train_fn = jax.jit(lstm_agent.train)

print("Starting to train")

# Train!
start = time.time()
lstm_train_state, lstm_train_evaluation = lstm_train_fn(rng)
time_elapsed = time.time() - start

sps = lstm_agent.total_timesteps / time_elapsed
print(f"Finished training in {time_elapsed:g} seconds ({sps:g} steps/second).")

In [None]:
from matplotlib import pyplot as plt
import os


episode_lengths, episode_returns = lstm_train_evaluation
mean_return = episode_returns.mean(axis=1)

plt.plot(jax.numpy.linspace(0, lstm_agent.total_timesteps, len(mean_return)), mean_return)
plt.xlabel("Environment step")
plt.ylabel("Episodic return")
plt.title(f"Training of {lstm_agent.__class__.__name__} agent")
plt.show()

# Create temp directory if it doesn't exist
os.makedirs("temp", exist_ok=True)

# Save the training curve data as numpy array for comparison with other algorithms
np.save(f"temp/{lstm_agent.__class__.__name__}_training_curve.npy", mean_return)


In [None]:
from quicknav_jax import evaluate_model
from quicknav_jax.eval import EvaluationResult

# Set the seed for reproducibility
TEST_SEED = 100

evaluation: EvaluationResult = evaluate_model(
    agent=lstm_agent,
    train_state=lstm_train_state,
    seed=TEST_SEED,
    render=True,
    n_eval_episodes=10,
)

# Save the returns for comparison with other algorithms
np.save(f"temp/{lstm_agent.__class__.__name__}_returns.npy", evaluation.returns)

In [None]:
from quicknav_utils.env_vis import save_gif
from pathlib import Path
from IPython.display import Image as IPImage, display

if evaluation.rendered_frames is not None:
    path = Path(f"temp/{lstm_agent.__class__.__name__}_policy.gif")
    save_gif(evaluation.rendered_frames, path)

    display(IPImage(filename=path))

    # Save the rendered frames as numpy array for comparison with other algorithms
    np.save(
        f"temp/{lstm_agent.__class__.__name__}_rendered_frames.npy",
        np.array(evaluation.rendered_frames, dtype=object),
        allow_pickle=True,
    )