Skip to content

❓ [Question] How do I load the torch tensorRT model on multiple gpus #2319

Closed
@agunapal

Description

@agunapal

❓ Question

In TorchServe, we have this concept of workers. In a multi-GPU node, we can assign each GPU to a worker.

I am noticing that tensorRT model is getting loaded on GPU 0 even though we specify the correct GPU ID
for each worker.torch.jit.load(model_pt_path, map_location=self.device)

How do we load a tensorRT model in a a device id which is not 0 ?

What you have already tried

I have tried loading a torchscript model, Here, it loads on all 4 GPUs

Using torch.jit.load(model_pt_path, map_location=self.device) to load the same model on each of the 4 GPUs

2023-09-14T18:32:19,333 [INFO ] W-9000-resnet-18_1.0-stdout MODEL_LOG - cuda:1
2023-09-14T18:32:19,333 [INFO ] W-9000-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,355 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,356 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - cuda:0
2023-09-14T18:32:19,356 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - cuda:3
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - cuda:2
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
Screenshot 2023-09-14 at 11 39 36 AM

Have a simpler repro

import torch
import torch_tensorrt
model = torch.jit.load("trt_model_fp16.pt","cuda:1")
Screenshot 2023-09-14 at 1 28 20 PM

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0):3.9
  • CPU Architecture:
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives: pip
  • Python version: 3.9
  • CUDA version: 11.7
  • GPU models and configuration: T4
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions