Skip to content

❓ [Question] dynamc engines & interpolation align_corners=True #2327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ArtemisZGL opened this issue Sep 20, 2023 · 20 comments
Open

❓ [Question] dynamc engines & interpolation align_corners=True #2327

ArtemisZGL opened this issue Sep 20, 2023 · 20 comments
Assignees
Labels
component: converters Issues re: Specific op converters question Further information is requested

Comments

@ArtemisZGL
Copy link

❓ Question

What you have already tried

I used the latest docker with tag 23.08-py3. When converting model doing interpolation with align_corners=True and dynamic input, I got error as below.

RuntimeError: [Error thrown at core/conversion/converters/impl/interpolate.cpp:412] Expected !(align_corners && ctx->input_is_dynamic) to be true but got false                                                                                               
Torch-TensorRT currently does not support the compilation of dynamc engines from code using PyTorch [bi/tri]linear interpolation via scale factor and align_corners=True 

And I found this check did exist in code with tag v1.4.0, but not in main branch. Will I need to clone the latest code and recompile torch-tensorrt to escape frome this error and will it work? Or any other simple way ?

Environment

nvcr.io/nvidia/pytorch:23.08-py3

Additional context

@ArtemisZGL ArtemisZGL added the question Further information is requested label Sep 20, 2023
@ArtemisZGL
Copy link
Author

BTW, there is no setup.py in py folder in main branch.

@gs-olive gs-olive added the component: converters Issues re: Specific op converters label Sep 20, 2023
@gs-olive
Copy link
Collaborator

Hello - #2146 removed this warning, so it seems the latest main may resolve this issue. The setup.py file has moved one directory up from py/, so it can now be invoked in the main directory of the repository, either via python setup.py install/develop, or with a pip installation.

@ArtemisZGL
Copy link
Author

ArtemisZGL commented Sep 22, 2023

@gs-olive I uninstall origin torch_tensorrt in docker, and then run python setup.py install in main branch. After that, I got this error when importing torch_tensorrt.

ImportError: /usr/local/lib/python3.10/dist-packages/torch_tensorrt/lib/libtorchtrt.so: undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKSs

Could you give some help? BTW, I tried python setup.py install/develop but got error.

invalid command name 'install/develop'

@gs-olive
Copy link
Collaborator

This is likely a Torch version mismatch error. Could you instead try, from the root of the directory:

pip install . --extra-index-url https://download.pytorch.org/whl/nightly/cu121 --extra-index-url https://pypi.ngc.nvidia.com

Alternatively, you can get a Docker container with the latest main branch here:

docker pull ghcr.io/pytorch/tensorrt/torch_tensorrt:main

@ArtemisZGL
Copy link
Author

ArtemisZGL commented Sep 26, 2023

@gs-olive Thanks for your help, I tried your docker image, but still get error when importing torch_tensorrt 😂.

/opt/python3/site-packages/torch_tensorrt/lib/libtorchtrt.so: undefined symbol: _ZNR5torch7Library4_defEON3c106eitherINS1_12OperatorNameENS1_14FunctionSchemaEEEONS_11CppFunctionE

@gs-olive
Copy link
Collaborator

Thanks for letting me know - it seems there is a mismatch in the libtorch libraries being installed through Torch vs Torch-TensorRT on recent nightly versions and on the Docker container. Looking into it.

@ArtemisZGL
Copy link
Author

@gs-olive Looking forward to the solution. 🚀

@gs-olive
Copy link
Collaborator

I just checked and the latest deployment of ghcr.io/pytorch/tensorrt/torch_tensorrt:main is working on my machine, as the import error has been resolved. Could you retry with the new container?

@ArtemisZGL
Copy link
Author

@gs-olive Now getting the cuda error when using gpu.😂

WARNING: [Torch-TensorRT] - Unable to read CUDA capable devices. Return status: 804                                                                                                                                                                           
Traceback (most recent call last):
  File "/code/algorithm/vits-svc-new/convert_torchscript_trt.py", line 1, in <module>                                                                                                                                                                         
    import torch_tensorrt                                                                                                                                                                                                                                     
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/__init__.py", line 84, in <module>                                                                                                                                          
    from torch_tensorrt._compile import *  # noqa: F403                                                                                                                                                                                                       
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 9, in <module>                                                                                                                                           
    import torch_tensorrt.ts                                                                                                                                                                                                                                  
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/ts/__init__.py", line 1, in <module>                                                                                                                                        
    from torch_tensorrt.ts._compile_spec import TensorRTCompileSpec  # noqa: F401                                                                                                                                                                             
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/ts/_compile_spec.py", line 348, in <module>                                                                                                                                 
    device: torch.device | Device = Device._current_device(),                                                                                                                                                                                                 
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/_Device.py", line 155, in _current_device                                                                                                                                   
    dev = _C._get_current_device()                                                                                                                                                                                                                            
RuntimeError: [Error thrown at core/runtime/runtime.cpp:97] Expected (cudaGetDevice(reinterpret_cast<int*>(&device)) == cudaSuccess) to be true but got false 
/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/cuda/__init__.py:138: UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)

@gs-olive
Copy link
Collaborator

gs-olive commented Sep 28, 2023

Hi @ArtemisZGL - I have seen this issue before - when running the docker container, do you specify --gpus all? As in:

docker run --rm -it --gpus all ghcr.io/pytorch/tensorrt/torch_tensorrt:main

This might help with exposing the GPUs through the docker container. If you prefer to run the latest versions of Torch-TRT outside of a Docker container, you can install the nightly through pip, like so (switch cu121 for cu118, if on a different CUDA version):

pip install --pre torch torch_tensorrt --extra-index-url https://download.pytorch.org/whl/nightly/cu121

@ArtemisZGL
Copy link
Author

@gs-olive Sorry for late reply cause having a long vocation. I used --gpus all to start the container as below.

docker run -it --gpus all --network host ghcr.io/pytorch/tensorrt/torch_tensorrt:main

@ArtemisZGL
Copy link
Author

ArtemisZGL commented Oct 7, 2023

@gs-olive I think that problem is related to driver version, and I reinstall by your command with cu113.

pip install --pre torch torch_tensorrt --extra-index-url https://download.pytorch.org/whl/nightly/cu113

But I met another problem with dynamic shape, here is the code I used to convert the model which should work for dynamic shape accordings to docs.

with torch_tensorrt.logging.debug():
        trt_ts_module = torch_tensorrt.compile(script_module,
            inputs = [ 
                torch_tensorrt.Input(
                min_shape=[1, 192, 1],
                opt_shape=[1, 192, 512],
                max_shape=[1, 192, 1024],
                dtype=torch.float),
                torch_tensorrt.Input(
                min_shape=[1, 1],
                opt_shape=[1, 512],
                max_shape=[1, 1024],
                dtype=torch.float),
                torch_tensorrt.Input(
                shape=[1, 768, 1],
                dtype=torch.float)
            ],
            enabled_precisions = {torch.half},
            # allow_shape_tensors = True,
            truncate_long_and_double = True,) # Run with FP16)

The model have 3 inputs, while the 1st and 2nd are dynamic but the 3rd is static. After conversion, It can infer the shape with 512 properly, but other shape get shape errors. I also tried to set the 3rd input with min(opt max) shape but still got the same error. Is that I missing somethings?
log_convert.txt

ERROR: [Torch-TensorRT] - 3: [executionContext.cpp::setInputShape::2020] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::setInputShape::2020, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape.
)
ERROR: [Torch-TensorRT] - 3: [executionContext.cpp::setInputShape::2020] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::setInputShape::2020, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape.
)

@gs-olive
Copy link
Collaborator

gs-olive commented Oct 9, 2023

Thanks for the follow-up. I don't believe there should be an issue with mixed static/dynamic inputs. Could you also try specifying ir="dynamo" in the compile command and see what output that gives you? In the meantime, @peri044 - have you seen this Dynamic Shape error before in TorchScript?

@ArtemisZGL
Copy link
Author

ArtemisZGL commented Oct 12, 2023

@gs-olive Sorry for late reply, I got this ir type error when adding ir="dynamo", did I miss to install somethings?

 raise ValueError("Unknown ir was requested")
ValueError: Unknown ir was requested

@gs-olive
Copy link
Collaborator

gs-olive commented Oct 12, 2023

What version of torch_tensorrt are you using? If you are in the Docker container, it should have that as a valid ir. If you prefer to install locally from pip, you could use the latest nightly, at this index, provided you also install the corresponding Torch nightly for that date.

@ArtemisZGL
Copy link
Author

@gs-olive Opps, I reinstall torch (torch-2.2.0.dev20231012+cu121.with.pypi.cudnn-cp310-cp310-linux_x86_64.whl)torch_tensorrt (torch_tensorrt-2.2.0.dev20231012+cu121-cp310-cp310-linux_x86_64.whl) and other dependencies in nightly index. But finding another error as blow, which not occured in torch_tensorrt 2.0.0.dev0.

RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:61] Expected ivalue->isScalar() to be true but got false

log.txt

BTW, I also tried with ir="dynamo" but still got the same error. The same as using this docker image nvcr.io/nvidia/pytorch:23.08-py3.

@gs-olive
Copy link
Collaborator

Thanks for testing that out. When you try ir="dynamo" in the latest nightly do you still see the "Unknown ir" message? Could you share the output of python -c "import torch_tensorrt; torch_tensorrt.dump_build_info()"?

@ArtemisZGL
Copy link
Author

ArtemisZGL commented Oct 19, 2023

@gs-olive Sorry for late reply, my torch and torch_tensorrt are from wheels build in 20231012 and still got Unknown ir. Here is the output:

/usr/local/lib/python3.10/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/usr/local/lib/python3.10/dist-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
Torch-TensorRT Version: 2.2.0.dev20231012+cu121
Using PyTorch Version: 2.2.0.dev20231012+cu121-with-pypi-cudnn
Using TensorRT Version: 8.6.1.6
PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.4  (built against CUDA 12.2)
    - Built with CuDNN 8.9.2
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.2.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

@gs-olive
Copy link
Collaborator

I think the issue may be in the input model type, then. If you are passing in a TorchScript ScriptModule, ir="dynamo" will not work, since it can only consume nn.Module or fx.GraphModule objects.

@Feynman1999
Copy link

@gs-olive I think that problem is related to driver version, and I reinstall by your command with cu113.

pip install --pre torch torch_tensorrt --extra-index-url https://download.pytorch.org/whl/nightly/cu113

But I met another problem with dynamic shape, here is the code I used to convert the model which should work for dynamic shape accordings to docs.

with torch_tensorrt.logging.debug():
        trt_ts_module = torch_tensorrt.compile(script_module,
            inputs = [ 
                torch_tensorrt.Input(
                min_shape=[1, 192, 1],
                opt_shape=[1, 192, 512],
                max_shape=[1, 192, 1024],
                dtype=torch.float),
                torch_tensorrt.Input(
                min_shape=[1, 1],
                opt_shape=[1, 512],
                max_shape=[1, 1024],
                dtype=torch.float),
                torch_tensorrt.Input(
                shape=[1, 768, 1],
                dtype=torch.float)
            ],
            enabled_precisions = {torch.half},
            # allow_shape_tensors = True,
            truncate_long_and_double = True,) # Run with FP16)

The model have 3 inputs, while the 1st and 2nd are dynamic but the 3rd is static. After conversion, It can infer the shape with 512 properly, but other shape get shape errors. I also tried to set the 3rd input with min(opt max) shape but still got the same error. Is that I missing somethings? log_convert.txt

ERROR: [Torch-TensorRT] - 3: [executionContext.cpp::setInputShape::2020] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::setInputShape::2020, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape.
)
ERROR: [Torch-TensorRT] - 3: [executionContext.cpp::setInputShape::2020] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::setInputShape::2020, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape.
)

hi! Does your interpolation target size depend on the input size? I have also been trying to run through fully dynamic shape support recently

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component: converters Issues re: Specific op converters question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants