tl;dr: open-source Llama 4 inference using JAX, minimal yet performant
Note: work in progress, Scout supported with defaults, Maverick's defaults tuning in progress.
This is a pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights. It currently runs on TPU. Support for GPU is in-progress.
The entire model is defined in model.py and invoked via main.py. Among other things, the model code demonstrates:
- an MLA attention implementation;
- expert and tensor-parallelism via JAX's
shard_map
for easy multi-device/multi-host computation; and - simple int8 quantization.
This example aims to be a concise, self-contained, fully open-source codebase, with performance that is reasonably comparable to other Llama 4 inference offerings (at cost). We hope that it is easy to understand and offers an accessible starting point for performant inference with JAX. See the performance rundown below.
In addition, this repo includes an overview of how to shard transformers and a discussion of the specific optimizations used in this implementation, as well as a workflow for interactive development on multi-host GPU and TPU clusters using ipyparallel.
- Quickstart
- Inference performance results
- Transformer parallelism strategies
- Optimizing Llama 4
- Working with multi-host clusters
- Next steps
Due to the large model size (671B parameters), a multi-host platform is required to run the full model. We've tested on v5e-64.
Run on all hosts in the TPU cluster:
$ python3 main.py
e.g. for Cloud TPU:
$ gcloud compute tpus tpu-vm ssh {TPU_NAME} --worker=all \
--command="cd ~/llama4_jax && python3 main.py"
Responses:
[
"\n\nI'm Llama, a model designed by Meta. What’s your name, or should I start guessing?<|eot|><|header_start|>assistant<|header_end|>\n\nI'm Llama, a",
"\n\nHear me, ye knaves! Gather 'round and heed my words, for I shall regale thee with tales of yonder skies and the",
"\n\nA question that requires precision!\n\nAs a computer program, I don't have personal preferences, taste buds, or a physical body. Therefore, I neither like nor"
]
(See Working with multi-host clusters for full setup.)
TPU | batch size | context length | tok/s | HBM BW util |
comments |
---|---|---|---|---|---|
v5e-16 | 1 | 4096 | 85.5 | 71.7% | |
v5e-16 | 16 | 4096 | 76.3 | 64.0% | |
v5e-16 | 32 | 4096 | 73.0 | 61.2% |
TPU | batch size | context length | tok/s | HBM BW util |
comments |
---|
TOOD
Results generated using jax 0.5.3, Python 3.10.15. Cost computation based on
https://cloud.google.com/spot-vms/pricing, region us-central1
as of Feb 28
2025.
Llama 4 is a fairly minimal MoE model in that it (i) uses grouped-query attention, (GQA) and (ii) uses either a simple MLP layer with a gate, up and down projections or a simple MoE layer with the same linear matrices (but repeated for each expert). This presents several opportunities in optimizing the model for TPUs and GPUs to maximize either compute (in training) or memory-bandwidth use in inference.
-
Q: What parameter influences inference speed the most?
A: HBM bandwidth
-
Q: Fully-replicated or sharded activations?
A: For low-latency decoding, fully-replicated activations are usually faster since that strategy relies on a all-reduce communication instead of repeated reduce-scatters. Computation (weights shards) is still partitioned, but local shared memory is traded for lower-latency communication.
-
Q: Why doesn't this match the cost and performance of proprietary inference APIs?
A: This example aims to balance simplicity with performance, and thus does not implement every possible optimization if they would add considerable complexity (e.g. heavy use of custom kernels). In addition, this example only uses well-known optimization strategies, and does not aim to introduce any new or closed-source techniques that inference providers may have independently developed.
-
Q: Inference or training TPUs (e.g., v5e or v5p)?
A: Inference (v5e) since the matmul units are not as powerful, but can be lower latency at low utilization.
-
Q: How to work with multiple hosts?
A: (1) Launching the same python script via
python3 script.py
or (2) ouripyparallel
setup. -
Q: Which TPU image to use?
A: For v5e:
v2-alpha-tpuv5-lite
, for v6e:v2-alpha-tpuv6e
. See runtimes.
TODO
This section overviews different sharding strategies and their performance considerations for Transformer architectures in general. For a very in-depth guide on this topic, check out How to Scale Your Model. The next section goes over Llama-specific optimizations.
A typical decoder-only transformer consists of
- An input embedding
- a single weight
$V \times D$
- a single weight
- Repeated Decoder Layers (Attention + a Feed-forward layer)
- Attention Layer
- project input
$BSD$ to$BSNH$ for queries,$BSNH$ for keys and$BSNH$ values, typically$D \approx N \cdot H$ - compute the attention operation on
$BSNH$ ,$BSNH$ ,$BSNH$ giving$BSNH$ - project the output
$BSNH$ back to$BSD$ using a projection matrix
- project input
- Feed-forward Layer - a Multilayer Perceptron (MLP) or a Mixture-of-Experts (MoE)
- always (i) up-projection -> (ii) nonlinearity -> (iii) down-projection
- MLP
- up-projection:
$BSD \times DF \rightarrow BSF$ - down-projection:
$BSF \times DF \rightarrow BSD$
- up-projection:
- MoE
- each token in
$BS$ can be routed to a matrix slice$EDF[\text{idx}, :, :]$ - up-projection:
$BSD \times EDF \rightarrow BSF$ - down-projection:
$BSF \times EDF \rightarrow BSD$
- each token in
- Attention Layer
- An output projection
- a single weight
$D \times V$
- a single weight
Abbreviation | Dimension |
---|---|
V | vocabulary size |
B | batch |
S | sequence |
D | model dimension |
F | up-projection dimension |
N | number of query, key or value heads |
H | head dimension |
E | expert dimension |
The simplest sharding strategy, naive eager pipeline parallelism, is putting the first couple of layers on the first device, the next couple of layers on the second, and so on, and it requires simple communication of passing activations between devices every couple of layers. Unfortunately, for fast inference, this implies that latter devices wait for the earlier ones to complete - decoding at a speed of a single device. Strategies that favor parallel work among devices, tensor-parallelism and fully-sharded data-parallel, are a better fit. We find tensor-parallelism results in fastest inference.
Strategy | Input | QKV | Output | Up | Down | |||||
---|---|---|---|---|---|---|---|---|---|---|
where:
-
${\color{red} \text{AR}}$ - all-reducejax.lax.psum
-
${\color{red} \text{RS}}$ - reduce-scatterjax.lax.psum_scatter
-
${\color{red} \text{AG}}$ - all-gatherjax.lax.all_gather
The key to designing a sharding strategy is minimizing communication overhead. There are typically several alternatives and the compiler will overlap communication with computation as much as possible. Given this, it's usually worth trying several alternatives and picking the one minimizing the total runtime. The best strategy depends on the hardware configuration. The following are general rules of thumb in different contexts
For low latency with 1D/2D sharding, the primary sharded matrix multiplication strategies are:
-
$BD_x \times D_xF \underset{\text{scatter}}{\longrightarrow} B F_x$ - contracting dimension with scatter (1 unit of comms) -
$BD \times DF_x \underset{}{\longrightarrow} B F_x$ - replicated activations (no comms) -
$BD_x \times D_xF \underset{\text{all-reduce}}{\longrightarrow} B F$ - contracting dimension with reduce comms after (2 units of comms) -
for attention activations should be sharded over heads (effectively the feature dimension)
-
do not all-gather weights (no FSDP)
Total FSDP comms is:
Total tensor-parallelism comms is:
This, very roughly implies the trade-off (in favor of FDSP):
FSDP can be more efficient if the the batch size is on the order of the model dimension. For fast latency Llama 4 Scout
strongly implying a preference for tensor-parallelism (in the context of low-latency decoding).
The MLP in the network is a fairly standard up-down MLP linear layer. Llama 4 uses either an MLP or an MoE layer (which contains an MLP layer referred to as shared expert) after every attention layer.
so we have to choose a two-step matrix multiplication sharding strategy
(
TODO
Fig: Decode profile TODO
When working with many accelerators, JAX offers Distributed arrays and automatic parallelization with a global view of the computation, but the program needs to be run on many hosts each of which controls a subset of the actual accelerators.
The simplest way to run a JAX program on multiple hosts is to run the same Python file from all the hosts at the same time - for example by launching an ssh command on all hosts in the cluster.
However, for development it's often easier to (1.) efficiently share code changes to all hosts, (2.) have a way of easily launching computation on all hosts and (3.) have the ability to debug interactively.
This section shows how you can do that:
- Shared disk setup - NFS & gcsfuse
- Batch SSH commands
- Interactive cluster setup with ipyparallel
This guide has specific instructions for setting up a TPU Pod with GCS, but a similar setup can be applied to any Linux multi-host platform, including GPU.
TPU_ZONE="zone, e.g. us-central1-a"
PROJECT="your-project"
IMAGE="v2-alpha-tpuv5-lite"
ACCELERATOR="v5litepod-64"
TPU_NAME="my_tpu"
TPU_NAME="$NAME_PREFIX"-"$ACCELERATOR"
gcloud alpha compute tpus tpu-vm create "$TPU_NAME" --zone="$TPU_ZONE" \
--project="$PROJECT" --accelerator-type="$ACCELERATOR" --version="$IMAGE"
1. gcsfuse
For datasets and checkpoints.
gcsfuse --implicit-dirs {bucket_name_no_gs://} {local_folder}
For code consistency between hosts in the TPU Pod / Cluster.
# on worker 0
WORKER0_IP="..."
sudo apt install -y nfs-server nfs-common net-tools tmux
mkdir -p ~/nfs; sudo umount ~/nfs
echo "$HOME/nfs $WORKER0_IP/24(rw,sync,no_subtree_check)" | sudo tee /etc/exports
sudo exportfs -a
sudo systemctl enable nfs-server; sudo systemctl restart nfs-server
sudo chown $USER:$USER -R ~/nfs
# on all other workers (!= 0)
SERVER_IP="..."
mkdir -p ~/nfs
sudo umount ~/nfs; sudo mount -t nfs $SERVER_IP:/home/$USER/nfs ~/nfs
(Optionally) 3. sshfs
For a quick preview from a local machine.
sshfs ~/local_folder TPU_WORKER_0_IP:~/remote_folder
TPU_NAME="..."
TPU_ZONE="..."
TPU_PROJECT="..."
tpu_exec() {
local workers=$(seq $1 $2 | tr '\n' ',')
gcloud alpha compute tpus tpu-vm ssh --zone="$TPU_ZONE" --project="$TPU_PROJECT" \
"$TPU_NAME" --worker="$workers" --command="$2"
}
tpu_exec all 'pip install -U "jax[tpu]"'
Start engines
) because we want worker 0 to execute interactively.
SERVER_IP="..."
CONTROLLER_SETUP=$(cat << EOM
tmux kill-session -t controller; pkill -9 python
tmux new -d -s controller '\
. ~/venv/bin/activate && ipcontroller --profile-dir=~/nfs --ip=$SERVER_IP'
EOM
)
ENGINE_SETUP=$(cat << EOM
tmux kill-session -t engine; pkill -9 ipengine
tmux new -d -s engine '. ~/venv/bin/activate && ipengine --profile-dir=~/nfs'
EOM
)
tpu_exec 0 0 "$CONTROLLER_CMD" # only worker 0
tpu_exec 1 15 "$ENGINE_CMD" # all workers except worker 0
Cell 0:
import ipyparallel as ip
from pathlib import Path
connection_file = Path("~/nfs/security/ipcontroller-client.json").expanduser()
client = ip.Client(connection_info=connection_file)
print(sorted(list(client._engines.keys()))) # one less than worker num
# this process is the final worker
Cell 1:
%%px --local
import socket
import jax
jax.distributed.initialize() # no arguments, TPUs automatically detect peers
print(f"Hello from {socket.gethostname()}")
Note: "--local" argument means "also run on this process", it's necessary to get easy access to the output of computations on worker 0
- GPU suport
- ragged attention auto-tuning
- ragged decode kernel