feat: support group_norm, batch_norm, and layer_norm by zewenli98 · Pull Request #2330 · pytorch/TensorRT (original) (raw)

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Conversation30 Commits7 Checks0 Files changed

Conversation

This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters

[ Show hidden characters]({{ revealButtonHref }})

@zewenli98

Description

Update batch_norm and layer_norm

Fixes #2225

Type of change

Checklist:

gs-olive

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates look great - added some suggestions to better follow the Torch schemas for these functions

gs-olive

Comment on lines 50 to 60

if weight is None:
weight = np.array(1.0)
if bias is None:
bias = np.array(0.0)
if running_mean is None:
running_mean = np.array(0.0)
if running_var is None:
running_var = np.array(1.0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these, it should be okay to not cast to np.array in the converter (instead leave them as ints or floats), since to_numpy should dictate this casting behavior for ints and floats. Specifically, one small difference is that I think np.array(1.0) has shape () (0D), but to_numpy generally adds a dimension, to make it 1D.

Comment on lines 117 to 121

if weight is None:
weight = np.array(1.0)
if bias is None:
bias = np.array(0.0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment

Comment on lines 183 to 187

if weight is None:
weight = np.array(1.0)
if bias is None:
bias = np.array(0.0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since line 189 is shape = weight.shape and lines 191 and 192 call weight.reshape and bias.reshape, I think weight and bias shouldn't be scalars.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - in that case, it might be preferable to use to_numpy(0.0), for instance, to get back a default-formatted numpy array for the float default. Additionally, I noticed the code below has some issues:

gamma = to_numpy(weight.reshape(*shape))

Above is invalid, since the reshape should apply to the numpy output. It should instead be:

gamma = to_numpy(weight).reshape(shape)

The same as the above applies for beta.

Additionally, lines 194 - 196 should be using get_axes_for_reduce_op, as here:

get_axes_for_reduce_op = functools.partial(

gs-olive

Comment on lines 183 to 187

if weight is None:
weight = np.array(1.0)
if bias is None:
bias = np.array(0.0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - in that case, it might be preferable to use to_numpy(0.0), for instance, to get back a default-formatted numpy array for the float default. Additionally, I noticed the code below has some issues:

gamma = to_numpy(weight.reshape(*shape))

Above is invalid, since the reshape should apply to the numpy output. It should instead be:

gamma = to_numpy(weight).reshape(shape)

The same as the above applies for beta.

Additionally, lines 194 - 196 should be using get_axes_for_reduce_op, as here:

get_axes_for_reduce_op = functools.partial(

Comment on lines 173 to 174

weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TRTTensor would not be a valid input here, for the scale layer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the type of weight and bias in all the three functions should be Optional[Union[torch.Tensor, np.ndarray]]? I see its native function:

func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it should be Optional[Union[torch.Tensor, np.ndarray]], because if either of those is a TRTTensor, the computation below would not work (to_numpy can't be called on a TRTTensor)

@gs-olive

As discussed, add group_norm implementation here. Additionally, for any converters added, remove those converters from torch_tensorrt.dynamo.lowering._decomposition_groups

@zewenli98

gs-olive

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few comments. Additionally, if the dynamic shape version of this converter is not passing, that is okay since it is not required for the first pass of support

scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt(
cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps)
cast(torch.Tensor, to_numpy(running_var)) + eps

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch.Tensor cast can be removed, because to_numpy will return an np.ndarray, so this typing would be incorrect.

Comment on lines 351 to 374

eps_field = trt.PluginField(
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32
)
num_groups_filed = trt.PluginField(
"num_groups", np.array(num_groups), trt.PluginFieldType.INT32
)
field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed])
try:
# Here's the schema of the plugin:
# https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin\_PluginConfig.yaml
plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1")
except AssertionError:
_LOGGER.error(
"Unable to find group norm plugin, fall back to TensorRT implementation."
)
layer = network.add_plugin_v2([input, scale, bias], plugin)
set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir)
# PyTorch requires three return values: (out, mean, rstd)
dummy_tensor = torch.tensor(0)
return layer.get_output(0), dummy_tensor, dummy_tensor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, the TRT layer-based implementation can be the backup for the plugin, etc.

Comment on lines 351 to 374

eps_field = trt.PluginField(
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32
)
num_groups_filed = trt.PluginField(
"num_groups", np.array(num_groups), trt.PluginFieldType.INT32
)
field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed])
try:
# Here's the schema of the plugin:
# https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin\_PluginConfig.yaml
plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1")
except AssertionError:
_LOGGER.error(
"Unable to find group norm plugin, fall back to TensorRT implementation."
)
layer = network.add_plugin_v2([input, scale, bias], plugin)
set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir)
# PyTorch requires three return values: (out, mean, rstd)
dummy_tensor = torch.tensor(0)
return layer.get_output(0), dummy_tensor, dummy_tensor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned values here should be correct intermediate tensors from during the computation unless we explicitly remove support for nodes which need the other two values

gs-olive

)
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the schema of native_layer_norm, it looks like it requires 3 outputs much like native_group_norm. As a comment on both of those - if you want to support it with essentially the same converter as the regular layer norm, you can do the following:

Add this validator

def validator(layer_norm: Node) -> bool: # Validate only one user, which is a getitem node that accesses the first element in the list return (len(layer_norm.users) == 1 and list(node.users)[0].target == operator.getitem and list(node.users)[0].args[1] == 0))

Add this converter

@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=validator) def converter(...): return (regular_layer_norm, )

It is important that the above converter returns a tuple, because it will be accessed by getitem, but as you have validated, it will only access the first element. This should also work for group norm.

@zewenli98 zewenli98 changed the titlefix: update batch_norm and layer_norm feat: support group_norm, batch_norm, and layer_norm

Sep 29, 2023

@zewenli98

@gs-olive

@zewenli98 - when you have the chance, please rebase this PR to the latest main. Additionally, to follow up on the discussion from this comment, the individual functions for layer_norm, batch_norm, etc. should likely return their relevant intermediate values too, so that we can convert those native_layer_norm-style functions.

@zewenli98

Yes! It's still in progress. Thanks for the reminder!

@zewenli98 zewenli98 added the WIP

Work is in progress, pull request should not be merged yet

label

Oct 3, 2023

@zewenli98

@zewenli98

@zewenli98

support group norm, and improve batch and layer norms

@zewenli98

gs-olive

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - will update again pending a manual check against SD

gs-olive

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works on SD - looks good to me!

gs-olive pushed a commit that referenced this pull request

Oct 10, 2023

@zewenli98 @gs-olive

This was referenced

Oct 10, 2023

Labels