Skip to content

Commit 0b568d3

Browse files
committed
fix: Upgrade to_numpy to allow boolean constants
1 parent 5f02996 commit 0b568d3

File tree

7 files changed

+55
-8
lines changed

7 files changed

+55
-8
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
345345

346346
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
347347
with _disable_current_modes():
348-
from torch_tensorrt.fx.converters import to_numpy
348+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
349349

350350
frozen_attr = self.fetch_attr(target)
351351

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from torch_tensorrt.fx.converters.converter_utils import (
1818
Frameworks,
1919
get_axes_for_reduce_op,
20-
to_numpy,
2120
unified_dtype_converter,
2221
)
2322
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
@@ -414,3 +413,50 @@ def convert_with_type_enforcement(
414413
return convert_with_type_enforcement
415414

416415
return wrapper
416+
417+
418+
def to_numpy(
419+
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
420+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
421+
) -> np.ndarray:
422+
"""
423+
Convert a PyTorch Tensor, Numpy array, or scalar to a Numpy Array. If the tensor is
424+
quantized it will be dequantized first.
425+
Args:
426+
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
427+
A PyTorch tensor, Numpy array, int, float, or bool
428+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
429+
If a dtype is given, we will convert the type of the given `value` to this dtype.
430+
Returns:
431+
A Numpy array.
432+
"""
433+
output = None
434+
435+
if value is None or isinstance(value, np.ndarray):
436+
output = value
437+
438+
elif isinstance(value, torch.Tensor):
439+
if value.is_quantized:
440+
value = value.dequantize()
441+
442+
output = value.cpu().detach().contiguous().numpy()
443+
444+
elif isinstance(value, int):
445+
output = np.array([value], dtype=np.int32)
446+
447+
elif isinstance(value, float):
448+
output = np.array([value], dtype=np.float32)
449+
450+
elif isinstance(value, bool):
451+
output = np.array([value], dtype=np.bool_)
452+
453+
if isinstance(output, np.ndarray):
454+
return (
455+
output
456+
if (dtype is None or output is None)
457+
else output.astype(unified_dtype_converter(dtype, Frameworks.NUMPY))
458+
)
459+
else:
460+
raise AssertionError(
461+
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
462+
)

py/torch_tensorrt/dynamo/conversion/impl/conv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
from torch_tensorrt.dynamo.conversion import impl
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
SourceIR,
1213
extend_attr_to_tuple,
1314
get_trt_tensor,
15+
to_numpy,
1416
)
1517
from torch_tensorrt.fx.converters.converter_utils import (
16-
SourceIR,
1718
get_dyn_range,
1819
has_dynamic_shape,
1920
mark_as_int8_layer,
2021
set_layer_name,
21-
to_numpy,
2222
)
2323
from torch_tensorrt.fx.types import TRTTensor
2424

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
extend_attr_to_tuple,
1313
get_trt_tensor,
14+
to_numpy,
1415
)
1516
from torch_tensorrt.fx.converters.converter_utils import (
1617
SourceIR,
1718
get_dyn_range,
1819
has_dynamic_shape,
1920
mark_as_int8_layer,
2021
set_layer_name,
21-
to_numpy,
2222
)
2323
from torch_tensorrt.fx.types import TRTTensor
2424

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
10+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
1011
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1112
convert_binary_elementwise,
1213
)
@@ -16,7 +17,6 @@
1617
get_trt_plugin,
1718
has_dynamic_shape,
1819
set_layer_name,
19-
to_numpy,
2020
)
2121
from torch_tensorrt.fx.types import TRTTensor
2222
from torch_tensorrt.fx.utils import get_dynamic_dims

py/torch_tensorrt/dynamo/conversion/impl/select.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
78
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
89
from torch_tensorrt.fx.converters.converter_utils import (
910
get_positive_dim,
1011
has_dynamic_shape,
11-
to_numpy,
1212
)
1313
from torch_tensorrt.fx.types import Shape, TRTTensor
1414

py/torch_tensorrt/dynamo/conversion/impl/shape.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
11+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
1112
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1213
convert_binary_elementwise,
1314
)
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, to_numpy
15+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1516
from torch_tensorrt.fx.types import TRTTensor
1617

1718

0 commit comments

Comments
 (0)