Skip to content

Running low_latency test on RoCE get IBGDA error #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
TheRainstorm opened this issue Apr 27, 2025 · 5 comments
Open

Running low_latency test on RoCE get IBGDA error #139

TheRainstorm opened this issue Apr 27, 2025 · 5 comments

Comments

@TheRainstorm
Copy link

TheRainstorm commented Apr 27, 2025

I can run test_internode.py and test_intranode.py correctly, but I cannot run the test_low_latency .py script. The error reported is mainly related to IBGDA(more complete output is at the end):

/repo/nvshmem_src/src/host/transport/transport.cpp:nvshmemi_transport_init:275: init failed for transport: IBGDA

I found related issues 38, but it doesn't seem relevant to my error.

How should I run low_latency on RoCE? Are any extra settings required? Any help you can provide would be greatly appreciated.

Environment

  • Two nodes, each with 8 H20 GPUs, connected by CX6 network (200Gb, ~25GB/s) cards using RoCE mode.
  • Using the latest DeepEP code.
  • nvshmem: 3.2.5-1
  • gdrcopy: 2.4.4

I have followed the NVSHMEM install guide, manually compiled and loaded the gdrdrv module, set the nvidia driver IBGDA-related options in /etc/modprobe.d (Even though I use RoCE not IB), and my command to compile nvshmem is as follows:

CUDA_HOME=/opt/lib/cuda-12.4.1_normal/ \
GDRCOPY_HOME=/repo/gdrcopy-2.4.4 \
NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_MPI_SUPPORT=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/repo/nvshmem_src/install

ibv_devinfo output

hca_id: mlx5_0
        transport:                      InfiniBand (0)
        fw_ver:                         20.35.2000
        node_guid:                      946d:ae03:009c:f3cc
        sys_image_guid:                 946d:ae03:009c:f3cc
        vendor_id:                      0x02c9
        vendor_part_id:                 4123
        hw_ver:                         0x0
        board_id:                       MT_0000000223
        phys_port_cnt:                  1
                port:   1
                        state:                  PORT_ACTIVE (4)
                        max_mtu:                4096 (5)
                        active_mtu:             4096 (5)
                        sm_lid:                 0
                        port_lid:               0
                        port_lmc:               0x00
                        link_layer:             Ethernet

hca_id: mlx5_1
        transport:                      InfiniBand (0)
        fw_ver:                         20.35.2000
        node_guid:                      946d:ae03:009c:f454
        sys_image_guid:                 946d:ae03:009c:f454
        vendor_id:                      0x02c9
        vendor_part_id:                 4123
        hw_ver:                         0x0
        board_id:                       MT_0000000223
        phys_port_cnt:                  1
                port:   1
                        state:                  PORT_ACTIVE (4)
                        max_mtu:                4096 (5)
                        active_mtu:             4096 (5)
                        sm_lid:                 0
                        port_lid:               0
                        port_lmc:               0x00
                        link_layer:             Ethernet

...
(omit mlx5_2 - mlx5_9)

Log

$ MASTER_ADDR=xxx MASTER_PORT=8362 WORLD_SIZE=2 RANK=0 python tests/test_low_latency.py

Allocating buffer size: 2116.290944 MB ...
WARN: cudaHostRegister with IoMemory failed with error=800. We may need to use a fallback path.

WARN: ibgda_nic_mem_gpu_map failed. We may need to use the CPU fallback path.

WARN: ibgda_alloc_and_map_qp_uar with GPU as handler failed. We may need to enter the CPU fallback path.

WARN: GPU cannot map UAR of device mlx5_0. Skipping...

...


WARN: ibgda_nic_mem_gpu_map failed. We may need to use the CPU fallback path.
WARN: ibgda_alloc_and_map_qp_uar with GPU as handler failed. We may need to enter the CPU fallback path.

WARN: GPU cannot map UAR of device mlx5_3. Skipping...

WARN: GPU cannot map UAR of device mlx5_8. Skipping...

WARN: GPU cannot map UAR of device mlx5_1. Skipping...

WARN: GPU cannot map UAR of device mlx5_7. Skipping...

WARN: GPU cannot map UAR of device mlx5_6. Skipping...

WARN: GPU cannot map UAR of device mlx5_bond_0. Skipping...

/repo/nvshmem_src/src/host/transport/transport.cpp:nvshmemi_transport_init:275: init failed for transport: IBGDA
WARN: GPU cannot map UAR of device mlx5_bond_0. Skipping...

/repo/nvshmem_src/src/host/transport/transport.cpp:nvshmemi_transport_init:275: init failed for transport: IBGDA
WARN: cudaHostRegister with IoMemory failed with error=800. We may need to use a fallback path.



WARN: ibgda_nic_mem_gpu_map failed. We may need to use the CPU fallback path.


...

/repo/nvshmem_src/src/host/transport/transport.cpp:nvshmemi_transport_init:275: init failed for transport: IBGDA
/repo/nvshmem_src/src/modules/bootstrap/uid/bootstrap_uid.cpp:bootstrap_net_recv:99: Message truncated : received 16 bytes instead of 8

/repo/nvshmem_src/src/modules/bootstrap/uid/bootstrap_uid.cpp:499: non-zero status: -3 /repo/nvshmem_src/src/host/mem/mem_heap.cpp:222: non-zero status: -3 allgather of heap base for all PE failed

/repo/nvshmem_src/src/host/mem/mem_heap.cpp:588: non-zero status: 7 Failed to allgather PEs peer_base values

/repo/nvshmem_src/src/host/init/init.cu:1011: non-zero status: 7 nvshmem register static heaps failed

/repo/nvshmem_src/src/host/team/team.cu:nvshmem_team_split_strided:63: NVSHMEM API called before NVSHMEM initialization has completed
@TheRainstorm
Copy link
Author

#55 (comment) Would this be helpful?

I remember that it was configured before; I will double-check later as I currently don't have administrator privileges.

However, these two options are used to enable IBGDA, and since IBGDA is an IB feature, is it also necessary in a RoCE environment?

@sphish
Copy link
Collaborator

sphish commented Apr 27, 2025

I'm not sure whether CX6 supports IBGDA.

@sphish
Copy link
Collaborator

sphish commented Apr 27, 2025

However, these two options are used to enable IBGDA, and since IBGDA is an IB feature, is it also necessary in a RoCE environment?

The name IBGDA might be a bit confusing. It actually applies to both IB and RoCE. In the latest version of DeepEP, this is necessary for both.

@TheRainstorm
Copy link
Author

However, these two options are used to enable IBGDA, and since IBGDA is an IB feature, is it also necessary in a RoCE environment?

The name IBGDA might be a bit confusing. It actually applies to both IB and RoCE. In the latest version of DeepEP, this is necessary for both.

Thank you, I will check both.

@TheRainstorm
Copy link
Author

TheRainstorm commented Apr 29, 2025

I'm not sure whether CX6 supports IBGDA.

I checked that CX-6 supports IBGDA. Here is an article comparing CX6 performance with IBGDA and IBRC.

And I check the nvidia option is configured by running cat /proc/driver/nvidia/params, which output EnableStreamMemOPs: 1 and RegistryDwords: "PeerMappingOverride=1;"

In the latest version of DeepEP

I'm sorry that what I was running before was not the latest version, but rather a version from before #130 was merged. After updating to the new version, my error message changed. Now I am getting the error:

$ MASTER_ADDR=172.19.33.83 MASTER_PORT=8362 WORLD_SIZE=2 RANK=0 python tests/test_low_latency.py 2>&1 |tee r_lowlatency2.log
Allocating buffer size: 2115.111296 MB ...
[/nvshmem_src/src/host/stream/coll/barrier/barrier.cu:36] cuda failed with too many blocks in cooperative launch
[/nvshmem_src/src/host/stream/coll/barrier/barrier.cu:36] cuda failed with too many blocks in cooperative launch
[/nvshmem_src/src/host/stream/coll/barrier/barrier.cu:36] cuda failed with too many blocks in cooperative launch
[/nvshmem_src/src/host/stream/coll/barrier/barrier.cu:36] cuda failed with too many blocks in cooperative launch
[/nvshmem_src/src/host/stream/coll/barrier/barrier.cu:36] cuda failed with too many blocks in cooperative launch
[/nvshmem_src/src/host/stream/coll/barrier/barrier.cu:36] cuda failed with too many blocks in cooperative launch
[/nvshmem_src/src/host/stream/coll/barrier/barrier.cu:36] cuda failed with too many blocks in cooperative launch
[/nvshmem_src/src/host/stream/coll/barrier/barrier.cu:36] cuda failed with too many blocks in cooperative launch
W0429 20:14:08.665000 145833 /miniconda3/envs/tcu12/lib/python3.11/site-packages/torch/multiprocessing/spawn.py:169] Terminating process 145915 via signal SIGTERM
W0429 20:14:08.665000 145833 /miniconda3/envs/tcu12/lib/python3.11/site-packages/torch/multiprocessing/spawn.py:169] Terminating process 145917 via signal SIGTERM
Traceback (most recent call last):
  File "/repo/DeepEP/tests/test_low_latency.py", line 172, in <module>
    torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
  File "/miniconda3/envs/tcu12/lib/python3.11/site-packages/torch/multiprocessing/spawn.py", line 340, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/tcu12/lib/python3.11/site-packages/torch/multiprocessing/spawn.py", line 296, in start_processes
    while not context.join():
              ^^^^^^^^^^^^^^
  File "/miniconda3/envs/tcu12/lib/python3.11/site-packages/torch/multiprocessing/spawn.py", line 215, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 4 terminated with the following error:
Traceback (most recent call last):
  File "/miniconda3/envs/tcu12/lib/python3.11/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap
    fn(i, *args)
  File "/repo/DeepEP/tests/test_low_latency.py", line 158, in test_loop
    test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
  File "/repo/DeepEP/tests/test_low_latency.py", line 40, in test_main
    buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
  File "/repo/DeepEP/deep_ep/buffer.py", line 488, in low_latency_dispatch
    self.runtime.low_latency_dispatch(x, topk_idx,
RuntimeError: Failed: CUDA error /repo/DeepEP/csrc/kernels/internode_ll.cu:341 'too many blocks in cooperative launch'

My NVSHMEM version is 3.2.5-1, compiled with the latest patch. Do you have any suggestions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants