Skip to content

Commit db1e9ca

Browse files
authored
Provide a minimal reproducible experiment using GRPO for mathematical reasoning on base model, referencing the approach from SimpleRL-Reason (huggingface#197)
* Create config_base_math_smalllr.yaml * Update README.md * Update README.md
1 parent a9c51ab commit db1e9ca

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ To train via the GRPO trainer, we use one GPU to run vLLM for faster generation
119119
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes=7 src/open_r1/grpo.py --config recipes/qwen/Qwen2.5-1.5B-Instruct/grpo/confg_full.yaml
120120
```
121121

122+
We provide a minimal reproducible experiment using GRPO for mathematical reasoning, referencing the approach from [SimpleRL-Reason](https://hkust-nlp.notion.site/simplerl-reason) which uses a 7B model trained on 8K examples. Running this on 8 H100 80G GPU takes about 3 hours:
123+
124+
```shell
125+
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes=7 src/open_r1/grpo.py --config recipes/deepseek/DeepSeek-R1-Distill-Qwen-7B/grpo/config_base_math_smalllr.yaml
126+
```
127+
128+
Our final [model](Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on Math_500, demonstrating a 17%+ improvement over the base model.
129+
122130
To launch a Slurm job, run:
123131

124132
```shell
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Model arguments
2+
model_name_or_path: Qwen/Qwen2.5-Math-7B
3+
model_revision: main
4+
torch_dtype: bfloat16
5+
6+
# Data training arguments
7+
dataset_name: DigitalLearningGmbH/MATH-lighteval
8+
dataset_configs:
9+
- train
10+
# Num processes is less by 1 as vLLM is using 1 GPU
11+
num_processes: 7
12+
13+
# GRPO trainer config
14+
bf16: true
15+
use_vllm: true
16+
vllm_device: auto
17+
vllm_gpu_memory_utilization: 0.7
18+
do_eval: true
19+
eval_strategy: steps
20+
eval_steps: 100
21+
gradient_accumulation_steps: 16
22+
gradient_checkpointing: true
23+
gradient_checkpointing_kwargs:
24+
use_reentrant: false
25+
hub_model_id: Qwen-2.5-7B_Base_Math_smalllr
26+
hub_strategy: every_save
27+
learning_rate: 3.0e-06
28+
log_level: info
29+
logging_steps: 10
30+
logging_strategy: steps
31+
lr_scheduler_type: cosine
32+
max_prompt_length: 512
33+
max_completion_length: 1024
34+
max_steps: -1
35+
num_train_epochs: 1
36+
output_dir: data/Qwen-2.5-7B_Base_Math_smalllr
37+
overwrite_output_dir: true
38+
per_device_eval_batch_size: 1
39+
per_device_train_batch_size: 1
40+
push_to_hub: true
41+
report_to:
42+
- wandb
43+
save_strategy: "no"
44+
seed: 42
45+
warmup_ratio: 0.1

0 commit comments

Comments
 (0)