File tree 3 files changed +12
-36
lines changed
distributed_shampoo/utils
3 files changed +12
-36
lines changed Original file line number Diff line number Diff line change @@ -299,18 +299,10 @@ def _split_local_dist_buffers(
299
299
)
300
300
split_tensors_list .append (split_tensors )
301
301
302
- # Obtain ordered buffer ranks containing (view of local buffer, rank).
303
- splitted_local_dist_buffers = []
304
- buffer_indices = [0 ] * len (
305
- local_dist_buffers
306
- ) # index counter for each rank for obtaining right buffer
307
- for _ , rank in buffer_size_ranks :
308
- splitted_local_dist_buffers .append (
309
- split_tensors_list [rank ][buffer_indices [rank ]]
310
- )
311
- buffer_indices [rank ] += 1
312
-
313
- return tuple (splitted_local_dist_buffers )
302
+ split_tensors_iterators = list (map (iter , split_tensors_list ))
303
+ return tuple (
304
+ next (split_tensors_iterators [rank ]) for _ , rank in buffer_size_ranks
305
+ )
314
306
315
307
def _construct_distributed_buffers (
316
308
self ,
Original file line number Diff line number Diff line change @@ -492,18 +492,10 @@ def _split_local_dist_buffers(
492
492
)
493
493
split_tensors_list .append (split_tensors )
494
494
495
- # Obtain ordered buffer ranks containing (view of local buffer, rank).
496
- splitted_local_dist_buffers = []
497
- buffer_indices = [0 ] * len (
498
- local_dist_buffers
499
- ) # index counter for each rank for obtaining right buffer
500
- for _ , rank in buffer_size_ranks :
501
- splitted_local_dist_buffers .append (
502
- split_tensors_list [rank ][buffer_indices [rank ]]
503
- )
504
- buffer_indices [rank ] += 1
505
-
506
- return tuple (splitted_local_dist_buffers )
495
+ split_tensors_iterators = list (map (iter , split_tensors_list ))
496
+ return tuple (
497
+ next (split_tensors_iterators [rank ]) for _ , rank in buffer_size_ranks
498
+ )
507
499
508
500
def _construct_distributed_buffers (
509
501
self ,
Original file line number Diff line number Diff line change @@ -428,18 +428,10 @@ def _split_local_dist_buffers(
428
428
)
429
429
split_tensors_list .append (split_tensors )
430
430
431
- # Obtain ordered buffer ranks containing (view of local buffer, rank).
432
- splitted_local_dist_buffers = []
433
- buffer_indices = [0 ] * len (
434
- local_dist_buffers
435
- ) # index counter for each rank for obtaining right buffer
436
- for _ , rank in buffer_size_ranks :
437
- splitted_local_dist_buffers .append (
438
- split_tensors_list [rank ][buffer_indices [rank ]]
439
- )
440
- buffer_indices [rank ] += 1
441
-
442
- return tuple (splitted_local_dist_buffers )
431
+ split_tensors_iterators = list (map (iter , split_tensors_list ))
432
+ return tuple (
433
+ next (split_tensors_iterators [rank ]) for _ , rank in buffer_size_ranks
434
+ )
443
435
444
436
def _construct_distributed_buffers (
445
437
self ,
You can’t perform that action at this time.
0 commit comments