-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat: FP8 Rowwise quantization support for Cohere models #3127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: aikitoria <[email protected]>
Signed-off-by: aikitoria <[email protected]>
/bot run |
/bot run |
PR_Github #672 [ run ] triggered by Bot |
PR_Github #672 [ run ] completed with state |
It looks like the CI failed, but the links go to some internal domains, so I can't see what the error is. I have some ideas what it might be... I probably need to update other usages of the LayerNorm quantization plugin to handle the new parameters. |
@aikitoria you code failed to pass the pre-commit check. Currently the pre-commit check failure will not be copied back to public to be viewable and we are working to improve it with this MR: For now I just manually copy the error message to provide quick feedback: You can also refer here to do the pre-commit check locally in your own dev environment. June |
Oh I see, I will fix the formatting for both PRs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution!
I've left a few comments, but the PR looks overall good.
@juney-nvidia It'd be good if we can find someone familiar with quantization support. I personally don't have hands-on quantization experience, so I might miss something.
int tokens, int hidden_dim, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, | ||
int8_t* normed_output_quant, bool use_shmem) | ||
int tokens, int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor, | ||
float* scale_orig_quant_per_token, float* sum_per_token, QuantT* normed_output_quant, bool hasFp8MinScaling) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hasFp8MinScaling
->has_fp8_min_scaling
clampPtr
->clamp_ptr
to keep the coding style consistent with other params?
float* scale_orig_quant_per_token, int8_t* normed_output_quant, const dim3 grid, const dim3 block, | ||
const size_t shmem_size, cudaStream_t stream) | ||
float const eps, int tokens, int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor, | ||
float* scale_orig_quant_per_token, float* sum_per_token, QuantT* normed_output_quant, bool const hasFp8MinScaling, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, naming convention issue. Please also check other occurrences of clampPtr
/hasFp8MinScaling
.
cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp
Show resolved
Hide resolved
cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp
Show resolved
Hide resolved
} | ||
// Dynamic scaling if enabled | ||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR); | ||
else if (pos == 5 + int(mClampValEnabled)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm feeling like it'd be more clear if we treat pos
as the position relative to input/output starting index:
if (pos < nbInputs) {
// pos is 0-based input pos.
if (pos < 3) {
...
}
...
} else {
pos -= nbInputs;
// pos is 0-based output pos.
if (pos == 0) {
// Quantized output
...
}
...
}
...
}
@@ -185,14 +270,25 @@ int LayernormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* input | |||
nvinfer1::DataType LayernormQuantizationPlugin::getOutputDataType( | |||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept | |||
{ | |||
assert((mDynActScaling && index < 2) || (!mDynActScaling && index == 0)); | |||
assert(index <= 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this check depend on the value of mDynActScaling
and mSumPerToken
?
Besides, I'd use index < 3
instead of index <= 2
if there are 3 outputs in total.
Sure, I just added @Tracin into the code reviewer loop. Thanks |
bool mClampValEnabled; | ||
// The quantization mode. | ||
tensorrt_llm::common::QuantMode mQuantMode; | ||
// Should we output the sum of channels per-token? (Used by QServe GEMM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we output sum in this scenario? @bobboli
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To my knowledge, per-token sum is only used by QServe's per-channel w4a8 GEMM. fp8 rowwise should not need this.
Hi @aikitoria , would you mind adding an functional unittest like |
default=False, | ||
help="Enable FP8 per-token per-channel quantization") | ||
parser.add_argument( | ||
"--use_meta_fp8_rowwise_recipe", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @aikitoria , use_meta_fp8_rowwise_recipe
cannot be enabled here because Cohere model only has input_layernorm
which is directly followed by the MLP layer. If use_meta_fp8_rowwise_recipe
is enabled, the input_layernorm
will be excluded from quantization and only generate one output while the following Fp8RowwiseFusedGatedMLP
requires two tensors (quantized_input and scale) . So it's better to delete the argument.
quant_config.quant_algo = QuantAlgo.FP8 | ||
elif args.use_fp8_rowwise: | ||
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN | ||
quant_config.use_meta_recipe = args.use_meta_fp8_rowwise_recipe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also delete this line to avoid confusion. Thanks
@aikitoria any update on this? |
Sorry, I have been busy at work, I will come back to this this week! Edit: Still haven't been able to get to it |
This adds FP8 support for the LayerNorm kernel in the same way as was done for the RmsNorm kernel, which then allows us to use FP8 Rowwise quantization with the Cohere models.
For previous discussion, see #2912