Skip to content

Commit 3fa8cc4

Browse files
authored
[ONNX] support for setting/getting string onnx tensor shapes (#1478) (#1480)
* [ONNX] support for setting/getting string onnx tensor shapes * Apply suggestions from code review
1 parent faebb8b commit 3fa8cc4

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/sparseml/onnx/utils/helpers.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1216,24 +1216,30 @@ def get_tensor_shape(tensor: onnx.TensorProto) -> List[int]:
12161216
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]
12171217

12181218

1219-
def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: int) -> int:
1219+
def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: Union[int, str]) -> int:
12201220
"""
12211221
:param tensor: ONNX tensor to get the shape of a dimension of
12221222
:param dim: dimension index of the tensor to get the shape of
12231223
:return: shape of the tensor at the given dimension
12241224
"""
1225-
return tensor.type.tensor_type.shape.dim[dim].dim_value
1225+
return (
1226+
tensor.type.tensor_type.shape.dim[dim].dim_value
1227+
or tensor.type.tensor_type.shape.dim[dim].dim_param
1228+
)
12261229

12271230

1228-
def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int):
1231+
def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: Union[int, str]):
12291232
"""
12301233
Sets the shape of the tensor at the given dimension to the given value
12311234
12321235
:param tensor: ONNX tensor to modify the shape of
12331236
:param dim: dimension index of the tensor to modify the shape of
12341237
:param value: new shape for the given dimension
12351238
"""
1236-
tensor.type.tensor_type.shape.dim[dim].dim_value = value
1239+
if isinstance(value, str):
1240+
tensor.type.tensor_type.shape.dim[dim].dim_param = value
1241+
else:
1242+
tensor.type.tensor_type.shape.dim[dim].dim_value = value
12371243

12381244

12391245
def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[int]):

0 commit comments

Comments
 (0)