An AI code assistant that helps write Triton kernel code.
Summary:
We introduce Kernel Coder, an AI assistant for writing Triton GPU kernels, trained using a novel GRPO-based reinforcement learning pipeline. By combining format validation and similarity rewards, we optimize the compact Qwen2.5-Coder-3B model on the KernelBook dataset. Our approach aims to match or surpass KerneLllm (Llama 3.1 8B) in generating correct and efficient Triton code.
We design an RL training pipeline to train a base model for generating Triton Kernel code. Triton is a Python-based DSL for GPU programming. Inspired by DeepSeek-R1-Zero, we implement a GRPO-based RL pipeline to train a base model (Qwen2.5-Coder-3B
).
🎯 Goal: RL-train the Qwen2.5-Coder-3B base model on a Triton kernel dataset (KernelBook), aiming for competitive performance compared to the SFT-trained KerneLllm (based on Llama-3.1-8B
).
We design the reward function with two components:
- ✅ Format Checking: Validate correct usage of
<thinking>
and<answer>
tags. - 🔍 Similarity Score: Measure string similarity between generated and ground-truth Triton kernels using Python’s
difflib.SequenceMatcher
. This idea is inspired bySWE-RL
.
We evaluate the generated Triton kernels using KernelBench (triton_backend_v2
branch) on:
- The base model (
Qwen2.5-Coder-3B
) - The SFT model (
KernelLLM
) kernel-coder
(our model): we will evaluate once the training is complete
- 🎓 DeepSeek R1-Zero style RL pipeline for Triton kernel generation
- 📊 Reward model design: Combining format and similarity-based rewards
- 🧪 Add verifiable rewards: Use
KernelBench
to check compilation, correctness, and speedup. - 🔄 Explore knowledge distillation: Distill
KerneLllm
into a smaller model before applying RL training, then compare with our RL-trained model.
📦 Dataset (KernelBook (Triton Kernel Dataset))
- ~18K samples of PyTorch kernels and corresponding Triton kernels (generated by Torch Inductor).
- Used for SFT training of KerneLllm (SFT model with KernelBook dataset) and our GRPO training of
kernel-coder
.
We apply GRPO training to Qwen2.5-Coder-3B
, a compact yet strong code model from the Qwen 2.5 family, balancing performance and compute cost.
Group Relative Policy Optimization (GRPO), proposed by DeepSeek, uses rule-based rewards for math and code tasks. GRPO avoids using a value model, instead estimating the advantage from relative reward rankings across multiple rollouts:
This approach improves efficiency by comparing rollout quality relative to the batch.
Generate Triton kernels equivalent or superior to provided PyTorch kernels.
Using KernelBench (Cuda and Triton kernel benchmark), (triton_backend_v2
branch) to evaluate:
Model | Compilation Rate (%) | Correctness Rate (%) |
---|---|---|
KernelLLM | 77.0% | 12.0% |
Qwen2.5-Coder-3B (untrained) | 29.0% | 3.0% |
kernel-coder (ours, GRPO-trained) | 🚧 TBD | 🚧 TBD |
We test on label 1 (100 test cases) with temperature 1.0 and top_p 0.97. Preliminary results show the importance of Triton-specific training for compilation and correctness.
The codebase consists of two main components:
nano_r1_script.py
- Modified for our project and originally from nano-aha-momentKernelBench
- Forked and modified from ScalingIntelligence/KernelBench
Project structure:
kernel-coder/
├── README.md
├── kernel-coder
│ ├── nano_r1_script.py # main code
│ └── utils.py
└── scripts
└── kernelllm.py # helper script from KernelLLM model, https://huggingface.co/facebook/KernelLLM
cd kernel-coder # cd to the project root
python kernel-coder/nano_r1_script.py --nproc 8 --max_response_tokens 2048
We build on the following resources: