This is the official repository for the paper: ThinkEdit: Interpretable Weight Editing to Mitigate Overly Short Thinking in Reasoning Models[project website].
pip install -r requirements.txt
If you want to skip all the steps and directly access the resulting output files, you can download them through:
gdown https://drive.google.com/uc?id=1WGJOV_Uh1UulU-sNwA7Gy82NddvljDlP
and then unzip the file
unzip ThinkEdit.zip
First, collect the responses from the reasoning models and store them in responses/
for extracting hidden states later:
python generate response_gsm8k.py
Specify the --model
argument: deepseek-qwen-1.5b
, deepseek-llama3-8b
, deepseek-qwen-14b
Next, extract the layerwise directions from Self-Attn or MLP and store them in directions/
:
python extract_thinking_length_directiongsm8k_attn.py
python extract_thinking_length_directiongsm8k_mlp.py
Specify the --model
argument: deepseek-qwen-1.5b
, deepseek-llama3-8b
, deepseek-qwen-14b
.
Finally, steer the models with the directions and observe changes in accuracy and reasoning length. To evaluate on 200 test examples from gsm8k and store the results in gsm8k_all_layer_thinking_length_steering_results/
:
python thinking_length_steering_gsm8k.py
Specify the --model
argument: deepseek-qwen-1.5b
, deepseek-llama3-8b
, deepseek-qwen-14b
.
--control
argument options: thinking_length_attn
, thinking_length_mlp
.
The steering strength alpha (--direction_weight
): we use -0.08 -0.07 ... 0.07 0.08
in our paper.
Similarly, to evaluate 140 Level-5 examples from MATH and store the results in math_level5_all_layer_thinking_length_steering_results/
:
python thinking_length_steering_math_level6.py
Specify arguments accordingly.
To steer only one layer each time and store the results in gsm8k_layerwise_thinking_length_steering_results/
:
python thinking_length_layerwise_steering_gsm8k.py
Specify arguments accordingly. Use --layer
to specify the layer and set --direction_weight
to -1
or 1
(as in our paper). Running the layerwise analysis can take considerable time. We suggest using automate_layerwise_steering_jobs.sh
to handle the jobs; please modify the script based on your hardware.
First, identify the short reasoning heads by calculating their per-head contribution to the short reasoning direction:
python find_short_thinking_attn_heads.py
Specify the --model
argument: deepseek-qwen-1.5b
, deepseek-llama3-8b
, deepseek-qwen-14b
.
This will output a list of short reasoning heads and a heatmap figure of every head's contribution.
Next, perform weight editing to the o_proj
layer of the short reasoning heads and store the model under ThinkEdit_models/
:
python get_ThinkEdit_models.py
Specify the --model
argument: deepseek-qwen-1.5b
, deepseek-llama3-8b
, deepseek-qwen-14b
.
We have provided ThinkEdit models on the Huggingface repo:
cesun/ThinkEdit-deepseek-qwen-14b
cesun/ThinkEdit-deepseek-llama3-8b
cesun/ThinkEdit-deepseek-qwen-1.5b
You can skip this step and our evaluation script will directly download the models from Huggingface.
Finally, evaluate the performance of the original and ThinkEdit models and store the results under ThinkEdit_model_evaluation_results/
. We use vllm to speed up evaluation:
CUDA_VISIBLE_DEVICES={your available gpus} python evaluate_ThinkEdit_models.py
Specify the --model
argument: deepseek-qwen-1.5b
, deepseek-llama3-8b
, deepseek-qwen-14b
, ThinkEdit-deepseek-qwen-14b
, ThinkEdit-deepseek-llama3-8b
, ThinkEdit-deepseek-qwen-1.5b
.
--dataset
argument: gsm8k
, mmlu_elementary_math
, MATH-500
, MATH-level1
, MATH-level5
.
--n_samples
argument: we set this to 10 in our paper, meaning each question is evaluated 10 times.
--tensor_parallel_size
argument: set this according to your number of GPUs; it should be a factor of the number of attention heads in each model. We recommend setting it to 4.
After you have all the results, run:
python analyze_ThinkEdit_performance.py
to generate the plots and tables shown in our paper.
Chung-En Sun, Ge Yan, Tsui-Wei Weng, "ThinkEdit: Interpretable Weight Editing to Mitigate Overly Short Thinking in Reasoning Models", arxiv preprint
@article{ThinkEdit,
title={ThinkEdit: Interpretable Weight Editing to Mitigate Overly Short Thinking in Reasoning Models},
author={Chung-En Sun, Ge Yan, Tsui-Wei Weng},
journal={arXiv},
year={2025}
}