Skip to content

Commit 36865a7

Browse files
yaochengjikylesayrs
authored andcommitted
[TPU] support disabling xla compilation cache (vllm-project#15567)
Signed-off-by: Chengji Yao <[email protected]> Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9085e1d commit 36865a7

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

vllm/v1/worker/tpu_worker.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,16 @@ def init_device(self):
113113
# can have slightly different XLA graphs.
114114
world_size = self.parallel_config.world_size
115115
rank = xr.global_ordinal()
116-
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
117-
f"tp{world_size}_rank{rank}")
118-
xr.initialize_cache(per_rank_path, readonly=False)
116+
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
117+
# Consequently, changes in optimization flags, which affect compilation
118+
# results, don't change the cache key. This can result in the wrong
119+
# compilation being used. To prevent this, disabling the XLA compilation
120+
# cache during development is recommended.We can disable it by
121+
# `export VLLM_XLA_CACHE_PATH=`
122+
if envs.VLLM_XLA_CACHE_PATH:
123+
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
124+
f"tp{world_size}_rank{rank}")
125+
xr.initialize_cache(per_rank_path, readonly=False)
119126

120127
# Init ModelRunner here, so that we have access to self.device.
121128
self.model_runner = TPUModelRunner(self.vllm_config, self.device)

vllm/worker/tpu_worker.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,16 @@ def init_device(self) -> None:
9393
# can have slightly different XLA graphs.
9494
world_size = self.parallel_config.world_size
9595
rank = xr.global_ordinal()
96-
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
97-
f"tp{world_size}_rank{rank}")
98-
xr.initialize_cache(per_rank_path, readonly=False)
96+
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
97+
# Consequently, changes in optimization flags, which affect compilation
98+
# results, don't change the cache key. This can result in the wrong
99+
# compilation being used. To prevent this, disabling the XLA compilation
100+
# cache during development is recommended.We can disable it by
101+
# `export VLLM_XLA_CACHE_PATH=`
102+
if envs.VLLM_XLA_CACHE_PATH:
103+
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
104+
f"tp{world_size}_rank{rank}")
105+
xr.initialize_cache(per_rank_path, readonly=False)
99106

100107
self.profiler = None
101108
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:

0 commit comments

Comments
 (0)