Closed
Description
❓ Question
In TorchServe, we have this concept of workers. In a multi-GPU node, we can assign each GPU to a worker.
I am noticing that tensorRT model is getting loaded on GPU 0 even though we specify the correct GPU ID
for each worker.torch.jit.load(model_pt_path, map_location=self.device)
How do we load a tensorRT model in a a device id which is not 0 ?
What you have already tried
I have tried loading a torchscript model, Here, it loads on all 4 GPUs
Using torch.jit.load(model_pt_path, map_location=self.device)
to load the same model on each of the 4 GPUs
2023-09-14T18:32:19,333 [INFO ] W-9000-resnet-18_1.0-stdout MODEL_LOG - cuda:1
2023-09-14T18:32:19,333 [INFO ] W-9000-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,355 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,356 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - cuda:0
2023-09-14T18:32:19,356 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - cuda:3
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - cuda:2
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!

Have a simpler repro
import torch
import torch_tensorrt
model = torch.jit.load("trt_model_fp16.pt","cuda:1")

Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- PyTorch Version (e.g., 1.0):3.9
- CPU Architecture:
- OS (e.g., Linux): Ubuntu 20.04
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives: pip
- Python version: 3.9
- CUDA version: 11.7
- GPU models and configuration: T4
- Any other relevant information: