feat: support group_norm, batch_norm, and layer_norm by zewenli98 · Pull Request #2330 · pytorch/TensorRT (original) (raw)
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Conversation30 Commits7 Checks0 Files changed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
[ Show hidden characters]({{ revealButtonHref }})
Description
Update batch_norm and layer_norm
Fixes #2225
Type of change
- Bug fix (non-breaking change which fixes an issue)
Checklist:
- My code follows the style guidelines of this project (You can use the linters)
- I have performed a self-review of my own code
- I have commented my code, particularly in hard-to-understand areas and hacks
- I have made corresponding changes to the documentation
- I have added tests to verify my fix or my feature
- New and existing unit tests pass locally with my changes
- I have added the relevant labels to my PR in so that relevant reviewers are notified
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updates look great - added some suggestions to better follow the Torch schemas for these functions
Comment on lines 50 to 60
| if weight is None: |
|---|
| weight = np.array(1.0) |
| if bias is None: |
| bias = np.array(0.0) |
| if running_mean is None: |
| running_mean = np.array(0.0) |
| if running_var is None: |
| running_var = np.array(1.0) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For these, it should be okay to not cast to np.array in the converter (instead leave them as ints or floats), since to_numpy should dictate this casting behavior for ints and floats. Specifically, one small difference is that I think np.array(1.0) has shape () (0D), but to_numpy generally adds a dimension, to make it 1D.
Comment on lines 117 to 121
| if weight is None: |
|---|
| weight = np.array(1.0) |
| if bias is None: |
| bias = np.array(0.0) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment
Comment on lines 183 to 187
| if weight is None: |
|---|
| weight = np.array(1.0) |
| if bias is None: |
| bias = np.array(0.0) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since line 189 is shape = weight.shape and lines 191 and 192 call weight.reshape and bias.reshape, I think weight and bias shouldn't be scalars.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - in that case, it might be preferable to use to_numpy(0.0), for instance, to get back a default-formatted numpy array for the float default. Additionally, I noticed the code below has some issues:
gamma = to_numpy(weight.reshape(*shape))
Above is invalid, since the reshape should apply to the numpy output. It should instead be:
gamma = to_numpy(weight).reshape(shape)
The same as the above applies for beta.
Additionally, lines 194 - 196 should be using get_axes_for_reduce_op, as here:
| get_axes_for_reduce_op = functools.partial( |
|---|
Comment on lines 183 to 187
| if weight is None: |
|---|
| weight = np.array(1.0) |
| if bias is None: |
| bias = np.array(0.0) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - in that case, it might be preferable to use to_numpy(0.0), for instance, to get back a default-formatted numpy array for the float default. Additionally, I noticed the code below has some issues:
gamma = to_numpy(weight.reshape(*shape))
Above is invalid, since the reshape should apply to the numpy output. It should instead be:
gamma = to_numpy(weight).reshape(shape)
The same as the above applies for beta.
Additionally, lines 194 - 196 should be using get_axes_for_reduce_op, as here:
| get_axes_for_reduce_op = functools.partial( |
|---|
Comment on lines 173 to 174
| weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], |
|---|
| bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TRTTensor would not be a valid input here, for the scale layer
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the type of weight and bias in all the three functions should be Optional[Union[torch.Tensor, np.ndarray]]? I see its native function:
func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it should be Optional[Union[torch.Tensor, np.ndarray]], because if either of those is a TRTTensor, the computation below would not work (to_numpy can't be called on a TRTTensor)
As discussed, add group_norm implementation here. Additionally, for any converters added, remove those converters from torch_tensorrt.dynamo.lowering._decomposition_groups
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few comments. Additionally, if the dynamic shape version of this converter is not passing, that is okay since it is not required for the first pass of support
| scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( |
| cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps) |
| cast(torch.Tensor, to_numpy(running_var)) + eps |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch.Tensor cast can be removed, because to_numpy will return an np.ndarray, so this typing would be incorrect.
Comment on lines 351 to 374
| eps_field = trt.PluginField( |
|---|
| "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 |
| ) |
| num_groups_filed = trt.PluginField( |
| "num_groups", np.array(num_groups), trt.PluginFieldType.INT32 |
| ) |
| field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) |
| try: |
| # Here's the schema of the plugin: |
| # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin\_PluginConfig.yaml |
| plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") |
| except AssertionError: |
| _LOGGER.error( |
| "Unable to find group norm plugin, fall back to TensorRT implementation." |
| ) |
| layer = network.add_plugin_v2([input, scale, bias], plugin) |
| set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) |
| # PyTorch requires three return values: (out, mean, rstd) |
| dummy_tensor = torch.tensor(0) |
| return layer.get_output(0), dummy_tensor, dummy_tensor |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, the TRT layer-based implementation can be the backup for the plugin, etc.
Comment on lines 351 to 374
| eps_field = trt.PluginField( |
|---|
| "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 |
| ) |
| num_groups_filed = trt.PluginField( |
| "num_groups", np.array(num_groups), trt.PluginFieldType.INT32 |
| ) |
| field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) |
| try: |
| # Here's the schema of the plugin: |
| # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin\_PluginConfig.yaml |
| plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") |
| except AssertionError: |
| _LOGGER.error( |
| "Unable to find group norm plugin, fall back to TensorRT implementation." |
| ) |
| layer = network.add_plugin_v2([input, scale, bias], plugin) |
| set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) |
| # PyTorch requires three return values: (out, mean, rstd) |
| dummy_tensor = torch.tensor(0) |
| return layer.get_output(0), dummy_tensor, dummy_tensor |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The returned values here should be correct intermediate tensors from during the computation unless we explicitly remove support for nodes which need the other two values
| ) |
|---|
| @dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc] |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the schema of native_layer_norm, it looks like it requires 3 outputs much like native_group_norm. As a comment on both of those - if you want to support it with essentially the same converter as the regular layer norm, you can do the following:
Add this validator
def validator(layer_norm: Node) -> bool: # Validate only one user, which is a getitem node that accesses the first element in the list return (len(layer_norm.users) == 1 and list(node.users)[0].target == operator.getitem and list(node.users)[0].args[1] == 0))
Add this converter
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=validator) def converter(...): return (regular_layer_norm, )
It is important that the above converter returns a tuple, because it will be accessed by getitem, but as you have validated, it will only access the first element. This should also work for group norm.
zewenli98 changed the title
fix: update batch_norm and layer_norm feat: support group_norm, batch_norm, and layer_norm
@zewenli98 - when you have the chance, please rebase this PR to the latest main. Additionally, to follow up on the discussion from this comment, the individual functions for layer_norm, batch_norm, etc. should likely return their relevant intermediate values too, so that we can convert those native_layer_norm-style functions.
Yes! It's still in progress. Thanks for the reminder!
Work is in progress, pull request should not be merged yet
label
support group norm, and improve batch and layer norms
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me - will update again pending a manual check against SD
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works on SD - looks good to me!
gs-olive pushed a commit that referenced this pull request
This was referenced
Oct 10, 2023