## Performance comparison between JAX and NumPy environments

## Side by side comparison

First let's verify that the two environments are equivalent, meaning that they produce the same observations and rewards for the same actions.

In [None]:
from typing import Tuple, cast
import quicknav_jax as env_jax
import quicknav_numpy as env_np
import jax.numpy as jnp
import numpy as np


def create_environments(
    grid_size: int = 8,
    seed: int = 42,
) -> Tuple[env_np.NavigationEnv, env_jax.NavigationEnv, env_jax.NavigationEnvParams]:
    """Create equivalent NumPy and JAX environments."""
    # Create RNGs
    rng_gen = np.random.default_rng(seed)  # Generator for generate_rooms

    # Generate room layouts
    room_params = env_np.RoomParams(num_rooms=2, grid_size=grid_size, size=grid_size // 2)
    obstacles, free_positions = env_np.generate_rooms(rng_gen, room_params)

    # Create environment parameters
    np_params = env_np.NavigationEnvParams(
        rooms=room_params,
        obstacles=obstacles,
        free_positions=free_positions,
        robot_spawn_pos=(2.0, 2.0),  # Set fixed spawn position for the robot
        goal_spawn_pos=(3.0, 3.0),  # Set fixed spawn position for the goal
        max_steps_in_episode=1_000_000,  # "Arbitrary" large number of steps
    )

    # Create NumPy environment
    np_env = env_np.NavigationEnv(params=np_params, seed=seed)

    # Create JAX environment with same parameters
    jax_params = env_jax.NavigationEnvParams(
        rooms=cast(env_jax.RoomParams, room_params),  # Cast to JAX RoomParams
        obstacles=jnp.array(obstacles),
        free_positions=jnp.array(free_positions),
        robot_spawn_pos=jnp.array([2.0, 2.0]),  # Set fixed spawn position for the robot
        goal_spawn_pos=jnp.array([3.0, 3.0]),  # Set fixed spawn position for the goal
        max_steps_in_episode=1_000_000,  # "Arbitrary" large number of steps
    )
    jax_env = env_jax.NavigationEnv()

    return np_env, jax_env, jax_params

In [None]:
def concat_frames(np_frame: np.ndarray, jax_frame: np.ndarray) -> np.ndarray:
    """Concatenate NumPy and JAX rendered frames side by side."""
    # Ensure both frames have the same height
    height = min(np_frame.shape[0], jax_frame.shape[0])
    np_frame = np_frame[:height]
    jax_frame = jax_frame[:height]

    # Concatenate horizontally
    return np.concatenate([np_frame, jax_frame], axis=1)

In [None]:
import os
from pathlib import Path

import jax
from quicknav_utils import render_frame, save_gif


def run_comparison(num_episodes: int, max_steps: int, seed: int, output_path: Path) -> None:
    """Run comparison between NumPy and JAX environments."""
    # Create environments
    np_env, jax_env, jax_params = create_environments(seed)

    # Create output directory
    os.makedirs(output_path.parent, exist_ok=True)

    # Create RNGs
    rng = np.random.RandomState(seed)  # Use RandomState for environment
    key = jax.random.PRNGKey(seed)

    # Lists to store rendered frames
    frames = []

    for episode in range(num_episodes):
        # Reset environments with same seed to ensure they sample the same initial state
        np_env.reset(seed=seed + episode)  # Use seed to ensure deterministic behavior
        key, reset_key = jax.random.split(key)
        _, jax_state = jax_env.reset_env(reset_key, jax_params)

        # Run episode
        for step in range(max_steps):
            # Generate random action
            action = rng.uniform(-1.0, 1.0, size=(2,))
            jax_action = jnp.array(action)

            # Step NumPy environment
            _, np_reward, np_terminated, np_truncated, _ = np_env.step(action)
            np_done = np_terminated or np_truncated

            # Step JAX environment
            key, step_key = jax.random.split(key)
            _, jax_state, jax_reward, jax_done, _ = jax_env.step_env(step_key, jax_state, jax_action, jax_params)

            # Render frames
            assert np_env.state is not None
            np_frame = render_frame(np_env.state, np_env.params)
            jax_frame = render_frame(jax_state, jax_params)

            # Concatenate frames
            combined_frame = concat_frames(np_frame, jax_frame)
            frames.append(combined_frame)

            # Print step info
            print(f"Step {step}")
            print(f"NumPy - Reward: {np_reward:.2f}, Done: {np_done}")
            print(f"JAX  - Reward: {float(jax_reward):.2f}, Done: {bool(jax_done)}")
            print("-" * 50)

            if np_done or jax_done:
                break

    # Save episode as GIF
    save_gif(frames, output_path, duration_per_frame=1 / 15.0)  # 15 FPS
    print(f"Saved comparison GIF to {output_path}")

## Run comparison

The idea is following:
- Create two environments with the same parameters
- Run them for a given number of time steps
- Visualize both environments side by side

In [None]:
from IPython.display import Image, display

comparison_path = Path("./temp/env_comparison.gif")

run_comparison(
    num_episodes=1,
    max_steps=100,
    seed=42,
    output_path=comparison_path,
)

# Show the GIF
display(Image(filename=comparison_path))

Nice, it looks like the two implementations of the environment behave the same!

## Performance comparison

Now let's compare the performance of the two environments.

- To try to make it as fair as possible, we will only compare the speed of the `step` function, not of the whole training loop with a Neural Network. We will only feed random actions to the environments, and not render anything.

- The Jax environment's `step` function is JIT-compiled and the internal logic is vectorized (avoiding for loops), so the hypothesis is that it should be faster.

- The main speedup of the vectorisation should lie in the collision detection, which scales with the size of the environment. So the idea is to test how the speed up behaves with the size of the environment.



In [None]:
import time
from tqdm import tqdm


def run_speed_test(
    np_env: env_np.NavigationEnv,
    jax_env: env_jax.NavigationEnv,
    jax_params: env_jax.NavigationEnvParams,
    num_steps: int,
    seed: int,
) -> Tuple[float, float]:
    """Measure the speed of the NumPy and JAX environments.
    Returns a tuple of (numpy_time, jax_time)."""
    # Create RNGs
    rng = np.random.RandomState(seed)
    key = jax.random.PRNGKey(seed)

    # Reset environments
    np_env.reset(seed=seed)
    key, reset_key = jax.random.split(key)
    _, jax_state = jax_env.reset_env(reset_key, jax_params)

    # JIT-compile the step function
    jit_step = jax.jit(jax_env.step_env)
    # Warm up the JIT - call the step function once
    dummy_action = jnp.array([0.0, 0.0])
    dummy_key = jax.random.PRNGKey(0)
    _ = jit_step(dummy_key, jax_state, dummy_action, jax_params)

    # Run NumPy environment
    np_start_time = time.time()
    for _ in tqdm(range(num_steps), desc="NumPy environment"):
        action = rng.uniform(-1.0, 1.0, size=(2,))
        np_env.step(action)
    np_end_time = time.time()
    np_time = np_end_time - np_start_time

    # Run JAX environment
    jax_start_time = time.time()
    for _ in tqdm(range(num_steps), desc="JAX environment"):
        action = rng.uniform(-1.0, 1.0, size=(2,))
        jax_action = jnp.array(action)
        key, step_key = jax.random.split(key)
        _, jax_state, _, _, _ = jit_step(step_key, jax_state, jax_action, jax_params)
    jax_end_time = time.time()
    jax_time = jax_end_time - jax_start_time

    return np_time, jax_time

In [None]:
from typing import Dict, Tuple


def run_comparison(
    room_sizes: list[int],
    num_steps: int,
    seed: int = 42,
) -> Dict[int, Tuple[float, float]]:
    """Run speed comparison for different room sizes.
    Returns a dictionary mapping room size to (numpy_time, jax_time).
    """
    results = {}
    for grid_size in room_sizes:
        print(f"\nTesting room size {grid_size}x{grid_size}")
        np_env, jax_env, jax_params = create_environments(grid_size, seed)
        results[grid_size] = run_speed_test(np_env, jax_env, jax_params, num_steps, seed)
    return results

In [None]:
room_sizes = [4, 8, 16]
num_steps = 100_000
seed = 42

# Run comparison
results = run_comparison(room_sizes, num_steps, seed)

In [None]:
for size, (np_time, jax_time) in results.items():
    print(f"Room size {size}x{size}:")
    print(f"NumPy time: {np_time:.2f} seconds")
    print(f"JAX time: {jax_time:.2f} seconds")
    print(f"Speedup: {np_time / jax_time:.2f}x")
    print()

In [None]:
from matplotlib import pyplot as plt


def plot_results(results: Dict[int, Tuple[float, float]]) -> None:
    """Plot results as a bar chart."""
    sizes = sorted(results.keys())
    np_times = [results[size][0] for size in sizes]
    jax_times = [results[size][1] for size in sizes]

    x = np.arange(len(sizes))  # the label locations
    width = 0.35  # the width of the bars

    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width / 2, np_times, width, label="NumPy", color="blue", alpha=0.7)
    rects2 = ax.bar(x + width / 2, jax_times, width, label="JAX", color="orange", alpha=0.7)

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_xlabel("Room Size")
    ax.set_ylabel("Time (seconds)")
    ax.set_title("Environment Performance Comparison")
    ax.set_xticks(x)
    ax.set_xticklabels([str(size) for size in sizes])
    ax.legend()

    # Add value labels on top of bars
    def autolabel(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(
                f"{height:.1f}s",
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, 3),  # 3 points vertical offset
                textcoords="offset points",
                ha="center",
                va="bottom",
            )

    autolabel(rects1)
    autolabel(rects2)

    plt.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()
    plt.show()
    plt.close()

In [None]:
plot_results(results)

## "cached results" (the code above might take an hour to run)

In [None]:
cached_results = {
    4: (149.15101432800293, 45.90607666969299),
    8: (317.48660945892334, 58.82390379905701),
    16: (993.1000814437866, 110.98208546638489),
}


plot_results(cached_results)