Skip to content

feat: FP8 Rowwise quantization support for Cohere models #3127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

aikitoria
Copy link

This adds FP8 support for the LayerNorm kernel in the same way as was done for the RmsNorm kernel, which then allows us to use FP8 Rowwise quantization with the Cohere models.

For previous discussion, see #2912

@juney-nvidia
Copy link
Collaborator

/bot run

@juney-nvidia juney-nvidia requested review from ming-wei and QiJune March 28, 2025 00:06
@juney-nvidia
Copy link
Collaborator

@QiJune @ming-wei pls help review this MR.

@juney-nvidia
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #672 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #672 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #565 completed with status: 'FAILURE'

@aikitoria
Copy link
Author

It looks like the CI failed, but the links go to some internal domains, so I can't see what the error is. I have some ideas what it might be... I probably need to update other usages of the LayerNorm quantization plugin to handle the new parameters.

@juney-nvidia
Copy link
Collaborator

juney-nvidia commented Mar 29, 2025

  • blossom-ci

@aikitoria you code failed to pass the pre-commit check.

Currently the pre-commit check failure will not be copied back to public to be viewable and we are working to improve it with this MR:

For now I just manually copy the error message to provide quick feedback:
image

You can also refer here to do the pre-commit check locally in your own dev environment.

June

@aikitoria
Copy link
Author

Oh I see, I will fix the formatting for both PRs

Copy link
Collaborator

@ming-wei ming-wei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

I've left a few comments, but the PR looks overall good.

@juney-nvidia It'd be good if we can find someone familiar with quantization support. I personally don't have hands-on quantization experience, so I might miss something.

int tokens, int hidden_dim, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token,
int8_t* normed_output_quant, bool use_shmem)
int tokens, int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token, float* sum_per_token, QuantT* normed_output_quant, bool hasFp8MinScaling)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • hasFp8MinScaling -> has_fp8_min_scaling
  • clampPtr -> clamp_ptr

to keep the coding style consistent with other params?

float* scale_orig_quant_per_token, int8_t* normed_output_quant, const dim3 grid, const dim3 block,
const size_t shmem_size, cudaStream_t stream)
float const eps, int tokens, int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token, float* sum_per_token, QuantT* normed_output_quant, bool const hasFp8MinScaling,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, naming convention issue. Please also check other occurrences of clampPtr/hasFp8MinScaling.

}
// Dynamic scaling if enabled
return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR);
else if (pos == 5 + int(mClampValEnabled))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm feeling like it'd be more clear if we treat pos as the position relative to input/output starting index:

if (pos < nbInputs) {
  // pos is 0-based input pos.
  if (pos < 3) {
    ...
  }
  ...
} else {
  pos -= nbInputs;
  // pos is 0-based output pos.
  if (pos == 0) {
    // Quantized output
    ...
  }
  ...
}
...

}

@@ -185,14 +270,25 @@ int LayernormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* input
nvinfer1::DataType LayernormQuantizationPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
assert((mDynActScaling && index < 2) || (!mDynActScaling && index == 0));
assert(index <= 2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this check depend on the value of mDynActScaling and mSumPerToken?

Besides, I'd use index < 3 instead of index <= 2 if there are 3 outputs in total.

@juney-nvidia juney-nvidia requested review from Tracin and wm2012011492 and removed request for QiJune March 31, 2025 06:13
@juney-nvidia
Copy link
Collaborator

Thank you for the contribution!

I've left a few comments, but the PR looks overall good.

@juney-nvidia It'd be good if we can find someone familiar with quantization support. I personally don't have hands-on quantization experience, so I might miss something.

Sure, I just added @Tracin into the code reviewer loop.

Thanks
June

bool mClampValEnabled;
// The quantization mode.
tensorrt_llm::common::QuantMode mQuantMode;
// Should we output the sum of channels per-token? (Used by QServe GEMM)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we output sum in this scenario? @bobboli

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my knowledge, per-token sum is only used by QServe's per-channel w4a8 GEMM. fp8 rowwise should not need this.

@byshiue byshiue self-requested a review March 31, 2025 07:22
@wm2012011492
Copy link
Collaborator

wm2012011492 commented Mar 31, 2025

Hi @aikitoria , would you mind adding an functional unittest like tests/unittest/trt/quantization/test_smooth_quant_layer_norm.py? And it would be better to add an example usage in examples/commandr/README.md. Thanks.

default=False,
help="Enable FP8 per-token per-channel quantization")
parser.add_argument(
"--use_meta_fp8_rowwise_recipe",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aikitoria , use_meta_fp8_rowwise_recipe cannot be enabled here because Cohere model only has input_layernorm which is directly followed by the MLP layer. If use_meta_fp8_rowwise_recipe is enabled, the input_layernorm will be excluded from quantization and only generate one output while the following Fp8RowwiseFusedGatedMLP requires two tensors (quantized_input and scale) . So it's better to delete the argument.

quant_config.quant_algo = QuantAlgo.FP8
elif args.use_fp8_rowwise:
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
quant_config.use_meta_recipe = args.use_meta_fp8_rowwise_recipe
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also delete this line to avoid confusion. Thanks

@ming-wei
Copy link
Collaborator

ming-wei commented Apr 8, 2025

@aikitoria any update on this?

@aikitoria
Copy link
Author

aikitoria commented Apr 8, 2025

Sorry, I have been busy at work, I will come back to this this week!

Edit: Still haven't been able to get to it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants