@@ -1216,24 +1216,30 @@ def get_tensor_shape(tensor: onnx.TensorProto) -> List[int]:
1216
1216
return [dim .dim_value for dim in tensor .type .tensor_type .shape .dim ]
1217
1217
1218
1218
1219
- def get_tensor_dim_shape (tensor : onnx .TensorProto , dim : int ) -> int :
1219
+ def get_tensor_dim_shape (tensor : onnx .TensorProto , dim : Union [ int , str ] ) -> int :
1220
1220
"""
1221
1221
:param tensor: ONNX tensor to get the shape of a dimension of
1222
1222
:param dim: dimension index of the tensor to get the shape of
1223
1223
:return: shape of the tensor at the given dimension
1224
1224
"""
1225
- return tensor .type .tensor_type .shape .dim [dim ].dim_value
1225
+ return (
1226
+ tensor .type .tensor_type .shape .dim [dim ].dim_value
1227
+ or tensor .type .tensor_type .shape .dim [dim ].dim_param
1228
+ )
1226
1229
1227
1230
1228
- def set_tensor_dim_shape (tensor : onnx .TensorProto , dim : int , value : int ):
1231
+ def set_tensor_dim_shape (tensor : onnx .TensorProto , dim : int , value : Union [ int , str ] ):
1229
1232
"""
1230
1233
Sets the shape of the tensor at the given dimension to the given value
1231
1234
1232
1235
:param tensor: ONNX tensor to modify the shape of
1233
1236
:param dim: dimension index of the tensor to modify the shape of
1234
1237
:param value: new shape for the given dimension
1235
1238
"""
1236
- tensor .type .tensor_type .shape .dim [dim ].dim_value = value
1239
+ if isinstance (value , str ):
1240
+ tensor .type .tensor_type .shape .dim [dim ].dim_param = value
1241
+ else :
1242
+ tensor .type .tensor_type .shape .dim [dim ].dim_value = value
1237
1243
1238
1244
1239
1245
def override_model_input_shape (model : Union [str , onnx .ModelProto ], shape : List [int ]):
0 commit comments