Skip to content

Commit 2ea1218

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Slight refactor of _split_local_dist_buffers()
Summary: Iterators enable this so no more bookkeeping on buffer indices. Reviewed By: gajjanag Differential Revision: D74159431 fbshipit-source-id: ab448d231d6d6d9bee3bbb7b801f8bd2cdffd859
1 parent 4429a4a commit 2ea1218

File tree

3 files changed

+12
-36
lines changed

3 files changed

+12
-36
lines changed

distributed_shampoo/utils/shampoo_ddp_distributor.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -299,18 +299,10 @@ def _split_local_dist_buffers(
299299
)
300300
split_tensors_list.append(split_tensors)
301301

302-
# Obtain ordered buffer ranks containing (view of local buffer, rank).
303-
splitted_local_dist_buffers = []
304-
buffer_indices = [0] * len(
305-
local_dist_buffers
306-
) # index counter for each rank for obtaining right buffer
307-
for _, rank in buffer_size_ranks:
308-
splitted_local_dist_buffers.append(
309-
split_tensors_list[rank][buffer_indices[rank]]
310-
)
311-
buffer_indices[rank] += 1
312-
313-
return tuple(splitted_local_dist_buffers)
302+
split_tensors_iterators = list(map(iter, split_tensors_list))
303+
return tuple(
304+
next(split_tensors_iterators[rank]) for _, rank in buffer_size_ranks
305+
)
314306

315307
def _construct_distributed_buffers(
316308
self,

distributed_shampoo/utils/shampoo_hsdp_distributor.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -492,18 +492,10 @@ def _split_local_dist_buffers(
492492
)
493493
split_tensors_list.append(split_tensors)
494494

495-
# Obtain ordered buffer ranks containing (view of local buffer, rank).
496-
splitted_local_dist_buffers = []
497-
buffer_indices = [0] * len(
498-
local_dist_buffers
499-
) # index counter for each rank for obtaining right buffer
500-
for _, rank in buffer_size_ranks:
501-
splitted_local_dist_buffers.append(
502-
split_tensors_list[rank][buffer_indices[rank]]
503-
)
504-
buffer_indices[rank] += 1
505-
506-
return tuple(splitted_local_dist_buffers)
495+
split_tensors_iterators = list(map(iter, split_tensors_list))
496+
return tuple(
497+
next(split_tensors_iterators[rank]) for _, rank in buffer_size_ranks
498+
)
507499

508500
def _construct_distributed_buffers(
509501
self,

distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -428,18 +428,10 @@ def _split_local_dist_buffers(
428428
)
429429
split_tensors_list.append(split_tensors)
430430

431-
# Obtain ordered buffer ranks containing (view of local buffer, rank).
432-
splitted_local_dist_buffers = []
433-
buffer_indices = [0] * len(
434-
local_dist_buffers
435-
) # index counter for each rank for obtaining right buffer
436-
for _, rank in buffer_size_ranks:
437-
splitted_local_dist_buffers.append(
438-
split_tensors_list[rank][buffer_indices[rank]]
439-
)
440-
buffer_indices[rank] += 1
441-
442-
return tuple(splitted_local_dist_buffers)
431+
split_tensors_iterators = list(map(iter, split_tensors_list))
432+
return tuple(
433+
next(split_tensors_iterators[rank]) for _, rank in buffer_size_ranks
434+
)
443435

444436
def _construct_distributed_buffers(
445437
self,

0 commit comments

Comments
 (0)