-
Notifications
You must be signed in to change notification settings - Fork 363
feat: support group_norm, batch_norm, and layer_norm #2330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updates look great - added some suggestions to better follow the Torch schemas for these functions
if weight is None: | ||
weight = np.array(1.0) | ||
|
||
if bias is None: | ||
bias = np.array(0.0) | ||
|
||
if running_mean is None: | ||
running_mean = np.array(0.0) | ||
|
||
if running_var is None: | ||
running_var = np.array(1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For these, it should be okay to not cast to np.array
in the converter (instead leave them as ints or floats), since to_numpy
should dictate this casting behavior for ints and floats. Specifically, one small difference is that I think np.array(1.0)
has shape ()
(0D), but to_numpy
generally adds a dimension, to make it 1D.
if weight is None: | ||
weight = np.array(1.0) | ||
|
||
if bias is None: | ||
bias = np.array(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment
if weight is None: | ||
weight = np.array(1.0) | ||
|
||
if bias is None: | ||
bias = np.array(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since line 189 is shape = weight.shape
and lines 191 and 192 call weight.reshape
and bias.reshape
, I think weight
and bias
shouldn't be scalars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - in that case, it might be preferable to use to_numpy(0.0)
, for instance, to get back a default-formatted numpy array for the float default. Additionally, I noticed the code below has some issues:
gamma = to_numpy(weight.reshape(*shape))
##### Above is invalid, since the reshape should apply to the numpy output. It should instead be:
gamma = to_numpy(weight).reshape(shape)
The same as the above applies for beta
.
Additionally, lines 194 - 196 should be using get_axes_for_reduce_op
, as here:
get_axes_for_reduce_op = functools.partial( |
if weight is None: | ||
weight = np.array(1.0) | ||
|
||
if bias is None: | ||
bias = np.array(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - in that case, it might be preferable to use to_numpy(0.0)
, for instance, to get back a default-formatted numpy array for the float default. Additionally, I noticed the code below has some issues:
gamma = to_numpy(weight.reshape(*shape))
##### Above is invalid, since the reshape should apply to the numpy output. It should instead be:
gamma = to_numpy(weight).reshape(shape)
The same as the above applies for beta
.
Additionally, lines 194 - 196 should be using get_axes_for_reduce_op
, as here:
get_axes_for_reduce_op = functools.partial( |
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], | ||
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TRTTensor
would not be a valid input here, for the scale layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the type of weight
and bias
in all the three functions should be Optional[Union[torch.Tensor, np.ndarray]]
? I see its native function:
func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it should be Optional[Union[torch.Tensor, np.ndarray]]
, because if either of those is a TRTTensor
, the computation below would not work (to_numpy
can't be called on a TRTTensor
)
As discussed, add |
@gs-olive group_norm was added! |
fd820e6
to
4aa4dce
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few comments. Additionally, if the dynamic shape version of this converter is not passing, that is okay since it is not required for the first pass of support
scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( | ||
cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps) | ||
cast(torch.Tensor, to_numpy(running_var)) + eps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch.Tensor
cast can be removed, because to_numpy
will return an np.ndarray
, so this typing would be incorrect.
eps_field = trt.PluginField( | ||
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 | ||
) | ||
num_groups_filed = trt.PluginField( | ||
"num_groups", np.array(num_groups), trt.PluginFieldType.INT32 | ||
) | ||
|
||
field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) | ||
|
||
try: | ||
# Here's the schema of the plugin: | ||
# https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml | ||
plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") | ||
except AssertionError: | ||
_LOGGER.error( | ||
"Unable to find group norm plugin, fall back to TensorRT implementation." | ||
) | ||
|
||
layer = network.add_plugin_v2([input, scale, bias], plugin) | ||
set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) | ||
|
||
# PyTorch requires three return values: (out, mean, rstd) | ||
dummy_tensor = torch.tensor(0) | ||
return layer.get_output(0), dummy_tensor, dummy_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to avoid invoking the plugin here, and instead use the full implementation, adapting from here: https://github.com/NVIDIA-AI-IOT/torch2trt/blob/36656b614f3fbc067ac673932e2200d7afdae712/torch2trt/converters/group_norm.py#L7-L73? The plugin is not preferable for use in new converters unless it cannot be otherwise supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, the TRT layer-based implementation can be the backup for the plugin, etc.
eps_field = trt.PluginField( | ||
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 | ||
) | ||
num_groups_filed = trt.PluginField( | ||
"num_groups", np.array(num_groups), trt.PluginFieldType.INT32 | ||
) | ||
|
||
field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) | ||
|
||
try: | ||
# Here's the schema of the plugin: | ||
# https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml | ||
plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") | ||
except AssertionError: | ||
_LOGGER.error( | ||
"Unable to find group norm plugin, fall back to TensorRT implementation." | ||
) | ||
|
||
layer = network.add_plugin_v2([input, scale, bias], plugin) | ||
set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) | ||
|
||
# PyTorch requires three return values: (out, mean, rstd) | ||
dummy_tensor = torch.tensor(0) | ||
return layer.get_output(0), dummy_tensor, dummy_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The returned values here should be correct intermediate tensors from during the computation unless we explicitly remove support for nodes which need the other two values
) | ||
|
||
|
||
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the schema of native_layer_norm
, it looks like it requires 3 outputs much like native_group_norm
. As a comment on both of those - if you want to support it with essentially the same converter as the regular layer norm, you can do the following:
Add this validator
def validator(layer_norm: Node) -> bool:
# Validate only one user, which is a getitem node that accesses the first element in the list
return (len(layer_norm.users) == 1 and
list(node.users)[0].target == operator.getitem and
list(node.users)[0].args[1] == 0))
Add this converter
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=validator)
def converter(...):
return (regular_layer_norm, )
It is important that the above converter returns a tuple, because it will be accessed by getitem
, but as you have validated, it will only access the first element. This should also work for group norm.
@zewenli98 - when you have the chance, please rebase this PR to the latest |
Yes! It's still in progress. Thanks for the reminder! |
8a41cf9
to
4f585d8
Compare
support group norm, and improve batch and layer norms
f628c0c
to
84b58dd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me - will update again pending a manual check against SD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works on SD - looks good to me!
Description
Update
batch_norm
andlayer_norm
Fixes #2225
Type of change
Checklist: